mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 01:54:53 +00:00
Rewrite contextvar usage. Implemented ContextInstanceMixin and DataMixin
This commit is contained in:
parent
8ef279bba1
commit
39c333251f
12 changed files with 82 additions and 120 deletions
|
|
@ -173,43 +173,6 @@ class BaseBot:
|
|||
|
||||
return await self.request(method, payload, files)
|
||||
|
||||
@property
|
||||
def data(self) -> Dict:
|
||||
"""
|
||||
Data stored in bot object
|
||||
|
||||
:return: Dictionary
|
||||
"""
|
||||
return self._data
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""
|
||||
Store data in bot instance
|
||||
|
||||
:param key: Key in dict
|
||||
:param value: Value
|
||||
"""
|
||||
self._data[key] = value
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""
|
||||
Get item from bot instance by key
|
||||
|
||||
:param item: key name
|
||||
:return: value
|
||||
"""
|
||||
return self._data[item]
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Get item from bot instance by key or return default value
|
||||
|
||||
:param key: key in dict
|
||||
:param default: default value
|
||||
:return: value or default value
|
||||
"""
|
||||
return self._data.get(key, default)
|
||||
|
||||
@property
|
||||
def parse_mode(self):
|
||||
return getattr(self, '_parse_mode', None)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from contextvars import ContextVar
|
||||
|
||||
from .base import BaseBot, api
|
||||
from .. import types
|
||||
from ..types import base
|
||||
from ..utils.mixins import DataMixin, ContextInstanceMixin
|
||||
from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file
|
||||
|
||||
|
||||
class Bot(BaseBot):
|
||||
class Bot(BaseBot, DataMixin, ContextInstanceMixin):
|
||||
"""
|
||||
Base bot class
|
||||
"""
|
||||
|
|
@ -39,14 +39,6 @@ class Bot(BaseBot):
|
|||
if hasattr(self, '_me'):
|
||||
delattr(self, '_me')
|
||||
|
||||
@classmethod
|
||||
def current(cls) -> Bot:
|
||||
"""
|
||||
Return active bot instance from the current context or None
|
||||
:return: Bot or None
|
||||
"""
|
||||
return bot.get()
|
||||
|
||||
async def download_file_by_id(self, file_id: base.String, destination=None,
|
||||
timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
|
||||
seek: base.Boolean = True):
|
||||
|
|
@ -98,7 +90,7 @@ class Bot(BaseBot):
|
|||
allowed_updates = prepare_arg(allowed_updates)
|
||||
payload = generate_payload(**locals())
|
||||
|
||||
result = await self.request(api.Methods.GET_UPDATES, payload, timeout=timeout + 2 if timeout else None)
|
||||
result = await self.request(api.Methods.GET_UPDATES, payload)
|
||||
return [types.Update(**update) for update in result]
|
||||
|
||||
async def set_webhook(self, url: base.String,
|
||||
|
|
@ -2064,6 +2056,3 @@ class Bot(BaseBot):
|
|||
result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload)
|
||||
|
||||
return [types.GameHighScore(**gamehighscore) for gamehighscore in result]
|
||||
|
||||
|
||||
bot: ContextVar[Bot] = ContextVar('bot_instance', default=None)
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class I18nMiddleware(BaseMiddleware):
|
|||
:param args: event arguments
|
||||
:return: locale name
|
||||
"""
|
||||
user: types.User = types.User.current()
|
||||
user: types.User = types.User.get_current()
|
||||
locale: Locale = user.locale
|
||||
|
||||
if locale:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@ from . import handler
|
|||
from . import middlewares
|
||||
from . import storage
|
||||
from . import webhook
|
||||
from .dispatcher import Dispatcher, dispatcher, FSMContext, DEFAULT_RATE_LIMIT
|
||||
from .dispatcher import Dispatcher, FSMContext, DEFAULT_RATE_LIMIT
|
||||
|
||||
__all__ = [
|
||||
'DEFAULT_RATE_LIMIT',
|
||||
'Dispatcher',
|
||||
'dispatcher',
|
||||
'FSMContext',
|
||||
'filters',
|
||||
'handler',
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import itertools
|
|||
import logging
|
||||
import time
|
||||
import typing
|
||||
from contextvars import ContextVar
|
||||
|
||||
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
|
||||
RegexpCommandsFilter, StateFilter, Text
|
||||
|
|
@ -14,15 +13,16 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon
|
|||
LAST_CALL, RATE_LIMIT, RESULT
|
||||
from .webhook import BaseResponse
|
||||
from .. import types
|
||||
from ..bot import Bot, bot
|
||||
from ..bot import Bot
|
||||
from ..utils.exceptions import TelegramAPIError, Throttled
|
||||
from ..utils.mixins import ContextInstanceMixin, DataMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RATE_LIMIT = .1
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||
"""
|
||||
Simple Updates dispatcher
|
||||
|
||||
|
|
@ -112,23 +112,6 @@ class Dispatcher:
|
|||
def __del__(self):
|
||||
self.stop_polling()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.bot.data
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.bot.data[key] = value
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.bot.data[item]
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.bot.data.get(key, default)
|
||||
|
||||
@classmethod
|
||||
def current(cls):
|
||||
return dispatcher.get()
|
||||
|
||||
async def skip_updates(self):
|
||||
"""
|
||||
You can skip old incoming updates from queue.
|
||||
|
|
@ -245,8 +228,8 @@ class Dispatcher:
|
|||
log.info('Start polling.')
|
||||
|
||||
# context.set_value(MODE, LONG_POLLING)
|
||||
dispatcher.set(self)
|
||||
bot.bot.set(self.bot)
|
||||
Dispatcher.set_current(self)
|
||||
Bot.set_current(self.bot)
|
||||
|
||||
if reset_webhook is None:
|
||||
await self.reset_webhook(check=False)
|
||||
|
|
@ -867,10 +850,10 @@ class Dispatcher:
|
|||
:return:
|
||||
"""
|
||||
if chat is None:
|
||||
chat_obj = types.Chat.current()
|
||||
chat_obj = types.Chat.get_current()
|
||||
chat = chat_obj.id if chat_obj else None
|
||||
if user is None:
|
||||
user_obj = types.User.current()
|
||||
user_obj = types.User.get_current()
|
||||
user = user_obj.id if user_obj else None
|
||||
|
||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||
|
|
@ -895,8 +878,8 @@ class Dispatcher:
|
|||
if rate is None:
|
||||
rate = self.throttling_rate_limit
|
||||
if user is None and chat is None:
|
||||
user = types.User.current()
|
||||
chat = types.Chat.current()
|
||||
user = types.User.get_current()
|
||||
chat = types.Chat.get_current()
|
||||
|
||||
# Detect current time
|
||||
now = time.time()
|
||||
|
|
@ -945,8 +928,8 @@ class Dispatcher:
|
|||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if user is None and chat is None:
|
||||
user = types.User.current()
|
||||
chat = types.Chat.current()
|
||||
user = types.User.get_current()
|
||||
chat = types.Chat.get_current()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
data = bucket.get(key, {})
|
||||
|
|
@ -965,8 +948,8 @@ class Dispatcher:
|
|||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if user is None and chat is None:
|
||||
user = types.User.current()
|
||||
chat = types.Chat.current()
|
||||
user = types.User.get_current()
|
||||
chat = types.Chat.get_current()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
if bucket and key in bucket:
|
||||
|
|
@ -997,7 +980,7 @@ class Dispatcher:
|
|||
response = task.result()
|
||||
except Exception as e:
|
||||
self.loop.create_task(
|
||||
self.errors_handlers.notify(types.Update.current(), e))
|
||||
self.errors_handlers.notify(types.Update.get_current(), e))
|
||||
else:
|
||||
if isinstance(response, BaseResponse):
|
||||
self.loop.create_task(response.execute_response(self.bot))
|
||||
|
|
@ -1016,6 +999,3 @@ class Dispatcher:
|
|||
if run_task:
|
||||
return self.async_task(callback)
|
||||
return callback
|
||||
|
||||
|
||||
dispatcher: ContextVar[Dispatcher] = ContextVar('dispatcher_instance', default=None)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class State:
|
|||
__repr__ = __str__
|
||||
|
||||
async def set(self):
|
||||
state = Dispatcher.current().current_state()
|
||||
state = Dispatcher.get_current().current_state()
|
||||
await state.set_state(self.state)
|
||||
|
||||
|
||||
|
|
@ -143,7 +143,7 @@ class StatesGroupMeta(type):
|
|||
class StatesGroup(metaclass=StatesGroupMeta):
|
||||
@classmethod
|
||||
async def next(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
state = Dispatcher.get_current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
|
|
@ -161,7 +161,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
|||
|
||||
@classmethod
|
||||
async def previous(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
state = Dispatcher.get_current().current_state()
|
||||
state_name = await state.get_state()
|
||||
|
||||
try:
|
||||
|
|
@ -179,7 +179,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
|||
|
||||
@classmethod
|
||||
async def first(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
state = Dispatcher.get_current().current_state()
|
||||
first_step_name = cls.states_names[0]
|
||||
|
||||
await state.set_state(first_step_name)
|
||||
|
|
@ -187,7 +187,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
|||
|
||||
@classmethod
|
||||
async def last(cls) -> str:
|
||||
state = Dispatcher.current().current_state()
|
||||
state = Dispatcher.get_current().current_state()
|
||||
last_step_name = cls.states_names[-1]
|
||||
|
||||
await state.set_state(last_step_name)
|
||||
|
|
|
|||
|
|
@ -89,10 +89,9 @@ class WebhookRequestHandler(web.View):
|
|||
"""
|
||||
dp = self.request.app[BOT_DISPATCHER_KEY]
|
||||
try:
|
||||
from aiogram.bot import bot
|
||||
from aiogram.dispatcher import dispatcher
|
||||
dispatcher.set(dp)
|
||||
bot.bot.set(dp.bot)
|
||||
from aiogram import Bot, Dispatcher
|
||||
Dispatcher.set_current(dp)
|
||||
Bot.set_current(dp.bot)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return dp
|
||||
|
|
@ -204,7 +203,7 @@ class WebhookRequestHandler(web.View):
|
|||
results = task.result()
|
||||
except Exception as e:
|
||||
loop.create_task(
|
||||
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
|
||||
dispatcher.errors_handlers.notify(dispatcher, types.Update.get_current(), e))
|
||||
else:
|
||||
response = self.get_response(results)
|
||||
if response is not None:
|
||||
|
|
@ -355,7 +354,7 @@ class BaseResponse:
|
|||
async def __call__(self, bot=None):
|
||||
if bot is None:
|
||||
from aiogram import Bot
|
||||
bot = Bot.current()
|
||||
bot = Bot.get_current()
|
||||
return await self.execute_response(bot)
|
||||
|
||||
async def __aenter__(self):
|
||||
|
|
@ -449,7 +448,7 @@ class ParseModeMixin:
|
|||
:return:
|
||||
"""
|
||||
from aiogram import Bot
|
||||
bot = Bot.current()
|
||||
bot = Bot.get_current()
|
||||
if bot is not None:
|
||||
return bot.parse_mode
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from __future__ import annotations
|
|||
|
||||
import io
|
||||
import typing
|
||||
from contextvars import ContextVar
|
||||
from typing import TypeVar
|
||||
|
||||
from .fields import BaseField
|
||||
from ..utils import json
|
||||
from ..utils.mixins import ContextInstanceMixin
|
||||
|
||||
__all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean')
|
||||
|
||||
|
|
@ -57,7 +57,6 @@ class MetaTelegramObject(type):
|
|||
|
||||
mcs._objects[cls.__name__] = cls
|
||||
|
||||
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
|
||||
return cls
|
||||
|
||||
@property
|
||||
|
|
@ -65,7 +64,7 @@ class MetaTelegramObject(type):
|
|||
return cls._objects
|
||||
|
||||
|
||||
class TelegramObject(metaclass=MetaTelegramObject):
|
||||
class TelegramObject(ContextInstanceMixin, metaclass=MetaTelegramObject):
|
||||
"""
|
||||
Abstract class for telegram objects
|
||||
"""
|
||||
|
|
@ -93,14 +92,6 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
|||
if value.default and key not in self.values:
|
||||
self.values[key] = value.default
|
||||
|
||||
@classmethod
|
||||
def current(cls):
|
||||
return cls._current.get()
|
||||
|
||||
@classmethod
|
||||
def set_current(cls, obj: TelegramObject):
|
||||
return cls._current.set(obj)
|
||||
|
||||
@property
|
||||
def conf(self) -> typing.Dict[str, typing.Any]:
|
||||
return self._conf
|
||||
|
|
@ -151,7 +142,7 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
|||
@property
|
||||
def bot(self):
|
||||
from ..bot.bot import Bot
|
||||
return Bot.current()
|
||||
return Bot.get_current()
|
||||
|
||||
def to_python(self) -> typing.Dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -511,7 +511,7 @@ class ChatActions(helper.Helper):
|
|||
@classmethod
|
||||
async def _do(cls, action: str, sleep=None):
|
||||
from aiogram import Bot
|
||||
await Bot.current().send_chat_action(Chat.current().id, action)
|
||||
await Bot.get_current().send_chat_action(Chat.get_current().id, action)
|
||||
if sleep:
|
||||
await asyncio.sleep(sleep)
|
||||
|
||||
|
|
|
|||
|
|
@ -103,10 +103,9 @@ class Executor:
|
|||
|
||||
self._freeze = False
|
||||
|
||||
from aiogram.bot.bot import bot as ctx_bot
|
||||
from aiogram.dispatcher import dispatcher as ctx_dp
|
||||
ctx_bot.set(dispatcher.bot)
|
||||
ctx_dp.set(dispatcher)
|
||||
from aiogram import Bot, Dispatcher
|
||||
Bot.set_current(dispatcher.bot)
|
||||
Dispatcher.set_current(dispatcher)
|
||||
|
||||
@property
|
||||
def frozen(self):
|
||||
|
|
|
|||
42
aiogram/utils/mixins.py
Normal file
42
aiogram/utils/mixins.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import contextvars
|
||||
|
||||
|
||||
class DataMixin:
|
||||
@property
|
||||
def data(self):
|
||||
data = getattr(self, '_data', None)
|
||||
if data is None:
|
||||
data = {}
|
||||
setattr(self, '_data', data)
|
||||
return data
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.data[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.data[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.data.get(key, default)
|
||||
|
||||
|
||||
class ContextInstanceMixin:
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
cls.__context_instance = contextvars.ContextVar('instance_' + cls.__name__)
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def get_current(cls, no_error=True):
|
||||
if no_error:
|
||||
return cls.__context_instance.get(None)
|
||||
return cls.__context_instance.get()
|
||||
|
||||
@classmethod
|
||||
def set_current(cls, value):
|
||||
if not isinstance(value, cls):
|
||||
raise TypeError(f"Value should be instance of '{cls.__name__}' not '{type(value).__name__}'")
|
||||
cls.__context_instance.set(value)
|
||||
|
||||
|
|
@ -56,7 +56,7 @@ class ThrottlingMiddleware(BaseMiddleware):
|
|||
handler = current_handler.get()
|
||||
|
||||
# Get dispatcher from context
|
||||
dispatcher = Dispatcher.current()
|
||||
dispatcher = Dispatcher.get_current()
|
||||
# If handler was configured, get rate limit and key from handler
|
||||
if handler:
|
||||
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
|
||||
|
|
@ -83,7 +83,7 @@ class ThrottlingMiddleware(BaseMiddleware):
|
|||
:param throttled:
|
||||
"""
|
||||
handler = current_handler.get()
|
||||
dispatcher = Dispatcher.current()
|
||||
dispatcher = Dispatcher.get_current()
|
||||
if handler:
|
||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue