mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-10 17:53:17 +00:00
Remove context util.
This commit is contained in:
parent
5c1eee4fa9
commit
06fbe0d9cd
18 changed files with 1169 additions and 1333 deletions
|
|
@ -1,5 +1,4 @@
|
||||||
from aiogram import types
|
from aiogram import types
|
||||||
from aiogram.dispatcher import ctx
|
|
||||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||||
|
|
||||||
OBJ_KEY = '_context_data'
|
OBJ_KEY = '_context_data'
|
||||||
|
|
@ -46,7 +45,7 @@ class ContextMiddleware(BaseMiddleware):
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
update = ctx.get_update()
|
update = types.Update.current()
|
||||||
obj = update.conf.get(OBJ_KEY, None)
|
obj = update.conf.get(OBJ_KEY, None)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
obj = self._configure_update(update)
|
obj = self._configure_update(update)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,42 +0,0 @@
|
||||||
from . import Bot
|
|
||||||
from .. import types
|
|
||||||
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
|
|
||||||
from ..utils import context
|
|
||||||
|
|
||||||
|
|
||||||
def _get(key, default=None, no_error=False):
|
|
||||||
result = context.get_value(key, default)
|
|
||||||
if not no_error and result is None:
|
|
||||||
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
|
|
||||||
f"Maybe asyncio task factory is not configured!\n"
|
|
||||||
f"\t>>> from aiogram.utils import context\n"
|
|
||||||
f"\t>>> loop.set_task_factory(context.task_factory)")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_bot() -> Bot:
|
|
||||||
return _get('bot')
|
|
||||||
|
|
||||||
|
|
||||||
def get_dispatcher() -> Dispatcher:
|
|
||||||
return _get('dispatcher')
|
|
||||||
|
|
||||||
|
|
||||||
def get_update() -> types.Update:
|
|
||||||
return _get(UPDATE_OBJECT)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mode() -> str:
|
|
||||||
return _get(MODE, 'unknown')
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat() -> int:
|
|
||||||
return _get('chat', no_error=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_user() -> int:
|
|
||||||
return _get('user', no_error=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_state() -> FSMContext:
|
|
||||||
return get_dispatcher().current_state()
|
|
||||||
1076
aiogram/dispatcher/dispatcher.py
Normal file
1076
aiogram/dispatcher/dispatcher.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,9 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
from _contextvars import ContextVar
|
||||||
|
|
||||||
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
|
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
|
||||||
from aiogram.types import CallbackQuery, ContentType, Message
|
from aiogram.types import CallbackQuery, ContentType, Message
|
||||||
from aiogram.utils import context
|
|
||||||
|
|
||||||
USER_STATE = 'USER_STATE'
|
USER_STATE = 'USER_STATE'
|
||||||
|
|
||||||
|
|
@ -130,6 +130,8 @@ class StateFilter(BaseFilter):
|
||||||
"""
|
"""
|
||||||
key = 'state'
|
key = 'state'
|
||||||
|
|
||||||
|
ctx_state = ContextVar('user_state')
|
||||||
|
|
||||||
def __init__(self, dispatcher, state):
|
def __init__(self, dispatcher, state):
|
||||||
super().__init__(dispatcher)
|
super().__init__(dispatcher)
|
||||||
if isinstance(state, str):
|
if isinstance(state, str):
|
||||||
|
|
@ -143,14 +145,16 @@ class StateFilter(BaseFilter):
|
||||||
if self.state == '*':
|
if self.state == '*':
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if context.check_value(USER_STATE):
|
try:
|
||||||
context_state = context.get_value(USER_STATE)
|
return self.state == self.ctx_state.get()
|
||||||
return self.state == context_state
|
except LookupError:
|
||||||
else:
|
|
||||||
chat, user = self.get_target(obj)
|
chat, user = self.get_target(obj)
|
||||||
|
|
||||||
if chat or user:
|
if chat or user:
|
||||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||||
|
self.ctx_state.set(state)
|
||||||
|
return state == self.state
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from .filters import AbstractFilter, FilterRecord
|
||||||
from ..handler import Handler
|
from ..handler import Handler
|
||||||
|
|
||||||
|
|
||||||
# TODO: provide to set default filters (Like state. It will be always be added to filters set)
|
|
||||||
# TODO: move check_filter/check_filters functions to FiltersFactory class
|
# TODO: move check_filter/check_filters functions to FiltersFactory class
|
||||||
|
|
||||||
class FiltersFactory:
|
class FiltersFactory:
|
||||||
|
|
@ -18,15 +17,17 @@ class FiltersFactory:
|
||||||
|
|
||||||
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
|
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
|
||||||
validator: typing.Optional[typing.Callable] = None,
|
validator: typing.Optional[typing.Callable] = None,
|
||||||
event_handlers: typing.Optional[typing.List[Handler]] = None):
|
event_handlers: typing.Optional[typing.List[Handler]] = None,
|
||||||
|
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
|
||||||
"""
|
"""
|
||||||
Register filter
|
Register filter
|
||||||
|
|
||||||
:param callback: callable or subclass of :obj:`AbstractFilter`
|
:param callback: callable or subclass of :obj:`AbstractFilter`
|
||||||
:param validator: custom validator.
|
:param validator: custom validator.
|
||||||
:param event_handlers: list of instances of :obj:`Handler`
|
:param event_handlers: list of instances of :obj:`Handler`
|
||||||
|
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
|
||||||
"""
|
"""
|
||||||
record = FilterRecord(callback, validator, event_handlers)
|
record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
|
||||||
self._registered.append(record)
|
self._registered.append(record)
|
||||||
|
|
||||||
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
|
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
|
||||||
|
|
@ -52,17 +53,21 @@ class FiltersFactory:
|
||||||
filters_set = []
|
filters_set = []
|
||||||
if custom_filters:
|
if custom_filters:
|
||||||
filters_set.extend(custom_filters)
|
filters_set.extend(custom_filters)
|
||||||
if full_config:
|
filters_set.extend(self._resolve_registered(event_handler,
|
||||||
filters_set.extend(self._resolve_registered(self._dispatcher, event_handler,
|
{k: v for k, v in full_config.items() if v is not None}))
|
||||||
{k: v for k, v in full_config.items() if v is not None}))
|
|
||||||
return filters_set
|
return filters_set
|
||||||
|
|
||||||
def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator:
|
def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
|
||||||
for record in self._registered:
|
"""
|
||||||
if not full_config:
|
Resolve registered filters
|
||||||
break
|
|
||||||
|
|
||||||
filter_ = record.resolve(dispatcher, event_handler, full_config)
|
:param event_handler:
|
||||||
|
:param full_config:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for record in self._registered:
|
||||||
|
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
|
||||||
if filter_:
|
if filter_:
|
||||||
yield filter_
|
yield filter_
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,10 @@ class FilterRecord:
|
||||||
return
|
return
|
||||||
config = self.resolver(full_config)
|
config = self.resolver(full_config)
|
||||||
if config:
|
if config:
|
||||||
|
for key in config:
|
||||||
|
if key in full_config:
|
||||||
|
full_config.pop(key)
|
||||||
|
|
||||||
return self.callback(dispatcher, **config)
|
return self.callback(dispatcher, **config)
|
||||||
|
|
||||||
def _check_event_handler(self, event_handler) -> bool:
|
def _check_event_handler(self, event_handler) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,3 @@
|
||||||
from ..utils import context
|
|
||||||
|
|
||||||
|
|
||||||
class SkipHandler(BaseException):
|
class SkipHandler(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -70,7 +67,7 @@ class Handler:
|
||||||
if await check_filters(filters, args):
|
if await check_filters(filters, args):
|
||||||
try:
|
try:
|
||||||
if self.middleware_key:
|
if self.middleware_key:
|
||||||
context.set_value('handler', handler)
|
# context.set_value('handler', handler)
|
||||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||||
response = await handler(*args)
|
response = await handler(*args)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,13 @@ from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
from aiogram import Bot
|
||||||
|
from aiogram.bot import bot
|
||||||
|
from aiogram.dispatcher import dispatcher
|
||||||
from .. import types
|
from .. import types
|
||||||
from ..bot import api
|
from ..bot import api
|
||||||
from ..types import ParseMode
|
from ..types import ParseMode
|
||||||
from ..types.base import Boolean, Float, Integer, String
|
from ..types.base import Boolean, Float, Integer, String
|
||||||
from ..utils import context
|
|
||||||
from ..utils import helper, markdown
|
from ..utils import helper, markdown
|
||||||
from ..utils import json
|
from ..utils import json
|
||||||
from ..utils.deprecated import warn_deprecated as warn
|
from ..utils.deprecated import warn_deprecated as warn
|
||||||
|
|
@ -88,8 +90,8 @@ class WebhookRequestHandler(web.View):
|
||||||
"""
|
"""
|
||||||
dp = self.request.app[BOT_DISPATCHER_KEY]
|
dp = self.request.app[BOT_DISPATCHER_KEY]
|
||||||
try:
|
try:
|
||||||
context.set_value('dispatcher', dp)
|
dispatcher.set(dp)
|
||||||
context.set_value('bot', dp.bot)
|
bot.bot.set(dp.bot)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
return dp
|
return dp
|
||||||
|
|
@ -116,9 +118,9 @@ class WebhookRequestHandler(web.View):
|
||||||
"""
|
"""
|
||||||
self.validate_ip()
|
self.validate_ip()
|
||||||
|
|
||||||
context.update_state({'CALLER': WEBHOOK,
|
# context.update_state({'CALLER': WEBHOOK,
|
||||||
WEBHOOK_CONNECTION: True,
|
# WEBHOOK_CONNECTION: True,
|
||||||
WEBHOOK_REQUEST: self.request})
|
# WEBHOOK_REQUEST: self.request})
|
||||||
|
|
||||||
dispatcher = self.get_dispatcher()
|
dispatcher = self.get_dispatcher()
|
||||||
update = await self.parse_update(dispatcher.bot)
|
update = await self.parse_update(dispatcher.bot)
|
||||||
|
|
@ -170,7 +172,7 @@ class WebhookRequestHandler(web.View):
|
||||||
if fut.done():
|
if fut.done():
|
||||||
return fut.result()
|
return fut.result()
|
||||||
else:
|
else:
|
||||||
context.set_value(WEBHOOK_CONNECTION, False)
|
# context.set_value(WEBHOOK_CONNECTION, False)
|
||||||
fut.remove_done_callback(cb)
|
fut.remove_done_callback(cb)
|
||||||
fut.add_done_callback(self.respond_via_request)
|
fut.add_done_callback(self.respond_via_request)
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -195,7 +197,7 @@ class WebhookRequestHandler(web.View):
|
||||||
results = task.result()
|
results = task.result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
loop.create_task(
|
loop.create_task(
|
||||||
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
|
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
|
||||||
else:
|
else:
|
||||||
response = self.get_response(results)
|
response = self.get_response(results)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
|
@ -242,7 +244,7 @@ class WebhookRequestHandler(web.View):
|
||||||
ip_address, accept = self.check_ip()
|
ip_address, accept = self.check_ip()
|
||||||
if not accept:
|
if not accept:
|
||||||
raise web.HTTPUnauthorized()
|
raise web.HTTPUnauthorized()
|
||||||
context.set_value('TELEGRAM_IP', ip_address)
|
# context.set_value('TELEGRAM_IP', ip_address)
|
||||||
|
|
||||||
|
|
||||||
def configure_app(dispatcher, app: web.Application, path=DEFAULT_WEB_PATH):
|
def configure_app(dispatcher, app: web.Application, path=DEFAULT_WEB_PATH):
|
||||||
|
|
@ -332,8 +334,8 @@ class BaseResponse:
|
||||||
|
|
||||||
async def __call__(self, bot=None):
|
async def __call__(self, bot=None):
|
||||||
if bot is None:
|
if bot is None:
|
||||||
from aiogram.dispatcher import ctx
|
from aiogram import Bot
|
||||||
bot = ctx.get_bot()
|
bot = Bot.current()
|
||||||
return await self.execute_response(bot)
|
return await self.execute_response(bot)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
|
|
@ -426,7 +428,7 @@ class ParseModeMixin:
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
bot = context.get_value('bot', None)
|
bot = Bot.current()
|
||||||
if bot is not None:
|
if bot is not None:
|
||||||
return bot.parse_mode
|
return bot.parse_mode
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import typing
|
import typing
|
||||||
|
from contextvars import ContextVar
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from .fields import BaseField
|
from .fields import BaseField
|
||||||
|
|
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
|
||||||
setattr(cls, ALIASES_ATTR_NAME, aliases)
|
setattr(cls, ALIASES_ATTR_NAME, aliases)
|
||||||
|
|
||||||
mcs._objects[cls.__name__] = cls
|
mcs._objects[cls.__name__] = cls
|
||||||
|
|
||||||
|
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
||||||
if value.default and key not in self.values:
|
if value.default and key not in self.values:
|
||||||
self.values[key] = value.default
|
self.values[key] = value.default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def current(cls):
|
||||||
|
return cls._current.get()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_current(cls, obj: TelegramObject):
|
||||||
|
return cls._current.set(obj)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def conf(self) -> typing.Dict[str, typing.Any]:
|
def conf(self) -> typing.Dict[str, typing.Any]:
|
||||||
return self._conf
|
return self._conf
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
from . import base
|
from . import base
|
||||||
from . import fields
|
from . import fields
|
||||||
|
|
@ -64,7 +67,7 @@ class Chat(base.TelegramObject):
|
||||||
if as_html:
|
if as_html:
|
||||||
return markdown.hlink(name, self.user_url)
|
return markdown.hlink(name, self.user_url)
|
||||||
return markdown.link(name, self.user_url)
|
return markdown.link(name, self.user_url)
|
||||||
|
|
||||||
async def get_url(self):
|
async def get_url(self):
|
||||||
"""
|
"""
|
||||||
Use this method to get chat link.
|
Use this method to get chat link.
|
||||||
|
|
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _do(cls, action: str, sleep=None):
|
async def _do(cls, action: str, sleep=None):
|
||||||
from ..dispatcher.ctx import get_bot, get_chat
|
from aiogram import Bot
|
||||||
await get_bot().send_chat_action(get_chat(), action)
|
await Bot.current().send_chat_action(Chat.current(), action)
|
||||||
if sleep:
|
if sleep:
|
||||||
await asyncio.sleep(sleep)
|
await asyncio.sleep(sleep)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ class Message(base.TelegramObject):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
|
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
|
||||||
disable_notification=None, reply_markup=None, reply=True) -> 'Message':
|
disable_notification=None, reply_markup=None, reply=True) -> Message:
|
||||||
"""
|
"""
|
||||||
Reply to this message
|
Reply to this message
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextvars import ContextVar
|
|
||||||
|
|
||||||
from . import base
|
from . import base
|
||||||
from . import fields
|
from . import fields
|
||||||
from .callback_query import CallbackQuery
|
from .callback_query import CallbackQuery
|
||||||
|
|
@ -12,8 +10,6 @@ from .pre_checkout_query import PreCheckoutQuery
|
||||||
from .shipping_query import ShippingQuery
|
from .shipping_query import ShippingQuery
|
||||||
from ..utils import helper
|
from ..utils import helper
|
||||||
|
|
||||||
current_update: ContextVar[Update] = ContextVar('current_update_object', default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class Update(base.TelegramObject):
|
class Update(base.TelegramObject):
|
||||||
"""
|
"""
|
||||||
|
|
@ -33,14 +29,6 @@ class Update(base.TelegramObject):
|
||||||
shipping_query: ShippingQuery = fields.Field(base=ShippingQuery)
|
shipping_query: ShippingQuery = fields.Field(base=ShippingQuery)
|
||||||
pre_checkout_query: PreCheckoutQuery = fields.Field(base=PreCheckoutQuery)
|
pre_checkout_query: PreCheckoutQuery = fields.Field(base=PreCheckoutQuery)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def current(cls):
|
|
||||||
return current_update.get()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_current(cls, update: Update):
|
|
||||||
return current_update.set(update)
|
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return self.update_id
|
return self.update_id
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import babel
|
import babel
|
||||||
|
|
||||||
from . import base
|
from . import base
|
||||||
|
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
"""
|
|
||||||
You need to setup task factory:
|
|
||||||
>>> from aiogram.utils import context
|
|
||||||
>>> loop = asyncio.get_event_loop()
|
|
||||||
>>> loop.set_task_factory(context.task_factory)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import typing
|
|
||||||
|
|
||||||
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
|
|
||||||
|
|
||||||
|
|
||||||
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
|
|
||||||
"""
|
|
||||||
Task factory for implementing context processor
|
|
||||||
|
|
||||||
:param loop:
|
|
||||||
:param coro:
|
|
||||||
:return: new task
|
|
||||||
:rtype: :obj:`asyncio.Task`
|
|
||||||
"""
|
|
||||||
# Is not allowed when loop is closed.
|
|
||||||
if loop.is_closed():
|
|
||||||
raise RuntimeError('Event loop is closed.')
|
|
||||||
|
|
||||||
task = asyncio.Task(coro, loop=loop)
|
|
||||||
|
|
||||||
# Hide factory
|
|
||||||
if task._source_traceback:
|
|
||||||
del task._source_traceback[-1]
|
|
||||||
|
|
||||||
try:
|
|
||||||
task.context = asyncio.Task.current_task().context.copy()
|
|
||||||
except AttributeError:
|
|
||||||
task.context = {CONFIGURED: True}
|
|
||||||
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_state() -> typing.Dict:
|
|
||||||
"""
|
|
||||||
Get current execution context from task
|
|
||||||
|
|
||||||
:return: context
|
|
||||||
:rtype: :obj:`dict`
|
|
||||||
"""
|
|
||||||
task = asyncio.Task.current_task()
|
|
||||||
if task is None:
|
|
||||||
raise RuntimeError('Can be used only in Task context.')
|
|
||||||
context_ = getattr(task, 'context', None)
|
|
||||||
if context_ is None:
|
|
||||||
context_ = task.context = {}
|
|
||||||
return context_
|
|
||||||
|
|
||||||
|
|
||||||
def get_value(key, default=None):
|
|
||||||
"""
|
|
||||||
Get value from task
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param default:
|
|
||||||
:return: value
|
|
||||||
"""
|
|
||||||
return get_current_state().get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
def check_value(key):
|
|
||||||
"""
|
|
||||||
Key in context?
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return key in get_current_state()
|
|
||||||
|
|
||||||
|
|
||||||
def set_value(key, value):
|
|
||||||
"""
|
|
||||||
Set value
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:param value:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
get_current_state()[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
def del_value(key):
|
|
||||||
"""
|
|
||||||
Remove value from context
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
del get_current_state()[key]
|
|
||||||
|
|
||||||
|
|
||||||
def update_state(data=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Update multiple state items
|
|
||||||
|
|
||||||
:param data:
|
|
||||||
:param kwargs:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if data is None:
|
|
||||||
data = {}
|
|
||||||
state = get_current_state()
|
|
||||||
state.update(data, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def check_configured():
|
|
||||||
"""
|
|
||||||
Check loop is configured
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return get_value(CONFIGURED)
|
|
||||||
|
|
||||||
|
|
||||||
class _Context:
|
|
||||||
"""
|
|
||||||
Other things for interactions with the execution context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return get_value(item)
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
set_value(key, value)
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
del_value(key)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_context():
|
|
||||||
return get_current_state()
|
|
||||||
|
|
||||||
|
|
||||||
context = _Context()
|
|
||||||
|
|
@ -6,7 +6,6 @@ from warnings import warn
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from . import context
|
|
||||||
from ..bot.api import log
|
from ..bot.api import log
|
||||||
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
|
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
|
||||||
|
|
||||||
|
|
@ -179,13 +178,13 @@ class Executor:
|
||||||
self._check_frozen()
|
self._check_frozen()
|
||||||
self._freeze = True
|
self._freeze = True
|
||||||
|
|
||||||
self.loop.set_task_factory(context.task_factory)
|
# self.loop.set_task_factory(context.task_factory)
|
||||||
|
|
||||||
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
|
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
|
||||||
self._check_frozen()
|
self._check_frozen()
|
||||||
self._freeze = True
|
self._freeze = True
|
||||||
|
|
||||||
self.loop.set_task_factory(context.task_factory)
|
# self.loop.set_task_factory(context.task_factory)
|
||||||
|
|
||||||
app = self._web_app
|
app = self._web_app
|
||||||
if app is None:
|
if app is None:
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import asyncio
|
||||||
|
|
||||||
from aiogram import Bot, types
|
from aiogram import Bot, types
|
||||||
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
||||||
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
|
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher
|
||||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||||
from aiogram.utils import context, executor
|
from aiogram.utils import executor
|
||||||
from aiogram.utils.exceptions import Throttled
|
from aiogram.utils.exceptions import Throttled
|
||||||
|
|
||||||
TOKEN = 'BOT TOKEN HERE'
|
TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
@ -53,10 +53,10 @@ class ThrottlingMiddleware(BaseMiddleware):
|
||||||
:param message:
|
:param message:
|
||||||
"""
|
"""
|
||||||
# Get current handler
|
# Get current handler
|
||||||
handler = context.get_value('handler')
|
# handler = context.get_value('handler')
|
||||||
|
|
||||||
# Get dispatcher from context
|
# Get dispatcher from context
|
||||||
dispatcher = ctx.get_dispatcher()
|
dispatcher = Dispatcher.current()
|
||||||
|
|
||||||
# If handler was configured, get rate limit and key from handler
|
# If handler was configured, get rate limit and key from handler
|
||||||
if handler:
|
if handler:
|
||||||
|
|
@ -83,8 +83,8 @@ class ThrottlingMiddleware(BaseMiddleware):
|
||||||
:param message:
|
:param message:
|
||||||
:param throttled:
|
:param throttled:
|
||||||
"""
|
"""
|
||||||
handler = context.get_value('handler')
|
# handler = context.get_value('handler')
|
||||||
dispatcher = ctx.get_dispatcher()
|
dispatcher = Dispatcher.current()
|
||||||
if handler:
|
if handler:
|
||||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue