mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 09:55:21 +00:00
Rewrite contextvar usage. Implemented ContextInstanceMixin and DataMixin
This commit is contained in:
parent
8ef279bba1
commit
39c333251f
12 changed files with 82 additions and 120 deletions
|
|
@ -173,43 +173,6 @@ class BaseBot:
|
||||||
|
|
||||||
return await self.request(method, payload, files)
|
return await self.request(method, payload, files)
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self) -> Dict:
|
|
||||||
"""
|
|
||||||
Data stored in bot object
|
|
||||||
|
|
||||||
:return: Dictionary
|
|
||||||
"""
|
|
||||||
return self._data
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
"""
|
|
||||||
Store data in bot instance
|
|
||||||
|
|
||||||
:param key: Key in dict
|
|
||||||
:param value: Value
|
|
||||||
"""
|
|
||||||
self._data[key] = value
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
"""
|
|
||||||
Get item from bot instance by key
|
|
||||||
|
|
||||||
:param item: key name
|
|
||||||
:return: value
|
|
||||||
"""
|
|
||||||
return self._data[item]
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
|
||||||
"""
|
|
||||||
Get item from bot instance by key or return default value
|
|
||||||
|
|
||||||
:param key: key in dict
|
|
||||||
:param default: default value
|
|
||||||
:return: value or default value
|
|
||||||
"""
|
|
||||||
return self._data.get(key, default)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parse_mode(self):
|
def parse_mode(self):
|
||||||
return getattr(self, '_parse_mode', None)
|
return getattr(self, '_parse_mode', None)
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
from contextvars import ContextVar
|
|
||||||
|
|
||||||
from .base import BaseBot, api
|
from .base import BaseBot, api
|
||||||
from .. import types
|
from .. import types
|
||||||
from ..types import base
|
from ..types import base
|
||||||
|
from ..utils.mixins import DataMixin, ContextInstanceMixin
|
||||||
from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file
|
from ..utils.payload import generate_payload, prepare_arg, prepare_attachment, prepare_file
|
||||||
|
|
||||||
|
|
||||||
class Bot(BaseBot):
|
class Bot(BaseBot, DataMixin, ContextInstanceMixin):
|
||||||
"""
|
"""
|
||||||
Base bot class
|
Base bot class
|
||||||
"""
|
"""
|
||||||
|
|
@ -39,14 +39,6 @@ class Bot(BaseBot):
|
||||||
if hasattr(self, '_me'):
|
if hasattr(self, '_me'):
|
||||||
delattr(self, '_me')
|
delattr(self, '_me')
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def current(cls) -> Bot:
|
|
||||||
"""
|
|
||||||
Return active bot instance from the current context or None
|
|
||||||
:return: Bot or None
|
|
||||||
"""
|
|
||||||
return bot.get()
|
|
||||||
|
|
||||||
async def download_file_by_id(self, file_id: base.String, destination=None,
|
async def download_file_by_id(self, file_id: base.String, destination=None,
|
||||||
timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
|
timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
|
||||||
seek: base.Boolean = True):
|
seek: base.Boolean = True):
|
||||||
|
|
@ -98,7 +90,7 @@ class Bot(BaseBot):
|
||||||
allowed_updates = prepare_arg(allowed_updates)
|
allowed_updates = prepare_arg(allowed_updates)
|
||||||
payload = generate_payload(**locals())
|
payload = generate_payload(**locals())
|
||||||
|
|
||||||
result = await self.request(api.Methods.GET_UPDATES, payload, timeout=timeout + 2 if timeout else None)
|
result = await self.request(api.Methods.GET_UPDATES, payload)
|
||||||
return [types.Update(**update) for update in result]
|
return [types.Update(**update) for update in result]
|
||||||
|
|
||||||
async def set_webhook(self, url: base.String,
|
async def set_webhook(self, url: base.String,
|
||||||
|
|
@ -522,7 +514,7 @@ class Bot(BaseBot):
|
||||||
"""
|
"""
|
||||||
reply_markup = prepare_arg(reply_markup)
|
reply_markup = prepare_arg(reply_markup)
|
||||||
payload = generate_payload(**locals(), exclude=["animation", "thumb"])
|
payload = generate_payload(**locals(), exclude=["animation", "thumb"])
|
||||||
|
|
||||||
files = {}
|
files = {}
|
||||||
prepare_file(payload, files, 'animation', animation)
|
prepare_file(payload, files, 'animation', animation)
|
||||||
prepare_attachment(payload, files, 'thumb', thumb)
|
prepare_attachment(payload, files, 'thumb', thumb)
|
||||||
|
|
@ -2064,6 +2056,3 @@ class Bot(BaseBot):
|
||||||
result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload)
|
result = await self.request(api.Methods.GET_GAME_HIGH_SCORES, payload)
|
||||||
|
|
||||||
return [types.GameHighScore(**gamehighscore) for gamehighscore in result]
|
return [types.GameHighScore(**gamehighscore) for gamehighscore in result]
|
||||||
|
|
||||||
|
|
||||||
bot: ContextVar[Bot] = ContextVar('bot_instance', default=None)
|
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ class I18nMiddleware(BaseMiddleware):
|
||||||
:param args: event arguments
|
:param args: event arguments
|
||||||
:return: locale name
|
:return: locale name
|
||||||
"""
|
"""
|
||||||
user: types.User = types.User.current()
|
user: types.User = types.User.get_current()
|
||||||
locale: Locale = user.locale
|
locale: Locale = user.locale
|
||||||
|
|
||||||
if locale:
|
if locale:
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,11 @@ from . import handler
|
||||||
from . import middlewares
|
from . import middlewares
|
||||||
from . import storage
|
from . import storage
|
||||||
from . import webhook
|
from . import webhook
|
||||||
from .dispatcher import Dispatcher, dispatcher, FSMContext, DEFAULT_RATE_LIMIT
|
from .dispatcher import Dispatcher, FSMContext, DEFAULT_RATE_LIMIT
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'DEFAULT_RATE_LIMIT',
|
'DEFAULT_RATE_LIMIT',
|
||||||
'Dispatcher',
|
'Dispatcher',
|
||||||
'dispatcher',
|
|
||||||
'FSMContext',
|
'FSMContext',
|
||||||
'filters',
|
'filters',
|
||||||
'handler',
|
'handler',
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
from contextvars import ContextVar
|
|
||||||
|
|
||||||
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
|
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
|
||||||
RegexpCommandsFilter, StateFilter, Text
|
RegexpCommandsFilter, StateFilter, Text
|
||||||
|
|
@ -14,15 +13,16 @@ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMCon
|
||||||
LAST_CALL, RATE_LIMIT, RESULT
|
LAST_CALL, RATE_LIMIT, RESULT
|
||||||
from .webhook import BaseResponse
|
from .webhook import BaseResponse
|
||||||
from .. import types
|
from .. import types
|
||||||
from ..bot import Bot, bot
|
from ..bot import Bot
|
||||||
from ..utils.exceptions import TelegramAPIError, Throttled
|
from ..utils.exceptions import TelegramAPIError, Throttled
|
||||||
|
from ..utils.mixins import ContextInstanceMixin, DataMixin
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_RATE_LIMIT = .1
|
DEFAULT_RATE_LIMIT = .1
|
||||||
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher(DataMixin, ContextInstanceMixin):
|
||||||
"""
|
"""
|
||||||
Simple Updates dispatcher
|
Simple Updates dispatcher
|
||||||
|
|
||||||
|
|
@ -112,23 +112,6 @@ class Dispatcher:
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.stop_polling()
|
self.stop_polling()
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self):
|
|
||||||
return self.bot.data
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.bot.data[key] = value
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self.bot.data[item]
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
|
||||||
return self.bot.data.get(key, default)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def current(cls):
|
|
||||||
return dispatcher.get()
|
|
||||||
|
|
||||||
async def skip_updates(self):
|
async def skip_updates(self):
|
||||||
"""
|
"""
|
||||||
You can skip old incoming updates from queue.
|
You can skip old incoming updates from queue.
|
||||||
|
|
@ -245,8 +228,8 @@ class Dispatcher:
|
||||||
log.info('Start polling.')
|
log.info('Start polling.')
|
||||||
|
|
||||||
# context.set_value(MODE, LONG_POLLING)
|
# context.set_value(MODE, LONG_POLLING)
|
||||||
dispatcher.set(self)
|
Dispatcher.set_current(self)
|
||||||
bot.bot.set(self.bot)
|
Bot.set_current(self.bot)
|
||||||
|
|
||||||
if reset_webhook is None:
|
if reset_webhook is None:
|
||||||
await self.reset_webhook(check=False)
|
await self.reset_webhook(check=False)
|
||||||
|
|
@ -867,10 +850,10 @@ class Dispatcher:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if chat is None:
|
if chat is None:
|
||||||
chat_obj = types.Chat.current()
|
chat_obj = types.Chat.get_current()
|
||||||
chat = chat_obj.id if chat_obj else None
|
chat = chat_obj.id if chat_obj else None
|
||||||
if user is None:
|
if user is None:
|
||||||
user_obj = types.User.current()
|
user_obj = types.User.get_current()
|
||||||
user = user_obj.id if user_obj else None
|
user = user_obj.id if user_obj else None
|
||||||
|
|
||||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||||
|
|
@ -895,8 +878,8 @@ class Dispatcher:
|
||||||
if rate is None:
|
if rate is None:
|
||||||
rate = self.throttling_rate_limit
|
rate = self.throttling_rate_limit
|
||||||
if user is None and chat is None:
|
if user is None and chat is None:
|
||||||
user = types.User.current()
|
user = types.User.get_current()
|
||||||
chat = types.Chat.current()
|
chat = types.Chat.get_current()
|
||||||
|
|
||||||
# Detect current time
|
# Detect current time
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
@ -945,8 +928,8 @@ class Dispatcher:
|
||||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
if user is None and chat is None:
|
if user is None and chat is None:
|
||||||
user = types.User.current()
|
user = types.User.get_current()
|
||||||
chat = types.Chat.current()
|
chat = types.Chat.get_current()
|
||||||
|
|
||||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||||
data = bucket.get(key, {})
|
data = bucket.get(key, {})
|
||||||
|
|
@ -965,8 +948,8 @@ class Dispatcher:
|
||||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
if user is None and chat is None:
|
if user is None and chat is None:
|
||||||
user = types.User.current()
|
user = types.User.get_current()
|
||||||
chat = types.Chat.current()
|
chat = types.Chat.get_current()
|
||||||
|
|
||||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||||
if bucket and key in bucket:
|
if bucket and key in bucket:
|
||||||
|
|
@ -997,7 +980,7 @@ class Dispatcher:
|
||||||
response = task.result()
|
response = task.result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.loop.create_task(
|
self.loop.create_task(
|
||||||
self.errors_handlers.notify(types.Update.current(), e))
|
self.errors_handlers.notify(types.Update.get_current(), e))
|
||||||
else:
|
else:
|
||||||
if isinstance(response, BaseResponse):
|
if isinstance(response, BaseResponse):
|
||||||
self.loop.create_task(response.execute_response(self.bot))
|
self.loop.create_task(response.execute_response(self.bot))
|
||||||
|
|
@ -1016,6 +999,3 @@ class Dispatcher:
|
||||||
if run_task:
|
if run_task:
|
||||||
return self.async_task(callback)
|
return self.async_task(callback)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
||||||
dispatcher: ContextVar[Dispatcher] = ContextVar('dispatcher_instance', default=None)
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ class State:
|
||||||
__repr__ = __str__
|
__repr__ = __str__
|
||||||
|
|
||||||
async def set(self):
|
async def set(self):
|
||||||
state = Dispatcher.current().current_state()
|
state = Dispatcher.get_current().current_state()
|
||||||
await state.set_state(self.state)
|
await state.set_state(self.state)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -143,7 +143,7 @@ class StatesGroupMeta(type):
|
||||||
class StatesGroup(metaclass=StatesGroupMeta):
|
class StatesGroup(metaclass=StatesGroupMeta):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def next(cls) -> str:
|
async def next(cls) -> str:
|
||||||
state = Dispatcher.current().current_state()
|
state = Dispatcher.get_current().current_state()
|
||||||
state_name = await state.get_state()
|
state_name = await state.get_state()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -161,7 +161,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def previous(cls) -> str:
|
async def previous(cls) -> str:
|
||||||
state = Dispatcher.current().current_state()
|
state = Dispatcher.get_current().current_state()
|
||||||
state_name = await state.get_state()
|
state_name = await state.get_state()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -179,7 +179,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def first(cls) -> str:
|
async def first(cls) -> str:
|
||||||
state = Dispatcher.current().current_state()
|
state = Dispatcher.get_current().current_state()
|
||||||
first_step_name = cls.states_names[0]
|
first_step_name = cls.states_names[0]
|
||||||
|
|
||||||
await state.set_state(first_step_name)
|
await state.set_state(first_step_name)
|
||||||
|
|
@ -187,7 +187,7 @@ class StatesGroup(metaclass=StatesGroupMeta):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def last(cls) -> str:
|
async def last(cls) -> str:
|
||||||
state = Dispatcher.current().current_state()
|
state = Dispatcher.get_current().current_state()
|
||||||
last_step_name = cls.states_names[-1]
|
last_step_name = cls.states_names[-1]
|
||||||
|
|
||||||
await state.set_state(last_step_name)
|
await state.set_state(last_step_name)
|
||||||
|
|
|
||||||
|
|
@ -89,10 +89,9 @@ class WebhookRequestHandler(web.View):
|
||||||
"""
|
"""
|
||||||
dp = self.request.app[BOT_DISPATCHER_KEY]
|
dp = self.request.app[BOT_DISPATCHER_KEY]
|
||||||
try:
|
try:
|
||||||
from aiogram.bot import bot
|
from aiogram import Bot, Dispatcher
|
||||||
from aiogram.dispatcher import dispatcher
|
Dispatcher.set_current(dp)
|
||||||
dispatcher.set(dp)
|
Bot.set_current(dp.bot)
|
||||||
bot.bot.set(dp.bot)
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
return dp
|
return dp
|
||||||
|
|
@ -204,7 +203,7 @@ class WebhookRequestHandler(web.View):
|
||||||
results = task.result()
|
results = task.result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
loop.create_task(
|
loop.create_task(
|
||||||
dispatcher.errors_handlers.notify(dispatcher, types.Update.current(), e))
|
dispatcher.errors_handlers.notify(dispatcher, types.Update.get_current(), e))
|
||||||
else:
|
else:
|
||||||
response = self.get_response(results)
|
response = self.get_response(results)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
|
@ -355,7 +354,7 @@ class BaseResponse:
|
||||||
async def __call__(self, bot=None):
|
async def __call__(self, bot=None):
|
||||||
if bot is None:
|
if bot is None:
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
bot = Bot.current()
|
bot = Bot.get_current()
|
||||||
return await self.execute_response(bot)
|
return await self.execute_response(bot)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
|
|
@ -449,7 +448,7 @@ class ParseModeMixin:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
bot = Bot.current()
|
bot = Bot.get_current()
|
||||||
if bot is not None:
|
if bot is not None:
|
||||||
return bot.parse_mode
|
return bot.parse_mode
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,11 @@ from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import typing
|
import typing
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from .fields import BaseField
|
from .fields import BaseField
|
||||||
from ..utils import json
|
from ..utils import json
|
||||||
|
from ..utils.mixins import ContextInstanceMixin
|
||||||
|
|
||||||
__all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean')
|
__all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean')
|
||||||
|
|
||||||
|
|
@ -57,7 +57,6 @@ class MetaTelegramObject(type):
|
||||||
|
|
||||||
mcs._objects[cls.__name__] = cls
|
mcs._objects[cls.__name__] = cls
|
||||||
|
|
||||||
cls._current = ContextVar('current_' + cls.__name__, default=None) # Maybe need to set default=None?
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -65,7 +64,7 @@ class MetaTelegramObject(type):
|
||||||
return cls._objects
|
return cls._objects
|
||||||
|
|
||||||
|
|
||||||
class TelegramObject(metaclass=MetaTelegramObject):
|
class TelegramObject(ContextInstanceMixin, metaclass=MetaTelegramObject):
|
||||||
"""
|
"""
|
||||||
Abstract class for telegram objects
|
Abstract class for telegram objects
|
||||||
"""
|
"""
|
||||||
|
|
@ -93,14 +92,6 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
||||||
if value.default and key not in self.values:
|
if value.default and key not in self.values:
|
||||||
self.values[key] = value.default
|
self.values[key] = value.default
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def current(cls):
|
|
||||||
return cls._current.get()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_current(cls, obj: TelegramObject):
|
|
||||||
return cls._current.set(obj)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def conf(self) -> typing.Dict[str, typing.Any]:
|
def conf(self) -> typing.Dict[str, typing.Any]:
|
||||||
return self._conf
|
return self._conf
|
||||||
|
|
@ -151,7 +142,7 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
||||||
@property
|
@property
|
||||||
def bot(self):
|
def bot(self):
|
||||||
from ..bot.bot import Bot
|
from ..bot.bot import Bot
|
||||||
return Bot.current()
|
return Bot.get_current()
|
||||||
|
|
||||||
def to_python(self) -> typing.Dict:
|
def to_python(self) -> typing.Dict:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -511,7 +511,7 @@ class ChatActions(helper.Helper):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _do(cls, action: str, sleep=None):
|
async def _do(cls, action: str, sleep=None):
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
await Bot.current().send_chat_action(Chat.current().id, action)
|
await Bot.get_current().send_chat_action(Chat.get_current().id, action)
|
||||||
if sleep:
|
if sleep:
|
||||||
await asyncio.sleep(sleep)
|
await asyncio.sleep(sleep)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,10 +103,9 @@ class Executor:
|
||||||
|
|
||||||
self._freeze = False
|
self._freeze = False
|
||||||
|
|
||||||
from aiogram.bot.bot import bot as ctx_bot
|
from aiogram import Bot, Dispatcher
|
||||||
from aiogram.dispatcher import dispatcher as ctx_dp
|
Bot.set_current(dispatcher.bot)
|
||||||
ctx_bot.set(dispatcher.bot)
|
Dispatcher.set_current(dispatcher)
|
||||||
ctx_dp.set(dispatcher)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def frozen(self):
|
def frozen(self):
|
||||||
|
|
|
||||||
42
aiogram/utils/mixins.py
Normal file
42
aiogram/utils/mixins.py
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
import contextvars
|
||||||
|
|
||||||
|
|
||||||
|
class DataMixin:
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
data = getattr(self, '_data', None)
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
|
setattr(self, '_data', data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.data[item]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.data[key]
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
return self.data.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextInstanceMixin:
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
cls.__context_instance = contextvars.ContextVar('instance_' + cls.__name__)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_current(cls, no_error=True):
|
||||||
|
if no_error:
|
||||||
|
return cls.__context_instance.get(None)
|
||||||
|
return cls.__context_instance.get()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_current(cls, value):
|
||||||
|
if not isinstance(value, cls):
|
||||||
|
raise TypeError(f"Value should be instance of '{cls.__name__}' not '{type(value).__name__}'")
|
||||||
|
cls.__context_instance.set(value)
|
||||||
|
|
||||||
|
|
@ -56,7 +56,7 @@ class ThrottlingMiddleware(BaseMiddleware):
|
||||||
handler = current_handler.get()
|
handler = current_handler.get()
|
||||||
|
|
||||||
# Get dispatcher from context
|
# Get dispatcher from context
|
||||||
dispatcher = Dispatcher.current()
|
dispatcher = Dispatcher.get_current()
|
||||||
# If handler was configured, get rate limit and key from handler
|
# If handler was configured, get rate limit and key from handler
|
||||||
if handler:
|
if handler:
|
||||||
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
|
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
|
||||||
|
|
@ -83,7 +83,7 @@ class ThrottlingMiddleware(BaseMiddleware):
|
||||||
:param throttled:
|
:param throttled:
|
||||||
"""
|
"""
|
||||||
handler = current_handler.get()
|
handler = current_handler.get()
|
||||||
dispatcher = Dispatcher.current()
|
dispatcher = Dispatcher.get_current()
|
||||||
if handler:
|
if handler:
|
||||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue