Remove context util.

This commit is contained in:
Alex Root Junior 2018-06-25 01:31:57 +03:00
parent 5c1eee4fa9
commit 06fbe0d9cd
18 changed files with 1169 additions and 1333 deletions

View file

@ -1,5 +1,4 @@
from aiogram import types from aiogram import types
from aiogram.dispatcher import ctx
from aiogram.dispatcher.middlewares import BaseMiddleware from aiogram.dispatcher.middlewares import BaseMiddleware
OBJ_KEY = '_context_data' OBJ_KEY = '_context_data'
@ -46,7 +45,7 @@ class ContextMiddleware(BaseMiddleware):
:return: :return:
""" """
update = ctx.get_update() update = types.Update.current()
obj = update.conf.get(OBJ_KEY, None) obj = update.conf.get(OBJ_KEY, None)
if obj is None: if obj is None:
obj = self._configure_update(update) obj = self._configure_update(update)

File diff suppressed because it is too large Load diff

View file

@ -1,42 +0,0 @@
from . import Bot
from .. import types
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
from ..utils import context
def _get(key, default=None, no_error=False):
result = context.get_value(key, default)
if not no_error and result is None:
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
f"Maybe asyncio task factory is not configured!\n"
f"\t>>> from aiogram.utils import context\n"
f"\t>>> loop.set_task_factory(context.task_factory)")
return result
def get_bot() -> Bot:
return _get('bot')
def get_dispatcher() -> Dispatcher:
return _get('dispatcher')
def get_update() -> types.Update:
return _get(UPDATE_OBJECT)
def get_mode() -> str:
return _get(MODE, 'unknown')
def get_chat() -> int:
return _get('chat', no_error=True)
def get_user() -> int:
return _get('user', no_error=True)
def get_state() -> FSMContext:
return get_dispatcher().current_state()

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,9 @@
import asyncio import asyncio
import re import re
from _contextvars import ContextVar
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
from aiogram.types import CallbackQuery, ContentType, Message from aiogram.types import CallbackQuery, ContentType, Message
from aiogram.utils import context
USER_STATE = 'USER_STATE' USER_STATE = 'USER_STATE'
@ -130,6 +130,8 @@ class StateFilter(BaseFilter):
""" """
key = 'state' key = 'state'
ctx_state = ContextVar('user_state')
def __init__(self, dispatcher, state): def __init__(self, dispatcher, state):
super().__init__(dispatcher) super().__init__(dispatcher)
if isinstance(state, str): if isinstance(state, str):
@ -143,14 +145,16 @@ class StateFilter(BaseFilter):
if self.state == '*': if self.state == '*':
return True return True
if context.check_value(USER_STATE): try:
context_state = context.get_value(USER_STATE) return self.state == self.ctx_state.get()
return self.state == context_state except LookupError:
else:
chat, user = self.get_target(obj) chat, user = self.get_target(obj)
if chat or user: if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
self.ctx_state.set(state)
return state == self.state
return False return False

View file

@ -4,7 +4,6 @@ from .filters import AbstractFilter, FilterRecord
from ..handler import Handler from ..handler import Handler
# TODO: provide to set default filters (Like state. It will be always be added to filters set)
# TODO: move check_filter/check_filters functions to FiltersFactory class # TODO: move check_filter/check_filters functions to FiltersFactory class
class FiltersFactory: class FiltersFactory:
@ -18,15 +17,17 @@ class FiltersFactory:
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter], def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
validator: typing.Optional[typing.Callable] = None, validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.List[Handler]] = None): event_handlers: typing.Optional[typing.List[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
""" """
Register filter Register filter
:param callback: callable or subclass of :obj:`AbstractFilter` :param callback: callable or subclass of :obj:`AbstractFilter`
:param validator: custom validator. :param validator: custom validator.
:param event_handlers: list of instances of :obj:`Handler` :param event_handlers: list of instances of :obj:`Handler`
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
""" """
record = FilterRecord(callback, validator, event_handlers) record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
self._registered.append(record) self._registered.append(record)
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]): def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
@ -52,17 +53,21 @@ class FiltersFactory:
filters_set = [] filters_set = []
if custom_filters: if custom_filters:
filters_set.extend(custom_filters) filters_set.extend(custom_filters)
if full_config: filters_set.extend(self._resolve_registered(event_handler,
filters_set.extend(self._resolve_registered(self._dispatcher, event_handler, {k: v for k, v in full_config.items() if v is not None}))
{k: v for k, v in full_config.items() if v is not None}))
return filters_set return filters_set
def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator: def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
for record in self._registered: """
if not full_config: Resolve registered filters
break
filter_ = record.resolve(dispatcher, event_handler, full_config) :param event_handler:
:param full_config:
:return:
"""
for record in self._registered:
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
if filter_: if filter_:
yield filter_ yield filter_

View file

@ -72,6 +72,10 @@ class FilterRecord:
return return
config = self.resolver(full_config) config = self.resolver(full_config)
if config: if config:
for key in config:
if key in full_config:
full_config.pop(key)
return self.callback(dispatcher, **config) return self.callback(dispatcher, **config)
def _check_event_handler(self, event_handler) -> bool: def _check_event_handler(self, event_handler) -> bool:

View file

@ -1,6 +1,3 @@
from ..utils import context
class SkipHandler(BaseException): class SkipHandler(BaseException):
pass pass
@ -70,7 +67,7 @@ class Handler:
if await check_filters(filters, args): if await check_filters(filters, args):
try: try:
if self.middleware_key: if self.middleware_key:
context.set_value('handler', handler) # context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args) await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args) response = await handler(*args)
if response is not None: if response is not None:

View file

@ -8,11 +8,13 @@ from typing import Dict, List, Optional, Union
from aiohttp import web from aiohttp import web
from aiogram import Bot
from aiogram.bot import bot
from aiogram.dispatcher import dispatcher
from .. import types from .. import types
from ..bot import api from ..bot import api
from ..types import ParseMode from ..types import ParseMode
from ..types.base import Boolean, Float, Integer, String from ..types.base import Boolean, Float, Integer, String
from ..utils import context
from ..utils import helper, markdown from ..utils import helper, markdown
from ..utils import json from ..utils import json
from ..utils.deprecated import warn_deprecated as warn from ..utils.deprecated import warn_deprecated as warn
@ -88,8 +90,8 @@ class WebhookRequestHandler(web.View):
""" """
dp = self.request.app[BOT_DISPATCHER_KEY] dp = self.request.app[BOT_DISPATCHER_KEY]
try: try:
context.set_value('dispatcher', dp) dispatcher.set(dp)
context.set_value('bot', dp.bot) bot.bot.set(dp.bot)
except RuntimeError: except RuntimeError:
pass pass
return dp return dp
@ -116,9 +118,9 @@ class WebhookRequestHandler(web.View):
""" """
self.validate_ip() self.validate_ip()
context.update_state({'CALLER': WEBHOOK, # context.update_state({'CALLER': WEBHOOK,
WEBHOOK_CONNECTION: True, # WEBHOOK_CONNECTION: True,
WEBHOOK_REQUEST: self.request}) # WEBHOOK_REQUEST: self.request})
dispatcher = self.get_dispatcher() dispatcher = self.get_dispatcher()
update = await self.parse_update(dispatcher.bot) update = await self.parse_update(dispatcher.bot)
@ -170,7 +172,7 @@ class WebhookRequestHandler(web.View):
if fut.done(): if fut.done():
return fut.result() return fut.result()
else: else:
context.set_value(WEBHOOK_CONNECTION, False) # context.set_value(WEBHOOK_CONNECTION, False)
fut.remove_done_callback(cb) fut.remove_done_callback(cb)
fut.add_done_callback(self.respond_via_request) fut.add_done_callback(self.respond_via_request)
finally: finally:
@ -195,7 +197,7 @@ class WebhookRequestHandler(web.View):
results = task.result() results = task.result()
except Exception as e: except Exception as e:
loop.create_task( loop.create_task(
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e)) dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
else: else:
response = self.get_response(results) response = self.get_response(results)
if response is not None: if response is not None:
@ -242,7 +244,7 @@ class WebhookRequestHandler(web.View):
ip_address, accept = self.check_ip() ip_address, accept = self.check_ip()
if not accept: if not accept:
raise web.HTTPUnauthorized() raise web.HTTPUnauthorized()
context.set_value('TELEGRAM_IP', ip_address) # context.set_value('TELEGRAM_IP', ip_address)
def configure_app(dispatcher, app: web.Application, path=DEFAULT_WEB_PATH): def configure_app(dispatcher, app: web.Application, path=DEFAULT_WEB_PATH):
@ -332,8 +334,8 @@ class BaseResponse:
async def __call__(self, bot=None): async def __call__(self, bot=None):
if bot is None: if bot is None:
from aiogram.dispatcher import ctx from aiogram import Bot
bot = ctx.get_bot() bot = Bot.current()
return await self.execute_response(bot) return await self.execute_response(bot)
async def __aenter__(self): async def __aenter__(self):
@ -426,7 +428,7 @@ class ParseModeMixin:
:return: :return:
""" """
bot = context.get_value('bot', None) bot = Bot.current()
if bot is not None: if bot is not None:
return bot.parse_mode return bot.parse_mode

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import io import io
import typing import typing
from contextvars import ContextVar
from typing import TypeVar from typing import TypeVar
from .fields import BaseField from .fields import BaseField
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
setattr(cls, ALIASES_ATTR_NAME, aliases) setattr(cls, ALIASES_ATTR_NAME, aliases)
mcs._objects[cls.__name__] = cls mcs._objects[cls.__name__] = cls
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
return cls return cls
@property @property
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
if value.default and key not in self.values: if value.default and key not in self.values:
self.values[key] = value.default self.values[key] = value.default
@classmethod
def current(cls):
return cls._current.get()
@classmethod
def set_current(cls, obj: TelegramObject):
return cls._current.set(obj)
@property @property
def conf(self) -> typing.Dict[str, typing.Any]: def conf(self) -> typing.Dict[str, typing.Any]:
return self._conf return self._conf

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import asyncio import asyncio
import typing import typing
from contextvars import ContextVar
from . import base from . import base
from . import fields from . import fields
@ -64,7 +67,7 @@ class Chat(base.TelegramObject):
if as_html: if as_html:
return markdown.hlink(name, self.user_url) return markdown.hlink(name, self.user_url)
return markdown.link(name, self.user_url) return markdown.link(name, self.user_url)
async def get_url(self): async def get_url(self):
""" """
Use this method to get chat link. Use this method to get chat link.
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
@classmethod @classmethod
async def _do(cls, action: str, sleep=None): async def _do(cls, action: str, sleep=None):
from ..dispatcher.ctx import get_bot, get_chat from aiogram import Bot
await get_bot().send_chat_action(get_chat(), action) await Bot.current().send_chat_action(Chat.current(), action)
if sleep: if sleep:
await asyncio.sleep(sleep) await asyncio.sleep(sleep)

View file

@ -190,7 +190,7 @@ class Message(base.TelegramObject):
return text return text
async def reply(self, text, parse_mode=None, disable_web_page_preview=None, async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
disable_notification=None, reply_markup=None, reply=True) -> 'Message': disable_notification=None, reply_markup=None, reply=True) -> Message:
""" """
Reply to this message Reply to this message

View file

@ -1,7 +1,5 @@
from __future__ import annotations from __future__ import annotations
from contextvars import ContextVar
from . import base from . import base
from . import fields from . import fields
from .callback_query import CallbackQuery from .callback_query import CallbackQuery
@ -12,8 +10,6 @@ from .pre_checkout_query import PreCheckoutQuery
from .shipping_query import ShippingQuery from .shipping_query import ShippingQuery
from ..utils import helper from ..utils import helper
current_update: ContextVar[Update] = ContextVar('current_update_object', default=None)
class Update(base.TelegramObject): class Update(base.TelegramObject):
""" """
@ -33,14 +29,6 @@ class Update(base.TelegramObject):
shipping_query: ShippingQuery = fields.Field(base=ShippingQuery) shipping_query: ShippingQuery = fields.Field(base=ShippingQuery)
pre_checkout_query: PreCheckoutQuery = fields.Field(base=PreCheckoutQuery) pre_checkout_query: PreCheckoutQuery = fields.Field(base=PreCheckoutQuery)
@classmethod
def current(cls):
return current_update.get()
@classmethod
def set_current(cls, update: Update):
return current_update.set(update)
def __hash__(self): def __hash__(self):
return self.update_id return self.update_id

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import babel import babel
from . import base from . import base

View file

@ -1,140 +0,0 @@
"""
You need to setup task factory:
>>> from aiogram.utils import context
>>> loop = asyncio.get_event_loop()
>>> loop.set_task_factory(context.task_factory)
"""
import asyncio
import typing
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
"""
Task factory for implementing context processor
:param loop:
:param coro:
:return: new task
:rtype: :obj:`asyncio.Task`
"""
# Is not allowed when loop is closed.
if loop.is_closed():
raise RuntimeError('Event loop is closed.')
task = asyncio.Task(coro, loop=loop)
# Hide factory
if task._source_traceback:
del task._source_traceback[-1]
try:
task.context = asyncio.Task.current_task().context.copy()
except AttributeError:
task.context = {CONFIGURED: True}
return task
def get_current_state() -> typing.Dict:
"""
Get current execution context from task
:return: context
:rtype: :obj:`dict`
"""
task = asyncio.Task.current_task()
if task is None:
raise RuntimeError('Can be used only in Task context.')
context_ = getattr(task, 'context', None)
if context_ is None:
context_ = task.context = {}
return context_
def get_value(key, default=None):
"""
Get value from task
:param key:
:param default:
:return: value
"""
return get_current_state().get(key, default)
def check_value(key):
"""
Key in context?
:param key:
:return:
"""
return key in get_current_state()
def set_value(key, value):
"""
Set value
:param key:
:param value:
:return:
"""
get_current_state()[key] = value
def del_value(key):
"""
Remove value from context
:param key:
:return:
"""
del get_current_state()[key]
def update_state(data=None, **kwargs):
"""
Update multiple state items
:param data:
:param kwargs:
:return:
"""
if data is None:
data = {}
state = get_current_state()
state.update(data, **kwargs)
def check_configured():
"""
Check loop is configured
:return:
"""
return get_value(CONFIGURED)
class _Context:
"""
Other things for interactions with the execution context.
"""
def __getitem__(self, item):
return get_value(item)
def __setitem__(self, key, value):
set_value(key, value)
def __delitem__(self, key):
del_value(key)
@staticmethod
def get_context():
return get_current_state()
context = _Context()

View file

@ -6,7 +6,6 @@ from warnings import warn
from aiohttp import web from aiohttp import web
from . import context
from ..bot.api import log from ..bot.api import log
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
@ -179,13 +178,13 @@ class Executor:
self._check_frozen() self._check_frozen()
self._freeze = True self._freeze = True
self.loop.set_task_factory(context.task_factory) # self.loop.set_task_factory(context.task_factory)
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler): def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
self._check_frozen() self._check_frozen()
self._freeze = True self._freeze = True
self.loop.set_task_factory(context.task_factory) # self.loop.set_task_factory(context.task_factory)
app = self._web_app app = self._web_app
if app is None: if app is None:

View file

@ -2,9 +2,9 @@ import asyncio
from aiogram import Bot, types from aiogram import Bot, types
from aiogram.contrib.fsm_storage.redis import RedisStorage2 from aiogram.contrib.fsm_storage.redis import RedisStorage2
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher
from aiogram.dispatcher.middlewares import BaseMiddleware from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils import context, executor from aiogram.utils import executor
from aiogram.utils.exceptions import Throttled from aiogram.utils.exceptions import Throttled
TOKEN = 'BOT TOKEN HERE' TOKEN = 'BOT TOKEN HERE'
@ -53,10 +53,10 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message: :param message:
""" """
# Get current handler # Get current handler
handler = context.get_value('handler') # handler = context.get_value('handler')
# Get dispatcher from context # Get dispatcher from context
dispatcher = ctx.get_dispatcher() dispatcher = Dispatcher.current()
# If handler was configured, get rate limit and key from handler # If handler was configured, get rate limit and key from handler
if handler: if handler:
@ -83,8 +83,8 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message: :param message:
:param throttled: :param throttled:
""" """
handler = context.get_value('handler') # handler = context.get_value('handler')
dispatcher = ctx.get_dispatcher() dispatcher = Dispatcher.current()
if handler: if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else: else: