mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-10 17:53:17 +00:00
Remove context util.
This commit is contained in:
parent
5c1eee4fa9
commit
06fbe0d9cd
18 changed files with 1169 additions and 1333 deletions
|
|
@ -1,5 +1,4 @@
|
|||
from aiogram import types
|
||||
from aiogram.dispatcher import ctx
|
||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
|
||||
OBJ_KEY = '_context_data'
|
||||
|
|
@ -46,7 +45,7 @@ class ContextMiddleware(BaseMiddleware):
|
|||
|
||||
:return:
|
||||
"""
|
||||
update = ctx.get_update()
|
||||
update = types.Update.current()
|
||||
obj = update.conf.get(OBJ_KEY, None)
|
||||
if obj is None:
|
||||
obj = self._configure_update(update)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,42 +0,0 @@
|
|||
from . import Bot
|
||||
from .. import types
|
||||
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
|
||||
from ..utils import context
|
||||
|
||||
|
||||
def _get(key, default=None, no_error=False):
|
||||
result = context.get_value(key, default)
|
||||
if not no_error and result is None:
|
||||
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
|
||||
f"Maybe asyncio task factory is not configured!\n"
|
||||
f"\t>>> from aiogram.utils import context\n"
|
||||
f"\t>>> loop.set_task_factory(context.task_factory)")
|
||||
return result
|
||||
|
||||
|
||||
def get_bot() -> Bot:
|
||||
return _get('bot')
|
||||
|
||||
|
||||
def get_dispatcher() -> Dispatcher:
|
||||
return _get('dispatcher')
|
||||
|
||||
|
||||
def get_update() -> types.Update:
|
||||
return _get(UPDATE_OBJECT)
|
||||
|
||||
|
||||
def get_mode() -> str:
|
||||
return _get(MODE, 'unknown')
|
||||
|
||||
|
||||
def get_chat() -> int:
|
||||
return _get('chat', no_error=True)
|
||||
|
||||
|
||||
def get_user() -> int:
|
||||
return _get('user', no_error=True)
|
||||
|
||||
|
||||
def get_state() -> FSMContext:
|
||||
return get_dispatcher().current_state()
|
||||
1076
aiogram/dispatcher/dispatcher.py
Normal file
1076
aiogram/dispatcher/dispatcher.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
import re
|
||||
from _contextvars import ContextVar
|
||||
|
||||
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
|
||||
from aiogram.types import CallbackQuery, ContentType, Message
|
||||
from aiogram.utils import context
|
||||
|
||||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
|
@ -130,6 +130,8 @@ class StateFilter(BaseFilter):
|
|||
"""
|
||||
key = 'state'
|
||||
|
||||
ctx_state = ContextVar('user_state')
|
||||
|
||||
def __init__(self, dispatcher, state):
|
||||
super().__init__(dispatcher)
|
||||
if isinstance(state, str):
|
||||
|
|
@ -143,14 +145,16 @@ class StateFilter(BaseFilter):
|
|||
if self.state == '*':
|
||||
return True
|
||||
|
||||
if context.check_value(USER_STATE):
|
||||
context_state = context.get_value(USER_STATE)
|
||||
return self.state == context_state
|
||||
else:
|
||||
try:
|
||||
return self.state == self.ctx_state.get()
|
||||
except LookupError:
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||
self.ctx_state.set(state)
|
||||
return state == self.state
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from .filters import AbstractFilter, FilterRecord
|
|||
from ..handler import Handler
|
||||
|
||||
|
||||
# TODO: provide to set default filters (Like state. It will be always be added to filters set)
|
||||
# TODO: move check_filter/check_filters functions to FiltersFactory class
|
||||
|
||||
class FiltersFactory:
|
||||
|
|
@ -18,15 +17,17 @@ class FiltersFactory:
|
|||
|
||||
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
|
||||
validator: typing.Optional[typing.Callable] = None,
|
||||
event_handlers: typing.Optional[typing.List[Handler]] = None):
|
||||
event_handlers: typing.Optional[typing.List[Handler]] = None,
|
||||
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||
"""
|
||||
Register filter
|
||||
|
||||
:param callback: callable or subclass of :obj:`AbstractFilter`
|
||||
:param validator: custom validator.
|
||||
:param event_handlers: list of instances of :obj:`Handler`
|
||||
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
|
||||
"""
|
||||
record = FilterRecord(callback, validator, event_handlers)
|
||||
record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
|
||||
self._registered.append(record)
|
||||
|
||||
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
|
||||
|
|
@ -52,17 +53,21 @@ class FiltersFactory:
|
|||
filters_set = []
|
||||
if custom_filters:
|
||||
filters_set.extend(custom_filters)
|
||||
if full_config:
|
||||
filters_set.extend(self._resolve_registered(self._dispatcher, event_handler,
|
||||
filters_set.extend(self._resolve_registered(event_handler,
|
||||
{k: v for k, v in full_config.items() if v is not None}))
|
||||
|
||||
return filters_set
|
||||
|
||||
def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator:
|
||||
for record in self._registered:
|
||||
if not full_config:
|
||||
break
|
||||
def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
|
||||
"""
|
||||
Resolve registered filters
|
||||
|
||||
filter_ = record.resolve(dispatcher, event_handler, full_config)
|
||||
:param event_handler:
|
||||
:param full_config:
|
||||
:return:
|
||||
"""
|
||||
for record in self._registered:
|
||||
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
|
||||
if filter_:
|
||||
yield filter_
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,10 @@ class FilterRecord:
|
|||
return
|
||||
config = self.resolver(full_config)
|
||||
if config:
|
||||
for key in config:
|
||||
if key in full_config:
|
||||
full_config.pop(key)
|
||||
|
||||
return self.callback(dispatcher, **config)
|
||||
|
||||
def _check_event_handler(self, event_handler) -> bool:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,3 @@
|
|||
from ..utils import context
|
||||
|
||||
|
||||
class SkipHandler(BaseException):
|
||||
pass
|
||||
|
||||
|
|
@ -70,7 +67,7 @@ class Handler:
|
|||
if await check_filters(filters, args):
|
||||
try:
|
||||
if self.middleware_key:
|
||||
context.set_value('handler', handler)
|
||||
# context.set_value('handler', handler)
|
||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||
response = await handler(*args)
|
||||
if response is not None:
|
||||
|
|
|
|||
|
|
@ -8,11 +8,13 @@ from typing import Dict, List, Optional, Union
|
|||
|
||||
from aiohttp import web
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.bot import bot
|
||||
from aiogram.dispatcher import dispatcher
|
||||
from .. import types
|
||||
from ..bot import api
|
||||
from ..types import ParseMode
|
||||
from ..types.base import Boolean, Float, Integer, String
|
||||
from ..utils import context
|
||||
from ..utils import helper, markdown
|
||||
from ..utils import json
|
||||
from ..utils.deprecated import warn_deprecated as warn
|
||||
|
|
@ -88,8 +90,8 @@ class WebhookRequestHandler(web.View):
|
|||
"""
|
||||
dp = self.request.app[BOT_DISPATCHER_KEY]
|
||||
try:
|
||||
context.set_value('dispatcher', dp)
|
||||
context.set_value('bot', dp.bot)
|
||||
dispatcher.set(dp)
|
||||
bot.bot.set(dp.bot)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return dp
|
||||
|
|
@ -116,9 +118,9 @@ class WebhookRequestHandler(web.View):
|
|||
"""
|
||||
self.validate_ip()
|
||||
|
||||
context.update_state({'CALLER': WEBHOOK,
|
||||
WEBHOOK_CONNECTION: True,
|
||||
WEBHOOK_REQUEST: self.request})
|
||||
# context.update_state({'CALLER': WEBHOOK,
|
||||
# WEBHOOK_CONNECTION: True,
|
||||
# WEBHOOK_REQUEST: self.request})
|
||||
|
||||
dispatcher = self.get_dispatcher()
|
||||
update = await self.parse_update(dispatcher.bot)
|
||||
|
|
@ -170,7 +172,7 @@ class WebhookRequestHandler(web.View):
|
|||
if fut.done():
|
||||
return fut.result()
|
||||
else:
|
||||
context.set_value(WEBHOOK_CONNECTION, False)
|
||||
# context.set_value(WEBHOOK_CONNECTION, False)
|
||||
fut.remove_done_callback(cb)
|
||||
fut.add_done_callback(self.respond_via_request)
|
||||
finally:
|
||||
|
|
@ -195,7 +197,7 @@ class WebhookRequestHandler(web.View):
|
|||
results = task.result()
|
||||
except Exception as e:
|
||||
loop.create_task(
|
||||
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
|
||||
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
|
||||
else:
|
||||
response = self.get_response(results)
|
||||
if response is not None:
|
||||
|
|
@ -242,7 +244,7 @@ class WebhookRequestHandler(web.View):
|
|||
ip_address, accept = self.check_ip()
|
||||
if not accept:
|
||||
raise web.HTTPUnauthorized()
|
||||
context.set_value('TELEGRAM_IP', ip_address)
|
||||
# context.set_value('TELEGRAM_IP', ip_address)
|
||||
|
||||
|
||||
def configure_app(dispatcher, app: web.Application, path=DEFAULT_WEB_PATH):
|
||||
|
|
@ -332,8 +334,8 @@ class BaseResponse:
|
|||
|
||||
async def __call__(self, bot=None):
|
||||
if bot is None:
|
||||
from aiogram.dispatcher import ctx
|
||||
bot = ctx.get_bot()
|
||||
from aiogram import Bot
|
||||
bot = Bot.current()
|
||||
return await self.execute_response(bot)
|
||||
|
||||
async def __aenter__(self):
|
||||
|
|
@ -426,7 +428,7 @@ class ParseModeMixin:
|
|||
|
||||
:return:
|
||||
"""
|
||||
bot = context.get_value('bot', None)
|
||||
bot = Bot.current()
|
||||
if bot is not None:
|
||||
return bot.parse_mode
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import typing
|
||||
from contextvars import ContextVar
|
||||
from typing import TypeVar
|
||||
|
||||
from .fields import BaseField
|
||||
|
|
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
|
|||
setattr(cls, ALIASES_ATTR_NAME, aliases)
|
||||
|
||||
mcs._objects[cls.__name__] = cls
|
||||
|
||||
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
|
||||
return cls
|
||||
|
||||
@property
|
||||
|
|
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
|||
if value.default and key not in self.values:
|
||||
self.values[key] = value.default
|
||||
|
||||
@classmethod
|
||||
def current(cls):
|
||||
return cls._current.get()
|
||||
|
||||
@classmethod
|
||||
def set_current(cls, obj: TelegramObject):
|
||||
return cls._current.set(obj)
|
||||
|
||||
@property
|
||||
def conf(self) -> typing.Dict[str, typing.Any]:
|
||||
return self._conf
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
from contextvars import ContextVar
|
||||
|
||||
from . import base
|
||||
from . import fields
|
||||
|
|
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
|
|||
|
||||
@classmethod
|
||||
async def _do(cls, action: str, sleep=None):
|
||||
from ..dispatcher.ctx import get_bot, get_chat
|
||||
await get_bot().send_chat_action(get_chat(), action)
|
||||
from aiogram import Bot
|
||||
await Bot.current().send_chat_action(Chat.current(), action)
|
||||
if sleep:
|
||||
await asyncio.sleep(sleep)
|
||||
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ class Message(base.TelegramObject):
|
|||
return text
|
||||
|
||||
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
|
||||
disable_notification=None, reply_markup=None, reply=True) -> 'Message':
|
||||
disable_notification=None, reply_markup=None, reply=True) -> Message:
|
||||
"""
|
||||
Reply to this message
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from . import base
|
||||
from . import fields
|
||||
from .callback_query import CallbackQuery
|
||||
|
|
@ -12,8 +10,6 @@ from .pre_checkout_query import PreCheckoutQuery
|
|||
from .shipping_query import ShippingQuery
|
||||
from ..utils import helper
|
||||
|
||||
current_update: ContextVar[Update] = ContextVar('current_update_object', default=None)
|
||||
|
||||
|
||||
class Update(base.TelegramObject):
|
||||
"""
|
||||
|
|
@ -33,14 +29,6 @@ class Update(base.TelegramObject):
|
|||
shipping_query: ShippingQuery = fields.Field(base=ShippingQuery)
|
||||
pre_checkout_query: PreCheckoutQuery = fields.Field(base=PreCheckoutQuery)
|
||||
|
||||
@classmethod
|
||||
def current(cls):
|
||||
return current_update.get()
|
||||
|
||||
@classmethod
|
||||
def set_current(cls, update: Update):
|
||||
return current_update.set(update)
|
||||
|
||||
def __hash__(self):
|
||||
return self.update_id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import babel
|
||||
|
||||
from . import base
|
||||
|
|
|
|||
|
|
@ -1,140 +0,0 @@
|
|||
"""
|
||||
You need to setup task factory:
|
||||
>>> from aiogram.utils import context
|
||||
>>> loop = asyncio.get_event_loop()
|
||||
>>> loop.set_task_factory(context.task_factory)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
|
||||
|
||||
|
||||
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
|
||||
"""
|
||||
Task factory for implementing context processor
|
||||
|
||||
:param loop:
|
||||
:param coro:
|
||||
:return: new task
|
||||
:rtype: :obj:`asyncio.Task`
|
||||
"""
|
||||
# Is not allowed when loop is closed.
|
||||
if loop.is_closed():
|
||||
raise RuntimeError('Event loop is closed.')
|
||||
|
||||
task = asyncio.Task(coro, loop=loop)
|
||||
|
||||
# Hide factory
|
||||
if task._source_traceback:
|
||||
del task._source_traceback[-1]
|
||||
|
||||
try:
|
||||
task.context = asyncio.Task.current_task().context.copy()
|
||||
except AttributeError:
|
||||
task.context = {CONFIGURED: True}
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def get_current_state() -> typing.Dict:
|
||||
"""
|
||||
Get current execution context from task
|
||||
|
||||
:return: context
|
||||
:rtype: :obj:`dict`
|
||||
"""
|
||||
task = asyncio.Task.current_task()
|
||||
if task is None:
|
||||
raise RuntimeError('Can be used only in Task context.')
|
||||
context_ = getattr(task, 'context', None)
|
||||
if context_ is None:
|
||||
context_ = task.context = {}
|
||||
return context_
|
||||
|
||||
|
||||
def get_value(key, default=None):
|
||||
"""
|
||||
Get value from task
|
||||
|
||||
:param key:
|
||||
:param default:
|
||||
:return: value
|
||||
"""
|
||||
return get_current_state().get(key, default)
|
||||
|
||||
|
||||
def check_value(key):
|
||||
"""
|
||||
Key in context?
|
||||
|
||||
:param key:
|
||||
:return:
|
||||
"""
|
||||
return key in get_current_state()
|
||||
|
||||
|
||||
def set_value(key, value):
|
||||
"""
|
||||
Set value
|
||||
|
||||
:param key:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
get_current_state()[key] = value
|
||||
|
||||
|
||||
def del_value(key):
|
||||
"""
|
||||
Remove value from context
|
||||
|
||||
:param key:
|
||||
:return:
|
||||
"""
|
||||
del get_current_state()[key]
|
||||
|
||||
|
||||
def update_state(data=None, **kwargs):
|
||||
"""
|
||||
Update multiple state items
|
||||
|
||||
:param data:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if data is None:
|
||||
data = {}
|
||||
state = get_current_state()
|
||||
state.update(data, **kwargs)
|
||||
|
||||
|
||||
def check_configured():
|
||||
"""
|
||||
Check loop is configured
|
||||
:return:
|
||||
"""
|
||||
return get_value(CONFIGURED)
|
||||
|
||||
|
||||
class _Context:
|
||||
"""
|
||||
Other things for interactions with the execution context.
|
||||
"""
|
||||
|
||||
def __getitem__(self, item):
|
||||
return get_value(item)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
set_value(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del_value(key)
|
||||
|
||||
@staticmethod
|
||||
def get_context():
|
||||
return get_current_state()
|
||||
|
||||
|
||||
context = _Context()
|
||||
|
|
@ -6,7 +6,6 @@ from warnings import warn
|
|||
|
||||
from aiohttp import web
|
||||
|
||||
from . import context
|
||||
from ..bot.api import log
|
||||
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
|
||||
|
||||
|
|
@ -179,13 +178,13 @@ class Executor:
|
|||
self._check_frozen()
|
||||
self._freeze = True
|
||||
|
||||
self.loop.set_task_factory(context.task_factory)
|
||||
# self.loop.set_task_factory(context.task_factory)
|
||||
|
||||
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
|
||||
self._check_frozen()
|
||||
self._freeze = True
|
||||
|
||||
self.loop.set_task_factory(context.task_factory)
|
||||
# self.loop.set_task_factory(context.task_factory)
|
||||
|
||||
app = self._web_app
|
||||
if app is None:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import asyncio
|
|||
|
||||
from aiogram import Bot, types
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
||||
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
|
||||
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher
|
||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
from aiogram.utils import context, executor
|
||||
from aiogram.utils import executor
|
||||
from aiogram.utils.exceptions import Throttled
|
||||
|
||||
TOKEN = 'BOT TOKEN HERE'
|
||||
|
|
@ -53,10 +53,10 @@ class ThrottlingMiddleware(BaseMiddleware):
|
|||
:param message:
|
||||
"""
|
||||
# Get current handler
|
||||
handler = context.get_value('handler')
|
||||
# handler = context.get_value('handler')
|
||||
|
||||
# Get dispatcher from context
|
||||
dispatcher = ctx.get_dispatcher()
|
||||
dispatcher = Dispatcher.current()
|
||||
|
||||
# If handler was configured, get rate limit and key from handler
|
||||
if handler:
|
||||
|
|
@ -83,8 +83,8 @@ class ThrottlingMiddleware(BaseMiddleware):
|
|||
:param message:
|
||||
:param throttled:
|
||||
"""
|
||||
handler = context.get_value('handler')
|
||||
dispatcher = ctx.get_dispatcher()
|
||||
# handler = context.get_value('handler')
|
||||
dispatcher = Dispatcher.current()
|
||||
if handler:
|
||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue