diff --git a/aiogram/bot/base.py b/aiogram/bot/base.py index e3a78054..aec439d0 100644 --- a/aiogram/bot/base.py +++ b/aiogram/bot/base.py @@ -173,43 +173,6 @@ class BaseBot: return await self.request(method, payload, files) - @property - def data(self) -> Dict: - """ - Data stored in bot object - - :return: Dictionary - """ - return self._data - - def __setitem__(self, key, value): - """ - Store data in bot instance - - :param key: Key in dict - :param value: Value - """ - self._data[key] = value - - def __getitem__(self, item): - """ - Get item from bot instance by key - - :param item: key name - :return: value - """ - return self._data[item] - - def get(self, key, default=None): - """ - Get item from bot instance by key or return default value - - :param key: key in dict - :param default: default value - :return: value or default value - """ - return self._data.get(key, default) - @property def parse_mode(self): return getattr(self, '_parse_mode', None) diff --git a/aiogram/bot/bot.py b/aiogram/bot/bot.py index 5a16d669..c3644f60 100644 --- a/aiogram/bot/bot.py +++ b/aiogram/bot/bot.py @@ -1,15 +1,15 @@ from __future__ import annotations import typing -from contextvars import ContextVar from .base import BaseBot, api from .. import types from ..types import base +from ..utils.mixins import DataMixin, ContextInstanceMixin from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file -class Bot(BaseBot): +class Bot(BaseBot, DataMixin, ContextInstanceMixin): """ Base bot class """ @@ -39,14 +39,6 @@ class Bot(BaseBot): if hasattr(self, '_me'): delattr(self, '_me') - @classmethod - def current(cls) -> Bot: - """ - Return active bot instance from the current context or None - :return: Bot or None - """ - return bot.get() - async def download_file_by_id(self, file_id: base.String, destination=None, timeout: base.Integer = 30, chunk_size: base.Integer = 65536, seek: base.Boolean = True): @@ -98,7 +90,7 @@ class Bot(BaseBot): allowed_updates = prepare_arg(allowed_updates) payload = generate_payload(**locals()) - result = await self.request(api.Methods.GET_UPDATES, payload, timeout=timeout + 2 if timeout else None) + result = await self.request(api.Methods.GET_UPDATES, payload) return [types.Update(**update) for update in result] async def set_webhook(self, url: base.String, @@ -522,7 +514,7 @@ class Bot(BaseBot): """ reply_markup = prepare_arg(reply_markup) payload = generate_payload(**locals(), exclude=["animation", "thumb"]) - + files = {} prepare_file(payload, files, 'animation', animation) prepare_attachment(payload, files, 'thumb', thumb) @@ -2064,6 +2056,3 @@ class Bot(BaseBot): result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload) return [types.GameHighScore(**gamehighscore) for gamehighscore in result] - - -bot: ContextVar[Bot] = ContextVar('bot_instance', default=None) diff --git a/aiogram/contrib/middlewares/i18n.py b/aiogram/contrib/middlewares/i18n.py index 2ecc167a..65cb1400 100644 --- a/aiogram/contrib/middlewares/i18n.py +++ b/aiogram/contrib/middlewares/i18n.py @@ -116,7 +116,7 @@ class I18nMiddleware(BaseMiddleware): :param args: event arguments :return: locale name """ - user: types.User = types.User.current() + user: types.User = types.User.get_current() locale: Locale = user.locale if locale: diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 2ff5dc90..6ad43bbe 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -3,12 +3,11 @@ from . import handler from . import middlewares from . import storage from . import webhook -from .dispatcher import Dispatcher, dispatcher, FSMContext, DEFAULT_RATE_LIMIT +from .dispatcher import Dispatcher, FSMContext, DEFAULT_RATE_LIMIT __all__ = [ 'DEFAULT_RATE_LIMIT', 'Dispatcher', - 'dispatcher', 'FSMContext', 'filters', 'handler', diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index cade7358..3c1506f8 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -4,7 +4,6 @@ import itertools import logging import time import typing -from contextvars import ContextVar from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \ RegexpCommandsFilter, StateFilter, Text @@ -14,15 +13,16 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon LAST_CALL, RATE_LIMIT, RESULT from .webhook import BaseResponse from .. import types -from ..bot import Bot, bot +from ..bot import Bot from ..utils.exceptions import TelegramAPIError, Throttled +from ..utils.mixins import ContextInstanceMixin, DataMixin log = logging.getLogger(__name__) DEFAULT_RATE_LIMIT = .1 -class Dispatcher: +class Dispatcher(DataMixin, ContextInstanceMixin): """ Simple Updates dispatcher @@ -112,23 +112,6 @@ class Dispatcher: def __del__(self): self.stop_polling() - @property - def data(self): - return self.bot.data - - def __setitem__(self, key, value): - self.bot.data[key] = value - - def __getitem__(self, item): - return self.bot.data[item] - - def get(self, key, default=None): - return self.bot.data.get(key, default) - - @classmethod - def current(cls): - return dispatcher.get() - async def skip_updates(self): """ You can skip old incoming updates from queue. @@ -245,8 +228,8 @@ class Dispatcher: log.info('Start polling.') # context.set_value(MODE, LONG_POLLING) - dispatcher.set(self) - bot.bot.set(self.bot) + Dispatcher.set_current(self) + Bot.set_current(self.bot) if reset_webhook is None: await self.reset_webhook(check=False) @@ -867,10 +850,10 @@ class Dispatcher: :return: """ if chat is None: - chat_obj = types.Chat.current() + chat_obj = types.Chat.get_current() chat = chat_obj.id if chat_obj else None if user is None: - user_obj = types.User.current() + user_obj = types.User.get_current() user = user_obj.id if user_obj else None return FSMContext(storage=self.storage, chat=chat, user=user) @@ -895,8 +878,8 @@ class Dispatcher: if rate is None: rate = self.throttling_rate_limit if user is None and chat is None: - user = types.User.current() - chat = types.Chat.current() + user = types.User.get_current() + chat = types.Chat.get_current() # Detect current time now = time.time() @@ -945,8 +928,8 @@ class Dispatcher: raise RuntimeError('This storage does not provide Leaky Bucket') if user is None and chat is None: - user = types.User.current() - chat = types.Chat.current() + user = types.User.get_current() + chat = types.Chat.get_current() bucket = await self.storage.get_bucket(chat=chat, user=user) data = bucket.get(key, {}) @@ -965,8 +948,8 @@ class Dispatcher: raise RuntimeError('This storage does not provide Leaky Bucket') if user is None and chat is None: - user = types.User.current() - chat = types.Chat.current() + user = types.User.get_current() + chat = types.Chat.get_current() bucket = await self.storage.get_bucket(chat=chat, user=user) if bucket and key in bucket: @@ -997,7 +980,7 @@ class Dispatcher: response = task.result() except Exception as e: self.loop.create_task( - self.errors_handlers.notify(types.Update.current(), e)) + self.errors_handlers.notify(types.Update.get_current(), e)) else: if isinstance(response, BaseResponse): self.loop.create_task(response.execute_response(self.bot)) @@ -1016,6 +999,3 @@ class Dispatcher: if run_task: return self.async_task(callback) return callback - - -dispatcher: ContextVar[Dispatcher] = ContextVar('dispatcher_instance', default=None) diff --git a/aiogram/dispatcher/filters/state.py b/aiogram/dispatcher/filters/state.py index fadc3687..afe08e64 100644 --- a/aiogram/dispatcher/filters/state.py +++ b/aiogram/dispatcher/filters/state.py @@ -53,7 +53,7 @@ class State: __repr__ = __str__ async def set(self): - state = Dispatcher.current().current_state() + state = Dispatcher.get_current().current_state() await state.set_state(self.state) @@ -143,7 +143,7 @@ class StatesGroupMeta(type): class StatesGroup(metaclass=StatesGroupMeta): @classmethod async def next(cls) -> str: - state = Dispatcher.current().current_state() + state = Dispatcher.get_current().current_state() state_name = await state.get_state() try: @@ -161,7 +161,7 @@ class StatesGroup(metaclass=StatesGroupMeta): @classmethod async def previous(cls) -> str: - state = Dispatcher.current().current_state() + state = Dispatcher.get_current().current_state() state_name = await state.get_state() try: @@ -179,7 +179,7 @@ class StatesGroup(metaclass=StatesGroupMeta): @classmethod async def first(cls) -> str: - state = Dispatcher.current().current_state() + state = Dispatcher.get_current().current_state() first_step_name = cls.states_names[0] await state.set_state(first_step_name) @@ -187,7 +187,7 @@ class StatesGroup(metaclass=StatesGroupMeta): @classmethod async def last(cls) -> str: - state = Dispatcher.current().current_state() + state = Dispatcher.get_current().current_state() last_step_name = cls.states_names[-1] await state.set_state(last_step_name) diff --git a/aiogram/dispatcher/webhook.py b/aiogram/dispatcher/webhook.py index bc2a0e60..2dc9c70b 100644 --- a/aiogram/dispatcher/webhook.py +++ b/aiogram/dispatcher/webhook.py @@ -89,10 +89,9 @@ class WebhookRequestHandler(web.View): """ dp = self.request.app[BOT_DISPATCHER_KEY] try: - from aiogram.bot import bot - from aiogram.dispatcher import dispatcher - dispatcher.set(dp) - bot.bot.set(dp.bot) + from aiogram import Bot, Dispatcher + Dispatcher.set_current(dp) + Bot.set_current(dp.bot) except RuntimeError: pass return dp @@ -204,7 +203,7 @@ class WebhookRequestHandler(web.View): results = task.result() except Exception as e: loop.create_task( - dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e)) + dispatcher.errors_handlers.notify(dispatcher, types.Update.get_current(), e)) else: response = self.get_response(results) if response is not None: @@ -355,7 +354,7 @@ class BaseResponse: async def __call__(self, bot=None): if bot is None: from aiogram import Bot - bot = Bot.current() + bot = Bot.get_current() return await self.execute_response(bot) async def __aenter__(self): @@ -449,7 +448,7 @@ class ParseModeMixin: :return: """ from aiogram import Bot - bot = Bot.current() + bot = Bot.get_current() if bot is not None: return bot.parse_mode diff --git a/aiogram/types/base.py b/aiogram/types/base.py index 9982ad35..8125a37d 100644 --- a/aiogram/types/base.py +++ b/aiogram/types/base.py @@ -2,11 +2,11 @@ from __future__ import annotations import io import typing -from contextvars import ContextVar from typing import TypeVar from .fields import BaseField from ..utils import json +from ..utils.mixins import ContextInstanceMixin __all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean') @@ -57,7 +57,6 @@ class MetaTelegramObject(type): mcs._objects[cls.__name__] = cls - cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None? return cls @property @@ -65,7 +64,7 @@ class MetaTelegramObject(type): return cls._objects -class TelegramObject(metaclass=MetaTelegramObject): +class TelegramObject(ContextInstanceMixin, metaclass=MetaTelegramObject): """ Abstract class for telegram objects """ @@ -93,14 +92,6 @@ class TelegramObject(metaclass=MetaTelegramObject): if value.default and key not in self.values: self.values[key] = value.default - @classmethod - def current(cls): - return cls._current.get() - - @classmethod - def set_current(cls, obj: TelegramObject): - return cls._current.set(obj) - @property def conf(self) -> typing.Dict[str, typing.Any]: return self._conf @@ -151,7 +142,7 @@ class TelegramObject(metaclass=MetaTelegramObject): @property def bot(self): from ..bot.bot import Bot - return Bot.current() + return Bot.get_current() def to_python(self) -> typing.Dict: """ diff --git a/aiogram/types/chat.py b/aiogram/types/chat.py index cc476de1..cd34f1be 100644 --- a/aiogram/types/chat.py +++ b/aiogram/types/chat.py @@ -511,7 +511,7 @@ class ChatActions(helper.Helper): @classmethod async def _do(cls, action: str, sleep=None): from aiogram import Bot - await Bot.current().send_chat_action(Chat.current().id, action) + await Bot.get_current().send_chat_action(Chat.get_current().id, action) if sleep: await asyncio.sleep(sleep) diff --git a/aiogram/utils/executor.py b/aiogram/utils/executor.py index f8611a26..0679de8c 100644 --- a/aiogram/utils/executor.py +++ b/aiogram/utils/executor.py @@ -103,10 +103,9 @@ class Executor: self._freeze = False - from aiogram.bot.bot import bot as ctx_bot - from aiogram.dispatcher import dispatcher as ctx_dp - ctx_bot.set(dispatcher.bot) - ctx_dp.set(dispatcher) + from aiogram import Bot, Dispatcher + Bot.set_current(dispatcher.bot) + Dispatcher.set_current(dispatcher) @property def frozen(self): diff --git a/aiogram/utils/mixins.py b/aiogram/utils/mixins.py new file mode 100644 index 00000000..dba2f2cd --- /dev/null +++ b/aiogram/utils/mixins.py @@ -0,0 +1,42 @@ +import contextvars + + +class DataMixin: + @property + def data(self): + data = getattr(self, '_data', None) + if data is None: + data = {} + setattr(self, '_data', data) + return data + + def __getitem__(self, item): + return self.data[item] + + def __setitem__(self, key, value): + self.data[key] = value + + def __delitem__(self, key): + del self.data[key] + + def get(self, key, default=None): + return self.data.get(key, default) + + +class ContextInstanceMixin: + def __init_subclass__(cls, **kwargs): + cls.__context_instance = contextvars.ContextVar('instance_' + cls.__name__) + return cls + + @classmethod + def get_current(cls, no_error=True): + if no_error: + return cls.__context_instance.get(None) + return cls.__context_instance.get() + + @classmethod + def set_current(cls, value): + if not isinstance(value, cls): + raise TypeError(f"Value should be instance of '{cls.__name__}' not '{type(value).__name__}'") + cls.__context_instance.set(value) + diff --git a/examples/middleware_and_antiflood.py b/examples/middleware_and_antiflood.py index 6f26b2ee..7986bf3f 100644 --- a/examples/middleware_and_antiflood.py +++ b/examples/middleware_and_antiflood.py @@ -56,7 +56,7 @@ class ThrottlingMiddleware(BaseMiddleware): handler = current_handler.get() # Get dispatcher from context - dispatcher = Dispatcher.current() + dispatcher = Dispatcher.get_current() # If handler was configured, get rate limit and key from handler if handler: limit = getattr(handler, 'throttling_rate_limit', self.rate_limit) @@ -83,7 +83,7 @@ class ThrottlingMiddleware(BaseMiddleware): :param throttled: """ handler = current_handler.get() - dispatcher = Dispatcher.current() + dispatcher = Dispatcher.get_current() if handler: key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") else: