Rewrite contextvar usage. Implemented ContextInstanceMixin and DataMixin

This commit is contained in:
Alex Root Junior 2018-10-20 15:55:57 +03:00
parent 8ef279bba1
commit 39c333251f
12 changed files with 82 additions and 120 deletions

View file

@ -173,43 +173,6 @@ class BaseBot:
return await self.request(method, payload, files)
@property
def data(self) -> Dict:
"""
Data stored in bot object
:return: Dictionary
"""
return self._data
def __setitem__(self, key, value):
"""
Store data in bot instance
:param key: Key in dict
:param value: Value
"""
self._data[key] = value
def __getitem__(self, item):
"""
Get item from bot instance by key
:param item: key name
:return: value
"""
return self._data[item]
def get(self, key, default=None):
"""
Get item from bot instance by key or return default value
:param key: key in dict
:param default: default value
:return: value or default value
"""
return self._data.get(key, default)
@property
def parse_mode(self):
return getattr(self, '_parse_mode', None)

View file

@ -1,15 +1,15 @@
from __future__ import annotations
import typing
from contextvars import ContextVar
from .base import BaseBot, api
from .. import types
from ..types import base
from ..utils.mixins import DataMixin, ContextInstanceMixin
from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file
class Bot(BaseBot):
class Bot(BaseBot, DataMixin, ContextInstanceMixin):
"""
Base bot class
"""
@ -39,14 +39,6 @@ class Bot(BaseBot):
if hasattr(self, '_me'):
delattr(self, '_me')
@classmethod
def current(cls) -> Bot:
"""
Return active bot instance from the current context or None
:return: Bot or None
"""
return bot.get()
async def download_file_by_id(self, file_id: base.String, destination=None,
timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
seek: base.Boolean = True):
@ -98,7 +90,7 @@ class Bot(BaseBot):
allowed_updates = prepare_arg(allowed_updates)
payload = generate_payload(**locals())
result = await self.request(api.Methods.GET_UPDATES, payload, timeout=timeout + 2 if timeout else None)
result = await self.request(api.Methods.GET_UPDATES, payload)
return [types.Update(**update) for update in result]
async def set_webhook(self, url: base.String,
@ -522,7 +514,7 @@ class Bot(BaseBot):
"""
reply_markup = prepare_arg(reply_markup)
payload = generate_payload(**locals(), exclude=["animation", "thumb"])
files = {}
prepare_file(payload, files, 'animation', animation)
prepare_attachment(payload, files, 'thumb', thumb)
@ -2064,6 +2056,3 @@ class Bot(BaseBot):
result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload)
return [types.GameHighScore(**gamehighscore) for gamehighscore in result]
bot: ContextVar[Bot] = ContextVar('bot_instance', default=None)

View file

@ -116,7 +116,7 @@ class I18nMiddleware(BaseMiddleware):
:param args: event arguments
:return: locale name
"""
user: types.User = types.User.current()
user: types.User = types.User.get_current()
locale: Locale = user.locale
if locale:

View file

@ -3,12 +3,11 @@ from . import handler
from . import middlewares
from . import storage
from . import webhook
from .dispatcher import Dispatcher, dispatcher, FSMContext, DEFAULT_RATE_LIMIT
from .dispatcher import Dispatcher, FSMContext, DEFAULT_RATE_LIMIT
__all__ = [
'DEFAULT_RATE_LIMIT',
'Dispatcher',
'dispatcher',
'FSMContext',
'filters',
'handler',

View file

@ -4,7 +4,6 @@ import itertools
import logging
import time
import typing
from contextvars import ContextVar
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
RegexpCommandsFilter, StateFilter, Text
@ -14,15 +13,16 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon
LAST_CALL, RATE_LIMIT, RESULT
from .webhook import BaseResponse
from .. import types
from ..bot import Bot, bot
from ..bot import Bot
from ..utils.exceptions import TelegramAPIError, Throttled
from ..utils.mixins import ContextInstanceMixin, DataMixin
log = logging.getLogger(__name__)
DEFAULT_RATE_LIMIT = .1
class Dispatcher:
class Dispatcher(DataMixin, ContextInstanceMixin):
"""
Simple Updates dispatcher
@ -112,23 +112,6 @@ class Dispatcher:
def __del__(self):
self.stop_polling()
@property
def data(self):
return self.bot.data
def __setitem__(self, key, value):
self.bot.data[key] = value
def __getitem__(self, item):
return self.bot.data[item]
def get(self, key, default=None):
return self.bot.data.get(key, default)
@classmethod
def current(cls):
return dispatcher.get()
async def skip_updates(self):
"""
You can skip old incoming updates from queue.
@ -245,8 +228,8 @@ class Dispatcher:
log.info('Start polling.')
# context.set_value(MODE, LONG_POLLING)
dispatcher.set(self)
bot.bot.set(self.bot)
Dispatcher.set_current(self)
Bot.set_current(self.bot)
if reset_webhook is None:
await self.reset_webhook(check=False)
@ -867,10 +850,10 @@ class Dispatcher:
:return:
"""
if chat is None:
chat_obj = types.Chat.current()
chat_obj = types.Chat.get_current()
chat = chat_obj.id if chat_obj else None
if user is None:
user_obj = types.User.current()
user_obj = types.User.get_current()
user = user_obj.id if user_obj else None
return FSMContext(storage=self.storage, chat=chat, user=user)
@ -895,8 +878,8 @@ class Dispatcher:
if rate is None:
rate = self.throttling_rate_limit
if user is None and chat is None:
user = types.User.current()
chat = types.Chat.current()
user = types.User.get_current()
chat = types.Chat.get_current()
# Detect current time
now = time.time()
@ -945,8 +928,8 @@ class Dispatcher:
raise RuntimeError('This storage does not provide Leaky Bucket')
if user is None and chat is None:
user = types.User.current()
chat = types.Chat.current()
user = types.User.get_current()
chat = types.Chat.get_current()
bucket = await self.storage.get_bucket(chat=chat, user=user)
data = bucket.get(key, {})
@ -965,8 +948,8 @@ class Dispatcher:
raise RuntimeError('This storage does not provide Leaky Bucket')
if user is None and chat is None:
user = types.User.current()
chat = types.Chat.current()
user = types.User.get_current()
chat = types.Chat.get_current()
bucket = await self.storage.get_bucket(chat=chat, user=user)
if bucket and key in bucket:
@ -997,7 +980,7 @@ class Dispatcher:
response = task.result()
except Exception as e:
self.loop.create_task(
self.errors_handlers.notify(types.Update.current(), e))
self.errors_handlers.notify(types.Update.get_current(), e))
else:
if isinstance(response, BaseResponse):
self.loop.create_task(response.execute_response(self.bot))
@ -1016,6 +999,3 @@ class Dispatcher:
if run_task:
return self.async_task(callback)
return callback
dispatcher: ContextVar[Dispatcher] = ContextVar('dispatcher_instance', default=None)

View file

@ -53,7 +53,7 @@ class State:
__repr__ = __str__
async def set(self):
state = Dispatcher.current().current_state()
state = Dispatcher.get_current().current_state()
await state.set_state(self.state)
@ -143,7 +143,7 @@ class StatesGroupMeta(type):
class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def next(cls) -> str:
state = Dispatcher.current().current_state()
state = Dispatcher.get_current().current_state()
state_name = await state.get_state()
try:
@ -161,7 +161,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def previous(cls) -> str:
state = Dispatcher.current().current_state()
state = Dispatcher.get_current().current_state()
state_name = await state.get_state()
try:
@ -179,7 +179,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def first(cls) -> str:
state = Dispatcher.current().current_state()
state = Dispatcher.get_current().current_state()
first_step_name = cls.states_names[0]
await state.set_state(first_step_name)
@ -187,7 +187,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
async def last(cls) -> str:
state = Dispatcher.current().current_state()
state = Dispatcher.get_current().current_state()
last_step_name = cls.states_names[-1]
await state.set_state(last_step_name)

View file

@ -89,10 +89,9 @@ class WebhookRequestHandler(web.View):
"""
dp = self.request.app[BOT_DISPATCHER_KEY]
try:
from aiogram.bot import bot
from aiogram.dispatcher import dispatcher
dispatcher.set(dp)
bot.bot.set(dp.bot)
from aiogram import Bot, Dispatcher
Dispatcher.set_current(dp)
Bot.set_current(dp.bot)
except RuntimeError:
pass
return dp
@ -204,7 +203,7 @@ class WebhookRequestHandler(web.View):
results = task.result()
except Exception as e:
loop.create_task(
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
dispatcher.errors_handlers.notify(dispatcher, types.Update.get_current(), e))
else:
response = self.get_response(results)
if response is not None:
@ -355,7 +354,7 @@ class BaseResponse:
async def __call__(self, bot=None):
if bot is None:
from aiogram import Bot
bot = Bot.current()
bot = Bot.get_current()
return await self.execute_response(bot)
async def __aenter__(self):
@ -449,7 +448,7 @@ class ParseModeMixin:
:return:
"""
from aiogram import Bot
bot = Bot.current()
bot = Bot.get_current()
if bot is not None:
return bot.parse_mode

View file

@ -2,11 +2,11 @@ from __future__ import annotations
import io
import typing
from contextvars import ContextVar
from typing import TypeVar
from .fields import BaseField
from ..utils import json
from ..utils.mixins import ContextInstanceMixin
__all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean')
@ -57,7 +57,6 @@ class MetaTelegramObject(type):
mcs._objects[cls.__name__] = cls
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
return cls
@property
@ -65,7 +64,7 @@ class MetaTelegramObject(type):
return cls._objects
class TelegramObject(metaclass=MetaTelegramObject):
class TelegramObject(ContextInstanceMixin, metaclass=MetaTelegramObject):
"""
Abstract class for telegram objects
"""
@ -93,14 +92,6 @@ class TelegramObject(metaclass=MetaTelegramObject):
if value.default and key not in self.values:
self.values[key] = value.default
@classmethod
def current(cls):
return cls._current.get()
@classmethod
def set_current(cls, obj: TelegramObject):
return cls._current.set(obj)
@property
def conf(self) -> typing.Dict[str, typing.Any]:
return self._conf
@ -151,7 +142,7 @@ class TelegramObject(metaclass=MetaTelegramObject):
@property
def bot(self):
from ..bot.bot import Bot
return Bot.current()
return Bot.get_current()
def to_python(self) -> typing.Dict:
"""

View file

@ -511,7 +511,7 @@ class ChatActions(helper.Helper):
@classmethod
async def _do(cls, action: str, sleep=None):
from aiogram import Bot
await Bot.current().send_chat_action(Chat.current().id, action)
await Bot.get_current().send_chat_action(Chat.get_current().id, action)
if sleep:
await asyncio.sleep(sleep)

View file

@ -103,10 +103,9 @@ class Executor:
self._freeze = False
from aiogram.bot.bot import bot as ctx_bot
from aiogram.dispatcher import dispatcher as ctx_dp
ctx_bot.set(dispatcher.bot)
ctx_dp.set(dispatcher)
from aiogram import Bot, Dispatcher
Bot.set_current(dispatcher.bot)
Dispatcher.set_current(dispatcher)
@property
def frozen(self):

42
aiogram/utils/mixins.py Normal file
View file

@ -0,0 +1,42 @@
import contextvars
class DataMixin:
@property
def data(self):
data = getattr(self, '_data', None)
if data is None:
data = {}
setattr(self, '_data', data)
return data
def __getitem__(self, item):
return self.data[item]
def __setitem__(self, key, value):
self.data[key] = value
def __delitem__(self, key):
del self.data[key]
def get(self, key, default=None):
return self.data.get(key, default)
class ContextInstanceMixin:
def __init_subclass__(cls, **kwargs):
cls.__context_instance = contextvars.ContextVar('instance_' + cls.__name__)
return cls
@classmethod
def get_current(cls, no_error=True):
if no_error:
return cls.__context_instance.get(None)
return cls.__context_instance.get()
@classmethod
def set_current(cls, value):
if not isinstance(value, cls):
raise TypeError(f"Value should be instance of '{cls.__name__}' not '{type(value).__name__}'")
cls.__context_instance.set(value)

View file

@ -56,7 +56,7 @@ class ThrottlingMiddleware(BaseMiddleware):
handler = current_handler.get()
# Get dispatcher from context
dispatcher = Dispatcher.current()
dispatcher = Dispatcher.get_current()
# If handler was configured, get rate limit and key from handler
if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
@ -83,7 +83,7 @@ class ThrottlingMiddleware(BaseMiddleware):
:param throttled:
"""
handler = current_handler.get()
dispatcher = Dispatcher.current()
dispatcher = Dispatcher.get_current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else: