Remove context util.

This commit is contained in:
Alex Root Junior 2018-06-25 01:31:57 +03:00
parent 5c1eee4fa9
commit 06fbe0d9cd
18 changed files with 1169 additions and 1333 deletions

View file

@ -1,5 +1,4 @@
from aiogram import types
from aiogram.dispatcher import ctx
from aiogram.dispatcher.middlewares import BaseMiddleware
OBJ_KEY = '_context_data'
@ -46,7 +45,7 @@ class ContextMiddleware(BaseMiddleware):
:return:
"""
update = ctx.get_update()
update = types.Update.current()
obj = update.conf.get(OBJ_KEY, None)
if obj is None:
obj = self._configure_update(update)

File diff suppressed because it is too large Load diff

View file

@ -1,42 +0,0 @@
from . import Bot
from .. import types
from ..dispatcher import Dispatcher, FSMContext, MODE, UPDATE_OBJECT
from ..utils import context
def _get(key, default=None, no_error=False):
result = context.get_value(key, default)
if not no_error and result is None:
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
f"Maybe asyncio task factory is not configured!\n"
f"\t>>> from aiogram.utils import context\n"
f"\t>>> loop.set_task_factory(context.task_factory)")
return result
def get_bot() -> Bot:
return _get('bot')
def get_dispatcher() -> Dispatcher:
return _get('dispatcher')
def get_update() -> types.Update:
return _get(UPDATE_OBJECT)
def get_mode() -> str:
return _get(MODE, 'unknown')
def get_chat() -> int:
return _get('chat', no_error=True)
def get_user() -> int:
return _get('user', no_error=True)
def get_state() -> FSMContext:
return get_dispatcher().current_state()

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,9 @@
import asyncio
import re
from _contextvars import ContextVar
from aiogram.dispatcher.filters.filters import BaseFilter, Filter, check_filter
from aiogram.types import CallbackQuery, ContentType, Message
from aiogram.utils import context
USER_STATE = 'USER_STATE'
@ -130,6 +130,8 @@ class StateFilter(BaseFilter):
"""
key = 'state'
ctx_state = ContextVar('user_state')
def __init__(self, dispatcher, state):
super().__init__(dispatcher)
if isinstance(state, str):
@ -143,14 +145,16 @@ class StateFilter(BaseFilter):
if self.state == '*':
return True
if context.check_value(USER_STATE):
context_state = context.get_value(USER_STATE)
return self.state == context_state
else:
try:
return self.state == self.ctx_state.get()
except LookupError:
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
state = await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
self.ctx_state.set(state)
return state == self.state
return False

View file

@ -4,7 +4,6 @@ from .filters import AbstractFilter, FilterRecord
from ..handler import Handler
# TODO: provide to set default filters (Like state. It will be always be added to filters set)
# TODO: move check_filter/check_filters functions to FiltersFactory class
class FiltersFactory:
@ -18,15 +17,17 @@ class FiltersFactory:
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.List[Handler]] = None):
event_handlers: typing.Optional[typing.List[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
"""
Register filter
:param callback: callable or subclass of :obj:`AbstractFilter`
:param validator: custom validator.
:param event_handlers: list of instances of :obj:`Handler`
:param exclude_event_handlers: list of excluded event handlers (:obj:`Handler`)
"""
record = FilterRecord(callback, validator, event_handlers)
record = FilterRecord(callback, validator, event_handlers, exclude_event_handlers)
self._registered.append(record)
def unbind(self, callback: typing.Union[typing.Callable, AbstractFilter]):
@ -52,17 +53,21 @@ class FiltersFactory:
filters_set = []
if custom_filters:
filters_set.extend(custom_filters)
if full_config:
filters_set.extend(self._resolve_registered(self._dispatcher, event_handler,
{k: v for k, v in full_config.items() if v is not None}))
filters_set.extend(self._resolve_registered(event_handler,
{k: v for k, v in full_config.items() if v is not None}))
return filters_set
def _resolve_registered(self, dispatcher, event_handler, full_config) -> typing.Generator:
for record in self._registered:
if not full_config:
break
def _resolve_registered(self, event_handler, full_config) -> typing.Generator:
"""
Resolve registered filters
filter_ = record.resolve(dispatcher, event_handler, full_config)
:param event_handler:
:param full_config:
:return:
"""
for record in self._registered:
filter_ = record.resolve(self._dispatcher, event_handler, full_config)
if filter_:
yield filter_

View file

@ -72,6 +72,10 @@ class FilterRecord:
return
config = self.resolver(full_config)
if config:
for key in config:
if key in full_config:
full_config.pop(key)
return self.callback(dispatcher, **config)
def _check_event_handler(self, event_handler) -> bool:

View file

@ -1,6 +1,3 @@
from ..utils import context
class SkipHandler(BaseException):
pass
@ -70,7 +67,7 @@ class Handler:
if await check_filters(filters, args):
try:
if self.middleware_key:
context.set_value('handler', handler)
# context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if response is not None:

View file

@ -8,11 +8,13 @@ from typing import Dict, List, Optional, Union
from aiohttp import web
from aiogram import Bot
from aiogram.bot import bot
from aiogram.dispatcher import dispatcher
from .. import types
from ..bot import api
from ..types import ParseMode
from ..types.base import Boolean, Float, Integer, String
from ..utils import context
from ..utils import helper, markdown
from ..utils import json
from ..utils.deprecated import warn_deprecated as warn
@ -88,8 +90,8 @@ class WebhookRequestHandler(web.View):
"""
dp = self.request.app[BOT_DISPATCHER_KEY]
try:
context.set_value('dispatcher', dp)
context.set_value('bot', dp.bot)
dispatcher.set(dp)
bot.bot.set(dp.bot)
except RuntimeError:
pass
return dp
@ -116,9 +118,9 @@ class WebhookRequestHandler(web.View):
"""
self.validate_ip()
context.update_state({'CALLER': WEBHOOK,
WEBHOOK_CONNECTION: True,
WEBHOOK_REQUEST: self.request})
# context.update_state({'CALLER': WEBHOOK,
# WEBHOOK_CONNECTION: True,
# WEBHOOK_REQUEST: self.request})
dispatcher = self.get_dispatcher()
update = await self.parse_update(dispatcher.bot)
@ -170,7 +172,7 @@ class WebhookRequestHandler(web.View):
if fut.done():
return fut.result()
else:
context.set_value(WEBHOOK_CONNECTION, False)
# context.set_value(WEBHOOK_CONNECTION, False)
fut.remove_done_callback(cb)
fut.add_done_callback(self.respond_via_request)
finally:
@ -195,7 +197,7 @@ class WebhookRequestHandler(web.View):
results = task.result()
except Exception as e:
loop.create_task(
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
else:
response = self.get_response(results)
if response is not None:
@ -242,7 +244,7 @@ class WebhookRequestHandler(web.View):
ip_address, accept = self.check_ip()
if not accept:
raise web.HTTPUnauthorized()
context.set_value('TELEGRAM_IP', ip_address)
# context.set_value('TELEGRAM_IP', ip_address)
def configure_app(dispatcher, app: web.Application, path=DEFAULT_WEB_PATH):
@ -332,8 +334,8 @@ class BaseResponse:
async def __call__(self, bot=None):
if bot is None:
from aiogram.dispatcher import ctx
bot = ctx.get_bot()
from aiogram import Bot
bot = Bot.current()
return await self.execute_response(bot)
async def __aenter__(self):
@ -426,7 +428,7 @@ class ParseModeMixin:
:return:
"""
bot = context.get_value('bot', None)
bot = Bot.current()
if bot is not None:
return bot.parse_mode

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import io
import typing
from contextvars import ContextVar
from typing import TypeVar
from .fields import BaseField
@ -53,6 +56,8 @@ class MetaTelegramObject(type):
setattr(cls, ALIASES_ATTR_NAME, aliases)
mcs._objects[cls.__name__] = cls
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
return cls
@property
@ -88,6 +93,14 @@ class TelegramObject(metaclass=MetaTelegramObject):
if value.default and key not in self.values:
self.values[key] = value.default
@classmethod
def current(cls):
return cls._current.get()
@classmethod
def set_current(cls, obj: TelegramObject):
return cls._current.set(obj)
@property
def conf(self) -> typing.Dict[str, typing.Any]:
return self._conf

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import asyncio
import typing
from contextvars import ContextVar
from . import base
from . import fields
@ -64,7 +67,7 @@ class Chat(base.TelegramObject):
if as_html:
return markdown.hlink(name, self.user_url)
return markdown.link(name, self.user_url)
async def get_url(self):
"""
Use this method to get chat link.
@ -507,8 +510,8 @@ class ChatActions(helper.Helper):
@classmethod
async def _do(cls, action: str, sleep=None):
from ..dispatcher.ctx import get_bot, get_chat
await get_bot().send_chat_action(get_chat(), action)
from aiogram import Bot
await Bot.current().send_chat_action(Chat.current(), action)
if sleep:
await asyncio.sleep(sleep)

View file

@ -190,7 +190,7 @@ class Message(base.TelegramObject):
return text
async def reply(self, text, parse_mode=None, disable_web_page_preview=None,
disable_notification=None, reply_markup=None, reply=True) -> 'Message':
disable_notification=None, reply_markup=None, reply=True) -> Message:
"""
Reply to this message

View file

@ -1,7 +1,5 @@
from __future__ import annotations
from contextvars import ContextVar
from . import base
from . import fields
from .callback_query import CallbackQuery
@ -12,8 +10,6 @@ from .pre_checkout_query import PreCheckoutQuery
from .shipping_query import ShippingQuery
from ..utils import helper
current_update: ContextVar[Update] = ContextVar('current_update_object', default=None)
class Update(base.TelegramObject):
"""
@ -33,14 +29,6 @@ class Update(base.TelegramObject):
shipping_query: ShippingQuery = fields.Field(base=ShippingQuery)
pre_checkout_query: PreCheckoutQuery = fields.Field(base=PreCheckoutQuery)
@classmethod
def current(cls):
return current_update.get()
@classmethod
def set_current(cls, update: Update):
return current_update.set(update)
def __hash__(self):
return self.update_id

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import babel
from . import base

View file

@ -1,140 +0,0 @@
"""
You need to setup task factory:
>>> from aiogram.utils import context
>>> loop = asyncio.get_event_loop()
>>> loop.set_task_factory(context.task_factory)
"""
import asyncio
import typing
CONFIGURED = '@CONFIGURED_TASK_FACTORY'
def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
"""
Task factory for implementing context processor
:param loop:
:param coro:
:return: new task
:rtype: :obj:`asyncio.Task`
"""
# Is not allowed when loop is closed.
if loop.is_closed():
raise RuntimeError('Event loop is closed.')
task = asyncio.Task(coro, loop=loop)
# Hide factory
if task._source_traceback:
del task._source_traceback[-1]
try:
task.context = asyncio.Task.current_task().context.copy()
except AttributeError:
task.context = {CONFIGURED: True}
return task
def get_current_state() -> typing.Dict:
"""
Get current execution context from task
:return: context
:rtype: :obj:`dict`
"""
task = asyncio.Task.current_task()
if task is None:
raise RuntimeError('Can be used only in Task context.')
context_ = getattr(task, 'context', None)
if context_ is None:
context_ = task.context = {}
return context_
def get_value(key, default=None):
"""
Get value from task
:param key:
:param default:
:return: value
"""
return get_current_state().get(key, default)
def check_value(key):
"""
Key in context?
:param key:
:return:
"""
return key in get_current_state()
def set_value(key, value):
"""
Set value
:param key:
:param value:
:return:
"""
get_current_state()[key] = value
def del_value(key):
"""
Remove value from context
:param key:
:return:
"""
del get_current_state()[key]
def update_state(data=None, **kwargs):
"""
Update multiple state items
:param data:
:param kwargs:
:return:
"""
if data is None:
data = {}
state = get_current_state()
state.update(data, **kwargs)
def check_configured():
"""
Check loop is configured
:return:
"""
return get_value(CONFIGURED)
class _Context:
"""
Other things for interactions with the execution context.
"""
def __getitem__(self, item):
return get_value(item)
def __setitem__(self, key, value):
set_value(key, value)
def __delitem__(self, key):
del_value(key)
@staticmethod
def get_context():
return get_current_state()
context = _Context()

View file

@ -6,7 +6,6 @@ from warnings import warn
from aiohttp import web
from . import context
from ..bot.api import log
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler
@ -179,13 +178,13 @@ class Executor:
self._check_frozen()
self._freeze = True
self.loop.set_task_factory(context.task_factory)
# self.loop.set_task_factory(context.task_factory)
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler):
self._check_frozen()
self._freeze = True
self.loop.set_task_factory(context.task_factory)
# self.loop.set_task_factory(context.task_factory)
app = self._web_app
if app is None:

View file

@ -2,9 +2,9 @@ import asyncio
from aiogram import Bot, types
from aiogram.contrib.fsm_storage.redis import RedisStorage2
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils import context, executor
from aiogram.utils import executor
from aiogram.utils.exceptions import Throttled
TOKEN = 'BOT TOKEN HERE'
@ -53,10 +53,10 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message:
"""
# Get current handler
handler = context.get_value('handler')
# handler = context.get_value('handler')
# Get dispatcher from context
dispatcher = ctx.get_dispatcher()
dispatcher = Dispatcher.current()
# If handler was configured, get rate limit and key from handler
if handler:
@ -83,8 +83,8 @@ class ThrottlingMiddleware(BaseMiddleware):
:param message:
:param throttled:
"""
handler = context.get_value('handler')
dispatcher = ctx.get_dispatcher()
# handler = context.get_value('handler')
dispatcher = Dispatcher.current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else: