Rewrite contextvar usage. Implemented ContextInstanceMixin and DataMixin

This commit is contained in:
Alex Root Junior 2018-10-20 15:55:57 +03:00
parent 8ef279bba1
commit 39c333251f
12 changed files with 82 additions and 120 deletions

View file

@ -173,43 +173,6 @@ class BaseBot:
return await self.request(method, payload, files) return await self.request(method, payload, files)
@property
def data(self) -> Dict:
"""
Data stored in bot object
:return: Dictionary
"""
return self._data
def __setitem__(self, key, value):
"""
Store data in bot instance
:param key: Key in dict
:param value: Value
"""
self._data[key] = value
def __getitem__(self, item):
"""
Get item from bot instance by key
:param item: key name
:return: value
"""
return self._data[item]
def get(self, key, default=None):
"""
Get item from bot instance by key or return default value
:param key: key in dict
:param default: default value
:return: value or default value
"""
return self._data.get(key, default)
@property @property
def parse_mode(self): def parse_mode(self):
return getattr(self, '_parse_mode', None) return getattr(self, '_parse_mode', None)

View file

@ -1,15 +1,15 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
from contextvars import ContextVar
from .base import BaseBot, api from .base import BaseBot, api
from .. import types from .. import types
from ..types import base from ..types import base
from ..utils.mixins import DataMixin, ContextInstanceMixin
from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file
class Bot(BaseBot): class Bot(BaseBot, DataMixin, ContextInstanceMixin):
""" """
Base bot class Base bot class
""" """
@ -39,14 +39,6 @@ class Bot(BaseBot):
if hasattr(self, '_me'): if hasattr(self, '_me'):
delattr(self, '_me') delattr(self, '_me')
@classmethod
def current(cls) -> Bot:
"""
Return active bot instance from the current context or None
:return: Bot or None
"""
return bot.get()
async def download_file_by_id(self, file_id: base.String, destination=None, async def download_file_by_id(self, file_id: base.String, destination=None,
timeout: base.Integer = 30, chunk_size: base.Integer = 65536, timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
seek: base.Boolean = True): seek: base.Boolean = True):
@ -98,7 +90,7 @@ class Bot(BaseBot):
allowed_updates = prepare_arg(allowed_updates) allowed_updates = prepare_arg(allowed_updates)
payload = generate_payload(**locals()) payload = generate_payload(**locals())
result = await self.request(api.Methods.GET_UPDATES, payload, timeout=timeout + 2 if timeout else None) result = await self.request(api.Methods.GET_UPDATES, payload)
return [types.Update(**update) for update in result] return [types.Update(**update) for update in result]
async def set_webhook(self, url: base.String, async def set_webhook(self, url: base.String,
@ -522,7 +514,7 @@ class Bot(BaseBot):
""" """
reply_markup = prepare_arg(reply_markup) reply_markup = prepare_arg(reply_markup)
payload = generate_payload(**locals(), exclude=["animation", "thumb"]) payload = generate_payload(**locals(), exclude=["animation", "thumb"])
files = {} files = {}
prepare_file(payload, files, 'animation', animation) prepare_file(payload, files, 'animation', animation)
prepare_attachment(payload, files, 'thumb', thumb) prepare_attachment(payload, files, 'thumb', thumb)
@ -2064,6 +2056,3 @@ class Bot(BaseBot):
result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload) result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload)
return [types.GameHighScore(**gamehighscore) for gamehighscore in result] return [types.GameHighScore(**gamehighscore) for gamehighscore in result]
bot: ContextVar[Bot] = ContextVar('bot_instance', default=None)

View file

@ -116,7 +116,7 @@ class I18nMiddleware(BaseMiddleware):
:param args: event arguments :param args: event arguments
:return: locale name :return: locale name
""" """
user: types.User = types.User.current() user: types.User = types.User.get_current()
locale: Locale = user.locale locale: Locale = user.locale
if locale: if locale:

View file

@ -3,12 +3,11 @@ from . import handler
from . import middlewares from . import middlewares
from . import storage from . import storage
from . import webhook from . import webhook
from .dispatcher import Dispatcher, dispatcher, FSMContext, DEFAULT_RATE_LIMIT from .dispatcher import Dispatcher, FSMContext, DEFAULT_RATE_LIMIT
__all__ = [ __all__ = [
'DEFAULT_RATE_LIMIT', 'DEFAULT_RATE_LIMIT',
'Dispatcher', 'Dispatcher',
'dispatcher',
'FSMContext', 'FSMContext',
'filters', 'filters',
'handler', 'handler',

View file

@ -4,7 +4,6 @@ import itertools
import logging import logging
import time import time
import typing import typing
from contextvars import ContextVar
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \ from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
RegexpCommandsFilter, StateFilter, Text RegexpCommandsFilter, StateFilter, Text
@ -14,15 +13,16 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon
LAST_CALL, RATE_LIMIT, RESULT LAST_CALL, RATE_LIMIT, RESULT
from .webhook import BaseResponse from .webhook import BaseResponse
from .. import types from .. import types
from ..bot import Bot, bot from ..bot import Bot
from ..utils.exceptions import TelegramAPIError, Throttled from ..utils.exceptions import TelegramAPIError, Throttled
from ..utils.mixins import ContextInstanceMixin, DataMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DEFAULT_RATE_LIMIT = .1 DEFAULT_RATE_LIMIT = .1
class Dispatcher: class Dispatcher(DataMixin, ContextInstanceMixin):
""" """
Simple Updates dispatcher Simple Updates dispatcher
@ -112,23 +112,6 @@ class Dispatcher:
def __del__(self): def __del__(self):
self.stop_polling() self.stop_polling()
@property
def data(self):
return self.bot.data
def __setitem__(self, key, value):
self.bot.data[key] = value
def __getitem__(self, item):
return self.bot.data[item]
def get(self, key, default=None):
return self.bot.data.get(key, default)
@classmethod
def current(cls):
return dispatcher.get()
async def skip_updates(self): async def skip_updates(self):
""" """
You can skip old incoming updates from queue. You can skip old incoming updates from queue.
@ -245,8 +228,8 @@ class Dispatcher:
log.info('Start polling.') log.info('Start polling.')
# context.set_value(MODE, LONG_POLLING) # context.set_value(MODE, LONG_POLLING)
dispatcher.set(self) Dispatcher.set_current(self)
bot.bot.set(self.bot) Bot.set_current(self.bot)
if reset_webhook is None: if reset_webhook is None:
await self.reset_webhook(check=False) await self.reset_webhook(check=False)
@ -867,10 +850,10 @@ class Dispatcher:
:return: :return:
""" """
if chat is None: if chat is None:
chat_obj = types.Chat.current() chat_obj = types.Chat.get_current()
chat = chat_obj.id if chat_obj else None chat = chat_obj.id if chat_obj else None
if user is None: if user is None:
user_obj = types.User.current() user_obj = types.User.get_current()
user = user_obj.id if user_obj else None user = user_obj.id if user_obj else None
return FSMContext(storage=self.storage, chat=chat, user=user) return FSMContext(storage=self.storage, chat=chat, user=user)
@ -895,8 +878,8 @@ class Dispatcher:
if rate is None: if rate is None:
rate = self.throttling_rate_limit rate = self.throttling_rate_limit
if user is None and chat is None: if user is None and chat is None:
user = types.User.current() user = types.User.get_current()
chat = types.Chat.current() chat = types.Chat.get_current()
# Detect current time # Detect current time
now = time.time() now = time.time()
@ -945,8 +928,8 @@ class Dispatcher:
raise RuntimeError('This storage does not provide Leaky Bucket') raise RuntimeError('This storage does not provide Leaky Bucket')
if user is None and chat is None: if user is None and chat is None:
user = types.User.current() user = types.User.get_current()
chat = types.Chat.current() chat = types.Chat.get_current()
bucket = await self.storage.get_bucket(chat=chat, user=user) bucket = await self.storage.get_bucket(chat=chat, user=user)
data = bucket.get(key, {}) data = bucket.get(key, {})
@ -965,8 +948,8 @@ class Dispatcher:
raise RuntimeError('This storage does not provide Leaky Bucket') raise RuntimeError('This storage does not provide Leaky Bucket')
if user is None and chat is None: if user is None and chat is None:
user = types.User.current() user = types.User.get_current()
chat = types.Chat.current() chat = types.Chat.get_current()
bucket = await self.storage.get_bucket(chat=chat, user=user) bucket = await self.storage.get_bucket(chat=chat, user=user)
if bucket and key in bucket: if bucket and key in bucket:
@ -997,7 +980,7 @@ class Dispatcher:
response = task.result() response = task.result()
except Exception as e: except Exception as e:
self.loop.create_task( self.loop.create_task(
self.errors_handlers.notify(types.Update.current(), e)) self.errors_handlers.notify(types.Update.get_current(), e))
else: else:
if isinstance(response, BaseResponse): if isinstance(response, BaseResponse):
self.loop.create_task(response.execute_response(self.bot)) self.loop.create_task(response.execute_response(self.bot))
@ -1016,6 +999,3 @@ class Dispatcher:
if run_task: if run_task:
return self.async_task(callback) return self.async_task(callback)
return callback return callback
dispatcher: ContextVar[Dispatcher] = ContextVar('dispatcher_instance', default=None)

View file

@ -53,7 +53,7 @@ class State:
__repr__ = __str__ __repr__ = __str__
async def set(self): async def set(self):
state = Dispatcher.current().current_state() state = Dispatcher.get_current().current_state()
await state.set_state(self.state) await state.set_state(self.state)
@ -143,7 +143,7 @@ class StatesGroupMeta(type):
class StatesGroup(metaclass=StatesGroupMeta): class StatesGroup(metaclass=StatesGroupMeta):
@classmethod @classmethod
async def next(cls) -> str: async def next(cls) -> str:
state = Dispatcher.current().current_state() state = Dispatcher.get_current().current_state()
state_name = await state.get_state() state_name = await state.get_state()
try: try:
@ -161,7 +161,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
@classmethod @classmethod
async def previous(cls) -> str: async def previous(cls) -> str:
state = Dispatcher.current().current_state() state = Dispatcher.get_current().current_state()
state_name = await state.get_state() state_name = await state.get_state()
try: try:
@ -179,7 +179,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
@classmethod @classmethod
async def first(cls) -> str: async def first(cls) -> str:
state = Dispatcher.current().current_state() state = Dispatcher.get_current().current_state()
first_step_name = cls.states_names[0] first_step_name = cls.states_names[0]
await state.set_state(first_step_name) await state.set_state(first_step_name)
@ -187,7 +187,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
@classmethod @classmethod
async def last(cls) -> str: async def last(cls) -> str:
state = Dispatcher.current().current_state() state = Dispatcher.get_current().current_state()
last_step_name = cls.states_names[-1] last_step_name = cls.states_names[-1]
await state.set_state(last_step_name) await state.set_state(last_step_name)

View file

@ -89,10 +89,9 @@ class WebhookRequestHandler(web.View):
""" """
dp = self.request.app[BOT_DISPATCHER_KEY] dp = self.request.app[BOT_DISPATCHER_KEY]
try: try:
from aiogram.bot import bot from aiogram import Bot, Dispatcher
from aiogram.dispatcher import dispatcher Dispatcher.set_current(dp)
dispatcher.set(dp) Bot.set_current(dp.bot)
bot.bot.set(dp.bot)
except RuntimeError: except RuntimeError:
pass pass
return dp return dp
@ -204,7 +203,7 @@ class WebhookRequestHandler(web.View):
results = task.result() results = task.result()
except Exception as e: except Exception as e:
loop.create_task( loop.create_task(
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e)) dispatcher.errors_handlers.notify(dispatcher, types.Update.get_current(), e))
else: else:
response = self.get_response(results) response = self.get_response(results)
if response is not None: if response is not None:
@ -355,7 +354,7 @@ class BaseResponse:
async def __call__(self, bot=None): async def __call__(self, bot=None):
if bot is None: if bot is None:
from aiogram import Bot from aiogram import Bot
bot = Bot.current() bot = Bot.get_current()
return await self.execute_response(bot) return await self.execute_response(bot)
async def __aenter__(self): async def __aenter__(self):
@ -449,7 +448,7 @@ class ParseModeMixin:
:return: :return:
""" """
from aiogram import Bot from aiogram import Bot
bot = Bot.current() bot = Bot.get_current()
if bot is not None: if bot is not None:
return bot.parse_mode return bot.parse_mode

View file

@ -2,11 +2,11 @@ from __future__ import annotations
import io import io
import typing import typing
from contextvars import ContextVar
from typing import TypeVar from typing import TypeVar
from .fields import BaseField from .fields import BaseField
from ..utils import json from ..utils import json
from ..utils.mixins import ContextInstanceMixin
__all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean') __all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean')
@ -57,7 +57,6 @@ class MetaTelegramObject(type):
mcs._objects[cls.__name__] = cls mcs._objects[cls.__name__] = cls
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
return cls return cls
@property @property
@ -65,7 +64,7 @@ class MetaTelegramObject(type):
return cls._objects return cls._objects
class TelegramObject(metaclass=MetaTelegramObject): class TelegramObject(ContextInstanceMixin, metaclass=MetaTelegramObject):
""" """
Abstract class for telegram objects Abstract class for telegram objects
""" """
@ -93,14 +92,6 @@ class TelegramObject(metaclass=MetaTelegramObject):
if value.default and key not in self.values: if value.default and key not in self.values:
self.values[key] = value.default self.values[key] = value.default
@classmethod
def current(cls):
return cls._current.get()
@classmethod
def set_current(cls, obj: TelegramObject):
return cls._current.set(obj)
@property @property
def conf(self) -> typing.Dict[str, typing.Any]: def conf(self) -> typing.Dict[str, typing.Any]:
return self._conf return self._conf
@ -151,7 +142,7 @@ class TelegramObject(metaclass=MetaTelegramObject):
@property @property
def bot(self): def bot(self):
from ..bot.bot import Bot from ..bot.bot import Bot
return Bot.current() return Bot.get_current()
def to_python(self) -> typing.Dict: def to_python(self) -> typing.Dict:
""" """

View file

@ -511,7 +511,7 @@ class ChatActions(helper.Helper):
@classmethod @classmethod
async def _do(cls, action: str, sleep=None): async def _do(cls, action: str, sleep=None):
from aiogram import Bot from aiogram import Bot
await Bot.current().send_chat_action(Chat.current().id, action) await Bot.get_current().send_chat_action(Chat.get_current().id, action)
if sleep: if sleep:
await asyncio.sleep(sleep) await asyncio.sleep(sleep)

View file

@ -103,10 +103,9 @@ class Executor:
self._freeze = False self._freeze = False
from aiogram.bot.bot import bot as ctx_bot from aiogram import Bot, Dispatcher
from aiogram.dispatcher import dispatcher as ctx_dp Bot.set_current(dispatcher.bot)
ctx_bot.set(dispatcher.bot) Dispatcher.set_current(dispatcher)
ctx_dp.set(dispatcher)
@property @property
def frozen(self): def frozen(self):

42
aiogram/utils/mixins.py Normal file
View file

@ -0,0 +1,42 @@
import contextvars
class DataMixin:
@property
def data(self):
data = getattr(self, '_data', None)
if data is None:
data = {}
setattr(self, '_data', data)
return data
def __getitem__(self, item):
return self.data[item]
def __setitem__(self, key, value):
self.data[key] = value
def __delitem__(self, key):
del self.data[key]
def get(self, key, default=None):
return self.data.get(key, default)
class ContextInstanceMixin:
def __init_subclass__(cls, **kwargs):
cls.__context_instance = contextvars.ContextVar('instance_' + cls.__name__)
return cls
@classmethod
def get_current(cls, no_error=True):
if no_error:
return cls.__context_instance.get(None)
return cls.__context_instance.get()
@classmethod
def set_current(cls, value):
if not isinstance(value, cls):
raise TypeError(f"Value should be instance of '{cls.__name__}' not '{type(value).__name__}'")
cls.__context_instance.set(value)

View file

@ -56,7 +56,7 @@ class ThrottlingMiddleware(BaseMiddleware):
handler = current_handler.get() handler = current_handler.get()
# Get dispatcher from context # Get dispatcher from context
dispatcher = Dispatcher.current() dispatcher = Dispatcher.get_current()
# If handler was configured, get rate limit and key from handler # If handler was configured, get rate limit and key from handler
if handler: if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit) limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
@ -83,7 +83,7 @@ class ThrottlingMiddleware(BaseMiddleware):
:param throttled: :param throttled:
""" """
handler = current_handler.get() handler = current_handler.get()
dispatcher = Dispatcher.current() dispatcher = Dispatcher.get_current()
if handler: if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else: else: