diff --git a/Makefile b/Makefile index 6ed9eecf..aac5cbce 100644 --- a/Makefile +++ b/Makefile @@ -2,17 +2,18 @@ VENV_NAME := venv PYTHON := $(VENV_NAME)/bin/python AIOGRAM_VERSION := $(shell $(PYTHON) -c "import aiogram;print(aiogram.__version__)") +RM := rm -rf + mkvenv: virtualenv $(VENV_NAME) $(PYTHON) -m pip install -r requirements.txt clean: - find . -name '*.pyc' -exec rm --force {} + - find . -name '*.pyo' -exec rm --force {} + - find . -name '*~' -exec rm --force {} + - rm --force --recursive build/ - rm --force --recursive dist/ - rm --force --recursive *.egg-info + find . -name '*.pyc' -exec $(RM) {} + + find . -name '*.pyo' -exec $(RM) {} + + find . -name '*~' -exec $(RM) {} + + find . -name '__pycache__' -exec $(RM) {} + + $(RM) build/ dist/ docs/build/ .tox/ .cache/ *.egg-info tag: @echo "Add tag: '$(AIOGRAM_VERSION)'" @@ -26,14 +27,23 @@ upload: release: make clean - make tag + make test make build + make tag @echo "Released aiogram $(AIOGRAM_VERSION)" full-release: make release make upload - -make install: +install: $(PYTHON) setup.py install + +test: + tox + +summary: + cloc aiogram/ tests/ examples/ setup.py + +docs: docs/source/* + cd docs && $(MAKE) html diff --git a/aiogram/__init__.py b/aiogram/__init__.py index f7ec1957..e7a1292d 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -10,7 +10,16 @@ except ImportError as e: from .utils.versions import Stage, Version -VERSION = Version(1, 0, 2, stage=Stage.FINAL, build=0) +try: + import uvloop +except ImportError: + pass +else: + import asyncio + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +VERSION = Version(1, 0, 3, stage=Stage.FINAL, build=0) API_VERSION = Version(3, 5) __version__ = VERSION.version diff --git a/aiogram/bot/api.py b/aiogram/bot/api.py index 9025ee00..1498f2a3 100644 --- a/aiogram/bot/api.py +++ b/aiogram/bot/api.py @@ -124,7 +124,7 @@ def _compose_data(params=None, files=None): return data -async def request(session, token, method, data=None, files=None, continue_retry=False, **kwargs) -> bool or dict: +async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict: """ Make request to API @@ -144,8 +144,6 @@ async def request(session, token, method, data=None, files=None, continue_retry= :type data: :obj:`dict` :param files: files :type files: :obj:`dict` - :param continue_retry: - :type continue_retry: :obj:`dict` :return: result :rtype :obj:`bool` or :obj:`dict` """ @@ -158,18 +156,13 @@ async def request(session, token, method, data=None, files=None, continue_retry= return await _check_result(method, response) except aiohttp.ClientError as e: raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}") - except exceptions.RetryAfter as e: - if continue_retry: - await asyncio.sleep(e.timeout) - return await request(session, token, method, data, files, **kwargs) - raise class Methods(Helper): """ Helper for Telegram API Methods listed on https://core.telegram.org/bots/api - List is updated to Bot API 3.4 + List is updated to Bot API 3.5 """ mode = HelperMode.lowerCamelCase diff --git a/aiogram/bot/base.py b/aiogram/bot/base.py index 9557a7f4..0617dee9 100644 --- a/aiogram/bot/base.py +++ b/aiogram/bot/base.py @@ -18,7 +18,6 @@ class BaseBot: loop: Optional[Union[asyncio.BaseEventLoop, asyncio.AbstractEventLoop]] = None, connections_limit: Optional[base.Integer] = 10, proxy: str = None, proxy_auth: Optional[aiohttp.BasicAuth] = None, - continue_retry: Optional[bool] = False, validate_token: Optional[bool] = True): """ Instructions how to get Bot token is found here: https://core.telegram.org/bots#3-how-do-i-create-a-bot @@ -33,8 +32,6 @@ class BaseBot: :type proxy: :obj:`str` :param proxy_auth: Authentication information :type proxy_auth: Optional :obj:`aiohttp.BasicAuth` - :param continue_retry: automatic retry sent request when flood control exceeded - :type continue_retry: :obj:`bool` :param validate_token: Validate token. :type validate_token: :obj:`bool` :raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError` @@ -48,9 +45,6 @@ class BaseBot: self.proxy = proxy self.proxy_auth = proxy_auth - # Action on error - self.continue_retry = continue_retry - # Asyncio loop instance if loop is None: loop = asyncio.get_event_loop() @@ -68,8 +62,11 @@ class BaseBot: self._data = {} def __del__(self): + self.close() + + def close(self): """ - When bot object is deleting - need close all sessions + Close all client sessions """ for session in self._temp_sessions: if not session.closed: @@ -77,17 +74,19 @@ class BaseBot: if self.session and not self.session.closed: self.session.close() - def create_temp_session(self, limit: int = 1) -> aiohttp.ClientSession: + def create_temp_session(self, limit: int = 1, force_close: bool = False) -> aiohttp.ClientSession: """ Create temporary session :param limit: Limit of connections :type limit: :obj:`int` + :param force_close: Set to True to force close and do reconnect after each request (and between redirects). + :type force_close: :obj:`bool` :return: New session :rtype: :obj:`aiohttp.TCPConnector` """ session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(limit=limit, force_close=True), + connector=aiohttp.TCPConnector(limit=limit, force_close=force_close), loop=self.loop, json_serialize=json.dumps) self._temp_sessions.append(session) return session @@ -123,8 +122,7 @@ class BaseBot: :raise: :obj:`aiogram.exceptions.TelegramApiError` """ return await api.request(self.session, self.__token, method, data, files, - proxy=self.proxy, proxy_auth=self.proxy_auth, - continue_retry=self.continue_retry) + proxy=self.proxy, proxy_auth=self.proxy_auth) async def download_file(self, file_path: base.String, destination: Optional[base.InputFile] = None, diff --git a/aiogram/bot/bot.py b/aiogram/bot/bot.py index 792d8bfa..bdc38f94 100644 --- a/aiogram/bot/bot.py +++ b/aiogram/bot/bot.py @@ -31,7 +31,8 @@ class Bot(BaseBot): delattr(self, '_me') async def download_file_by_id(self, file_id: base.String, destination=None, - timeout: base.Integer=30, chunk_size: base.Integer=65536, seek: base.Boolean=True): + timeout: base.Integer = 30, chunk_size: base.Integer = 65536, + seek: base.Boolean = True): """ Download file by file_id to destination diff --git a/aiogram/contrib/fsm_storage/memory.py b/aiogram/contrib/fsm_storage/memory.py index fa993e0e..ac463d14 100644 --- a/aiogram/contrib/fsm_storage/memory.py +++ b/aiogram/contrib/fsm_storage/memory.py @@ -30,7 +30,7 @@ class MemoryStorage(BaseStorage): chat_id = str(chat_id) user_id = str(user_id) if user_id not in self.data[chat_id]: - self.data[chat_id][user_id] = {'state': None, 'data': {}} + self.data[chat_id][user_id] = {'state': None, 'data': {}, 'bucket': {}} return self.data[chat_id][user_id] async def get_state(self, *, @@ -82,3 +82,27 @@ class MemoryStorage(BaseStorage): await self.set_state(chat=chat, user=user, state=None) if with_data: await self.set_data(chat=chat, user=user, data={}) + + def has_bucket(self): + return True + + async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[dict] = None) -> typing.Dict: + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + return user['bucket'] + + async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None): + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + user['bucket'] = bucket + + async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None, **kwargs): + chat, user = self.check_address(chat=chat, user=user) + user = self._get_user(chat, user) + if bucket is None: + bucket = [] + user['bucket'].update(bucket, **kwargs) diff --git a/aiogram/contrib/fsm_storage/redis.py b/aiogram/contrib/fsm_storage/redis.py index b4e350bc..ecf81afa 100644 --- a/aiogram/contrib/fsm_storage/redis.py +++ b/aiogram/contrib/fsm_storage/redis.py @@ -3,6 +3,7 @@ This module has redis storage for finite-state machine based on `aioredis typing.Dict: + record = await self.get_record(chat=chat, user=user) + return record.get('bucket', {}) + + async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None): + record = await self.get_record(chat=chat, user=user) + await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket) + + async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None, **kwargs): + record = await self.get_record(chat=chat, user=user) + record_bucket = record.get('bucket', {}) + record_bucket.update(bucket, **kwargs) + await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket) + + +class RedisStorage2(BaseStorage): + """ + Busted Redis-base storage for FSM. + Works with Redis connection pool and customizable keys prefix. + + Usage: + + .. code-block:: python3 + + storage = RedisStorage('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key') + dp = Dispatcher(bot, storage=storage) + + And need to close Redis connection when shutdown + + .. code-block:: python3 + + await dp.storage.close() + await dp.storage.wait_closed() + + """ + + def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, + pool_size=10, loop=None, prefix='fsm', **kwargs): + self._host = host + self._port = port + self._db = db + self._password = password + self._ssl = ssl + self._pool_size = pool_size + self._loop = loop or asyncio.get_event_loop() + self._kwargs = kwargs + self._prefix = (prefix,) + + self._redis: aioredis.RedisConnection = None + self._connection_lock = asyncio.Lock(loop=self._loop) + + @property + async def redis(self) -> aioredis.Redis: + """ + Get Redis connection + + This property is awaitable. + """ + # Use thread-safe asyncio Lock because this method without that is not safe + async with self._connection_lock: + if self._redis is None: + self._redis = await aioredis.create_redis_pool((self._host, self._port), + db=self._db, password=self._password, ssl=self._ssl, + minsize=1, maxsize=self._pool_size, + loop=self._loop, **self._kwargs) + return self._redis + + def generate_key(self, *parts): + return ':'.join(self._prefix + tuple(map(str, parts))) + + async def close(self): + async with self._connection_lock: + if self._redis and not self._redis.closed: + self._redis.close() + del self._redis + self._redis = None + + async def wait_closed(self): + async with self._connection_lock: + if self._redis: + return await self._redis.wait_closed() + return True + + async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[str] = None) -> typing.Optional[str]: + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_KEY) + redis = await self.redis + return await redis.get(key, encoding='utf8') or None + + async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[dict] = None) -> typing.Dict: + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_DATA_KEY) + redis = await self.redis + raw_result = await redis.get(key, encoding='utf8') + if raw_result: + return json.loads(raw_result) + return default or {} + + async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + state: typing.Optional[typing.AnyStr] = None): + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_KEY) + redis = await self.redis + if state is None: + await redis.delete(key) + else: + await redis.set(key, state) + + async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + data: typing.Dict = None): + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_DATA_KEY) + redis = await self.redis + await redis.set(key, json.dumps(data)) + + async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + data: typing.Dict = None, **kwargs): + temp_data = await self.get_data(chat=chat, user=user, default={}) + temp_data.update(data, **kwargs) + await self.set_data(chat=chat, user=user, data=temp_data) + + def has_bucket(self): + return True + + async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + default: typing.Optional[dict] = None) -> typing.Dict: + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_BUCKET_KEY) + redis = await self.redis + raw_result = await redis.get(key, encoding='utf8') + if raw_result: + return json.loads(raw_result) + return default or {} + + async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None): + chat, user = self.check_address(chat=chat, user=user) + key = self.generate_key(chat, user, STATE_BUCKET_KEY) + redis = await self.redis + await redis.set(key, json.dumps(bucket)) + + async def update_bucket(self, *, chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None, **kwargs): + temp_bucket = await self.get_data(chat=chat, user=user) + temp_bucket.update(bucket, **kwargs) + await self.set_data(chat=chat, user=user, data=temp_bucket) + + async def reset_all(self, full=True): + """ + Reset states in DB + + :param full: clean DB or clean only states + :return: + """ + conn = await self.redis + + if full: + conn.flushdb() + else: + keys = await conn.keys(self.generate_key('*')) + conn.delete(*keys) + + async def get_states_list(self) -> typing.List[typing.Tuple[int]]: + """ + Get list of all stored chat's and user's + + :return: list of tuples where first element is chat id and second is user id + """ + conn = await self.redis + result = [] + + keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8') + for item in keys: + *_, chat, user, _ = item.split(':') + result.append((chat, user)) + + return result + + async def import_redis1(self, redis1): + await migrate_redis1_to_redis2(redis1, self) + + +async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2): + """ + Helper for migrating from RedisStorage to RedisStorage2 + + :param storage1: instance of RedisStorage + :param storage2: instance of RedisStorage2 + :return: + """ + assert isinstance(storage1, RedisStorage) + assert isinstance(storage2, RedisStorage2) + + log = logging.getLogger('aiogram.RedisStorage') + + for chat, user in await storage1.get_states_list(): + state = await storage1.get_state(chat=chat, user=user) + await storage2.set_state(chat=chat, user=user, state=state) + + data = await storage1.get_data(chat=chat, user=user) + await storage2.set_data(chat=chat, user=user, data=data) + + bucket = await storage1.get_bucket(chat=chat, user=user) + await storage2.set_bucket(chat=chat, user=user, bucket=bucket) + + log.info(f"Migrated user {user} in chat {chat}") diff --git a/aiogram/contrib/middlewares/__init__.py b/aiogram/contrib/middlewares/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/contrib/middlewares/logging.py b/aiogram/contrib/middlewares/logging.py new file mode 100644 index 00000000..ee9ac65a --- /dev/null +++ b/aiogram/contrib/middlewares/logging.py @@ -0,0 +1,132 @@ +import logging +import time + +from aiogram import types +from aiogram.dispatcher.middlewares import BaseMiddleware + +HANDLED_STR = ['Unhandled', 'Handled'] + + +class LoggingMiddleware(BaseMiddleware): + def __init__(self, logger=__name__): + if not isinstance(logger, logging.Logger): + logger = logging.getLogger(logger) + + self.logger = logger + + super(LoggingMiddleware, self).__init__() + + def check_timeout(self, obj): + start = obj.conf.get('_start', None) + if start: + del obj.conf['_start'] + return round((time.time() - start) * 1000) + return -1 + + async def on_pre_process_update(self, update: types.Update): + update.conf['_start'] = time.time() + self.logger.debug(f"Received update [ID:{update.update_id}]") + + async def on_post_process_update(self, update: types.Update, result): + timeout = self.check_timeout(update) + if timeout > 0: + self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)") + + async def on_pre_process_message(self, message: types.Message): + self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") + + async def on_post_process_message(self, message: types.Message, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]") + + async def on_pre_process_edited_message(self, edited_message): + self.logger.info(f"Received edited message [ID:{edited_message.message_id}] " + f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") + + async def on_post_process_edited_message(self, edited_message, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"edited message [ID:{edited_message.message_id}] " + f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]") + + async def on_pre_process_channel_post(self, channel_post: types.Message): + self.logger.info(f"Received channel post [ID:{channel_post.message_id}] " + f"in channel [ID:{channel_post.chat.id}]") + + async def on_post_process_channel_post(self, channel_post: types.Message, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"channel post [ID:{channel_post.message_id}] " + f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]") + + async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message): + self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] " + f"in channel [ID:{edited_channel_post.chat.id}]") + + async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"edited channel post [ID:{edited_channel_post.message_id}] " + f"in channel [ID:{edited_channel_post.chat.id}]") + + async def on_pre_process_inline_query(self, inline_query: types.InlineQuery): + self.logger.info(f"Received inline query [ID:{inline_query.id}] " + f"from user [ID:{inline_query.from_user.id}]") + + async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"inline query [ID:{inline_query.id}] " + f"from user [ID:{inline_query.from_user.id}]") + + async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult): + self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " + f"from user [ID:{chosen_inline_result.from_user.id}] " + f"result [ID:{chosen_inline_result.result_id}]") + + async def on_post_process_chosen_inline_result(self, chosen_inline_result, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] " + f"from user [ID:{chosen_inline_result.from_user.id}] " + f"result [ID:{chosen_inline_result.result_id}]") + + async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery): + if callback_query.message: + self.logger.info(f"Received callback query [ID:{callback_query.id}] " + f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] " + f"from user [ID:{callback_query.message.from_user.id}]") + else: + self.logger.info(f"Received callback query [ID:{callback_query.id}] " + f"from inline message [ID:{callback_query.inline_message_id}] " + f"from user [ID:{callback_query.from_user.id}]") + + async def on_post_process_callback_query(self, callback_query, results): + if callback_query.message: + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"callback query [ID:{callback_query.id}] " + f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] " + f"from user [ID:{callback_query.message.from_user.id}]") + else: + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"callback query [ID:{callback_query.id}] " + f"from inline message [ID:{callback_query.inline_message_id}] " + f"from user [ID:{callback_query.from_user.id}]") + + async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery): + self.logger.info(f"Received shipping query [ID:{shipping_query.id}] " + f"from user [ID:{shipping_query.from_user.id}]") + + async def on_post_process_shipping_query(self, shipping_query, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"shipping query [ID:{shipping_query.id}] " + f"from user [ID:{shipping_query.from_user.id}]") + + async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery): + self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] " + f"from user [ID:{pre_checkout_query.from_user.id}]") + + async def on_post_process_pre_checkout_query(self, pre_checkout_query, results): + self.logger.debug(f"{HANDLED_STR[bool(len(results))]} " + f"pre-checkout query [ID:{pre_checkout_query.id}] " + f"from user [ID:{pre_checkout_query.from_user.id}]") + + async def on_pre_process_error(self, dispatcher, update, error): + timeout = self.check_timeout(update) + if timeout > 0: + self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)") diff --git a/aiogram/dispatcher/__init__.py b/aiogram/dispatcher/__init__.py index 76fb3b4d..dc378439 100644 --- a/aiogram/dispatcher/__init__.py +++ b/aiogram/dispatcher/__init__.py @@ -1,20 +1,21 @@ import asyncio import functools +import itertools import logging -import typing - import time +import typing from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \ generate_default_filters -from .handler import Handler -from .storage import BaseStorage, DisabledStorage, FSMContext +from .handler import CancelHandler, Handler, SkipHandler +from .middlewares import MiddlewareManager +from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT from .webhook import BaseResponse from ..bot import Bot from ..types.message import ContentType from ..utils import context from ..utils.deprecated import deprecated -from ..utils.exceptions import NetworkError, TelegramAPIError +from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled log = logging.getLogger(__name__) @@ -22,6 +23,8 @@ MODE = 'MODE' LONG_POLLING = 'long-polling' UPDATE_OBJECT = 'update_object' +DEFAULT_RATE_LIMIT = .1 + class Dispatcher: """ @@ -33,7 +36,9 @@ class Dispatcher: """ def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None, - run_tasks_by_default: bool = False): + run_tasks_by_default: bool = False, + throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False): + if loop is None: loop = bot.loop if storage is None: @@ -44,27 +49,33 @@ class Dispatcher: self.storage = storage self.run_tasks_by_default = run_tasks_by_default + self.throttling_rate_limit = throttling_rate_limit + self.no_throttle_error = no_throttle_error + self.last_update_id = 0 - self.updates_handler = Handler(self) - self.message_handlers = Handler(self) - self.edited_message_handlers = Handler(self) - self.channel_post_handlers = Handler(self) - self.edited_channel_post_handlers = Handler(self) - self.inline_query_handlers = Handler(self) - self.chosen_inline_result_handlers = Handler(self) - self.callback_query_handlers = Handler(self) - self.shipping_query_handlers = Handler(self) - self.pre_checkout_query_handlers = Handler(self) + self.updates_handler = Handler(self, middleware_key='update') + self.message_handlers = Handler(self, middleware_key='message') + self.edited_message_handlers = Handler(self, middleware_key='edited_message') + self.channel_post_handlers = Handler(self, middleware_key='channel_post') + self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post') + self.inline_query_handlers = Handler(self, middleware_key='inline_query') + self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result') + self.callback_query_handlers = Handler(self, middleware_key='callback_query') + self.shipping_query_handlers = Handler(self, middleware_key='shipping_query') + self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query') + self.errors_handlers = Handler(self, once=False, middleware_key='error') + + self.middleware = MiddlewareManager(self) self.updates_handler.register(self.process_update) - self.errors_handlers = Handler(self, once=False) - self._polling = False + self._closed = True + self._close_waiter = loop.create_future() def __del__(self): - self._polling = False + self.stop_polling() @property def data(self): @@ -105,7 +116,7 @@ class Dispatcher: """ tasks = [] for update in updates: - tasks.append(self.process_update(update)) + tasks.append(self.updates_handler.notify(update)) return await asyncio.gather(*tasks) async def process_update(self, update): @@ -115,72 +126,59 @@ class Dispatcher: :param update: :return: """ - start = time.time() - success = True - + self.last_update_id = update.update_id + context.set_value(UPDATE_OBJECT, update) try: - self.last_update_id = update.update_id - has_context = context.check_configured() - if has_context: - context.set_value(UPDATE_OBJECT, update) if update.message: - if has_context: - state = await self.storage.get_state(chat=update.message.chat.id, - user=update.message.from_user.id) - context.update_state(chat=update.message.chat.id, - user=update.message.from_user.id, - state=state) + state = await self.storage.get_state(chat=update.message.chat.id, + user=update.message.from_user.id) + context.update_state(chat=update.message.chat.id, + user=update.message.from_user.id, + state=state) return await self.message_handlers.notify(update.message) if update.edited_message: - if has_context: - state = await self.storage.get_state(chat=update.edited_message.chat.id, - user=update.edited_message.from_user.id) - context.update_state(chat=update.edited_message.chat.id, - user=update.edited_message.from_user.id, - state=state) + state = await self.storage.get_state(chat=update.edited_message.chat.id, + user=update.edited_message.from_user.id) + context.update_state(chat=update.edited_message.chat.id, + user=update.edited_message.from_user.id, + state=state) return await self.edited_message_handlers.notify(update.edited_message) if update.channel_post: - if has_context: - state = await self.storage.get_state(chat=update.channel_post.chat.id) - context.update_state(chat=update.channel_post.chat.id, - state=state) + state = await self.storage.get_state(chat=update.channel_post.chat.id) + context.update_state(chat=update.channel_post.chat.id, + state=state) return await self.channel_post_handlers.notify(update.channel_post) if update.edited_channel_post: - if has_context: - state = await self.storage.get_state(chat=update.edited_channel_post.chat.id) - context.update_state(chat=update.edited_channel_post.chat.id, - state=state) + state = await self.storage.get_state(chat=update.edited_channel_post.chat.id) + context.update_state(chat=update.edited_channel_post.chat.id, + state=state) return await self.edited_channel_post_handlers.notify(update.edited_channel_post) if update.inline_query: - if has_context: - state = await self.storage.get_state(user=update.inline_query.from_user.id) - context.update_state(user=update.inline_query.from_user.id, - state=state) + state = await self.storage.get_state(user=update.inline_query.from_user.id) + context.update_state(user=update.inline_query.from_user.id, + state=state) return await self.inline_query_handlers.notify(update.inline_query) if update.chosen_inline_result: - if has_context: - state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id) - context.update_state(user=update.chosen_inline_result.from_user.id, - state=state) + state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id) + context.update_state(user=update.chosen_inline_result.from_user.id, + state=state) return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) if update.callback_query: - if has_context: - state = await self.storage.get_state(chat=update.callback_query.message.chat.id, - user=update.callback_query.from_user.id) - context.update_state(user=update.callback_query.from_user.id, - state=state) + state = await self.storage.get_state( + chat=update.callback_query.message.chat.id if update.callback_query.message else None, + user=update.callback_query.from_user.id) + context.update_state(user=update.callback_query.from_user.id, + state=state) return await self.callback_query_handlers.notify(update.callback_query) if update.shipping_query: - if has_context: - state = await self.storage.get_state(user=update.shipping_query.from_user.id) - context.update_state(user=update.shipping_query.from_user.id, - state=state) + state = await self.storage.get_state(user=update.shipping_query.from_user.id) + context.update_state(user=update.shipping_query.from_user.id, + state=state) return await self.shipping_query_handlers.notify(update.shipping_query) if update.pre_checkout_query: - if has_context: - state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id) - context.update_state(user=update.pre_checkout_query.from_user.id, - state=state) + state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id) + context.update_state(user=update.pre_checkout_query.from_user.id, + state=state) return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) except Exception as e: success = False @@ -188,10 +186,6 @@ class Dispatcher: if err: return err raise - finally: - log.info(f"Process update [ID:{update.update_id}]: " - f"{['failed', 'success'][success]} " - f"(in {round((time.time() - start) * 1000)} ms)") async def reset_webhook(self, check=True) -> bool: """ @@ -244,24 +238,26 @@ class Dispatcher: self._polling = True offset = None - while self._polling: - try: - updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout) - except NetworkError: - log.exception('Cause exception while getting updates.') - await asyncio.sleep(15) - continue + try: + while self._polling: + try: + updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout) + except NetworkError: + log.exception('Cause exception while getting updates.') + await asyncio.sleep(15) + continue - if updates: - log.debug(f"Received {len(updates)} updates.") - offset = updates[-1].update_id + 1 + if updates: + log.debug(f"Received {len(updates)} updates.") + offset = updates[-1].update_id + 1 - self.loop.create_task(self._process_polling_updates(updates)) + self.loop.create_task(self._process_polling_updates(updates)) - if relax: - await asyncio.sleep(relax) - - log.warning('Polling is stopped.') + if relax: + await asyncio.sleep(relax) + finally: + self._close_waiter.set_result(None) + log.warning('Polling is stopped.') async def _process_polling_updates(self, updates): """ @@ -270,8 +266,8 @@ class Dispatcher: :param updates: list of updates. """ need_to_call = [] - for response in await self.process_updates(updates): - for response in response: + for responses in itertools.chain.from_iterable(await self.process_updates(updates)): + for response in responses: if not isinstance(response, BaseResponse): continue need_to_call.append(response.execute_response(self.bot)) @@ -288,12 +284,21 @@ class Dispatcher: def stop_polling(self): """ Break long-polling process. + :return: """ if self._polling: - log.info('Stop polling.') + log.info('Stop polling...') self._polling = False + async def wait_closed(self): + """ + Wait closing the long polling + + :return: + """ + await asyncio.shield(self._close_waiter, loop=self.loop) + @deprecated('The old method was renamed to `is_polling`') def is_pooling(self): return self.is_polling() @@ -897,7 +902,8 @@ class Dispatcher: """ def decorator(callback): - self.register_errors_handler(callback, func=func, exception=exception) + self.register_errors_handler(self._wrap_async_task(callback, run_task), + func=func, exception=exception) return callback return decorator @@ -929,6 +935,109 @@ class Dispatcher: return FSMContext(storage=self.storage, chat=chat, user=user) + async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool: + """ + Execute throttling manager. + Return True limit is not exceeded otherwise raise ThrottleError or return False + + :param key: key in storage + :param rate: limit (by default is equals with default rate limit) + :param user: user id + :param chat: chat id + :param no_error: return boolean value instead of raising error + :return: bool + """ + if not self.storage.has_bucket(): + raise RuntimeError('This storage does not provide Leaky Bucket') + + if no_error is None: + no_error = self.no_throttle_error + if rate is None: + rate = self.throttling_rate_limit + if user is None and chat is None: + from . import ctx + user = ctx.get_user() + chat = ctx.get_chat() + + # Detect current time + now = time.time() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + + # Fix bucket + if bucket is None: + bucket = {key: {}} + if key not in bucket: + bucket[key] = {} + data = bucket[key] + + # Calculate + called = data.get(LAST_CALL, now) + delta = now - called + result = delta >= rate or delta <= 0 + + # Save results + data[RESULT] = result + data[RATE_LIMIT] = rate + data[LAST_CALL] = now + data[DELTA] = delta + if not result: + data[EXCEEDED_COUNT] += 1 + else: + data[EXCEEDED_COUNT] = 1 + bucket[key].update(data) + await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + + if not result and not no_error: + # Raise if that is allowed + raise Throttled(key=key, chat=chat, user=user, **data) + return result + + async def check_key(self, key, chat=None, user=None): + """ + Get information about key in bucket + + :param key: + :param chat: + :param user: + :return: + """ + if not self.storage.has_bucket(): + raise RuntimeError('This storage does not provide Leaky Bucket') + + if user is None and chat is None: + from . import ctx + user = ctx.get_user() + chat = ctx.get_chat() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + data = bucket.get(key, {}) + return Throttled(key=key, chat=chat, user=user, **data) + + async def release_key(self, key, chat=None, user=None): + """ + Release blocked key + + :param key: + :param chat: + :param user: + :return: + """ + if not self.storage.has_bucket(): + raise RuntimeError('This storage does not provide Leaky Bucket') + + if user is None and chat is None: + from . import ctx + user = ctx.get_user() + chat = ctx.get_chat() + + bucket = await self.storage.get_bucket(chat=chat, user=user) + if bucket and key in bucket: + del bucket['key'] + await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + return True + return False + def async_task(self, func): """ Execute handler as task and return None. @@ -947,10 +1056,14 @@ class Dispatcher: """ def process_response(task): - response = task.result() - - if isinstance(response, BaseResponse): - self.loop.create_task(response.execute_response(self.bot)) + try: + response = task.result() + except Exception as e: + self.loop.create_task( + self.errors_handlers.notify(self, task.context.get(UPDATE_OBJECT, None), e)) + else: + if isinstance(response, BaseResponse): + self.loop.create_task(response.execute_response(self.bot)) @functools.wraps(func) async def wrapper(*args, **kwargs): diff --git a/aiogram/dispatcher/ctx.py b/aiogram/dispatcher/ctx.py index f1ecce68..18229125 100644 --- a/aiogram/dispatcher/ctx.py +++ b/aiogram/dispatcher/ctx.py @@ -7,7 +7,10 @@ from ..utils import context def _get(key, default=None, no_error=False): result = context.get_value(key, default) if not no_error and result is None: - raise RuntimeError(f"Context is not configured for '{key}'") + raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n" + f"Maybe asyncio task factory is not configured!\n" + f"\t>>> from aiogram.utils import context\n" + f"\t>>> loop.set_task_factory(context.task_factory)") return result diff --git a/aiogram/dispatcher/filters.py b/aiogram/dispatcher/filters.py index b431ae67..599b4aef 100644 --- a/aiogram/dispatcher/filters.py +++ b/aiogram/dispatcher/filters.py @@ -9,26 +9,45 @@ from ..utils.helper import Helper, HelperMode, Item USER_STATE = 'USER_STATE' -async def check_filter(filter_, args, kwargs): +async def check_filter(filter_, args): + """ + Helper for executing filter + + :param filter_: + :param args: + :param kwargs: + :return: + """ if not callable(filter_): raise TypeError('Filter must be callable and/or awaitable!') if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_): - return await filter_(*args, **kwargs) + return await filter_(*args) else: - return filter_(*args, **kwargs) + return filter_(*args) -async def check_filters(filters, args, kwargs): +async def check_filters(filters, args): + """ + Check list of filters + + :param filters: + :param args: + :return: + """ if filters is not None: for filter_ in filters: - f = await check_filter(filter_, args, kwargs) + f = await check_filter(filter_, args) if not f: return False return True class Filter: + """ + Base class for filters + """ + def __call__(self, *args, **kwargs): return self.check(*args, **kwargs) @@ -37,6 +56,10 @@ class Filter: class AsyncFilter(Filter): + """ + Base class for asynchronous filters + """ + def __aiter__(self): return None @@ -48,23 +71,35 @@ class AsyncFilter(Filter): class AnyFilter(AsyncFilter): + """ + One filter from many + """ + def __init__(self, *filters: callable): self.filters = filters - async def check(self, *args, **kwargs): - f = (check_filter(filter_, args, kwargs) for filter_ in self.filters) + async def check(self, *args): + f = (check_filter(filter_, args) for filter_ in self.filters) return any(await asyncio.gather(*f)) class NotFilter(AsyncFilter): + """ + Reverse filter + """ + def __init__(self, filter_: callable): self.filter = filter_ - async def check(self, *args, **kwargs): - return not await check_filter(self.filter, args, kwargs) + async def check(self, *args): + return not await check_filter(self.filter, args) class CommandsFilter(AsyncFilter): + """ + Check commands in message + """ + def __init__(self, commands): self.commands = commands @@ -85,6 +120,10 @@ class CommandsFilter(AsyncFilter): class RegexpFilter(Filter): + """ + Regexp filter for messages + """ + def __init__(self, regexp): self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) @@ -94,6 +133,10 @@ class RegexpFilter(Filter): class ContentTypeFilter(Filter): + """ + Check message content type + """ + def __init__(self, content_types): self.content_types = content_types @@ -103,6 +146,10 @@ class ContentTypeFilter(Filter): class CancelFilter(Filter): + """ + Find cancel in message text + """ + def __init__(self, cancel_set=None): if cancel_set is None: cancel_set = ['/cancel', 'cancel', 'cancel.'] @@ -114,6 +161,10 @@ class CancelFilter(Filter): class StateFilter(AsyncFilter): + """ + Check user state + """ + def __init__(self, dispatcher, state): self.dispatcher = dispatcher self.state = state @@ -137,6 +188,10 @@ class StateFilter(AsyncFilter): class StatesListFilter(StateFilter): + """ + List of states + """ + async def check(self, obj): chat, user = self.get_target(obj) @@ -146,6 +201,10 @@ class StatesListFilter(StateFilter): class ExceptionsFilter(Filter): + """ + Filter for exceptions + """ + def __init__(self, exception): self.exception = exception @@ -159,6 +218,14 @@ class ExceptionsFilter(Filter): def generate_default_filters(dispatcher, *args, **kwargs): + """ + Prepare filters + + :param dispatcher: + :param args: + :param kwargs: + :return: + """ filters_set = [] for name, filter_ in kwargs.items(): diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index 517b8b75..c721369b 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -1,3 +1,4 @@ +from aiogram.utils import context from .filters import check_filters @@ -10,11 +11,12 @@ class CancelHandler(BaseException): class Handler: - def __init__(self, dispatcher, once=True): + def __init__(self, dispatcher, once=True, middleware_key=None): self.dispatcher = dispatcher self.once = once self.handlers = [] + self.middleware_key = middleware_key def register(self, handler, filters=None, index=None): """ @@ -48,20 +50,24 @@ class Handler: return True raise ValueError('This handler is not registered!') - async def notify(self, *args, **kwargs): + async def notify(self, *args): """ Notify handlers :param args: - :param kwargs: :return: """ results = [] + if self.middleware_key: + await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args) for filters, handler in self.handlers: - if await check_filters(filters, args, kwargs): + if await check_filters(filters, args): try: - response = await handler(*args, **kwargs) + if self.middleware_key: + context.set_value('handler', handler) + await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args) + response = await handler(*args) if results is not None: results.append(response) if self.once: @@ -70,5 +76,8 @@ class Handler: continue except CancelHandler: break + if self.middleware_key: + await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}", + args + (results,)) return results diff --git a/aiogram/dispatcher/middlewares.py b/aiogram/dispatcher/middlewares.py new file mode 100644 index 00000000..85c7ef2d --- /dev/null +++ b/aiogram/dispatcher/middlewares.py @@ -0,0 +1,101 @@ +import logging +import typing + +log = logging.getLogger('aiogram.Middleware') + + +class MiddlewareManager: + """ + Middlewares manager. Works only with dispatcher. + """ + + def __init__(self, dispatcher): + """ + Init + + :param dispatcher: instance of Dispatcher + """ + self.dispatcher = dispatcher + self.loop = dispatcher.loop + self.bot = dispatcher.bot + self.storage = dispatcher.storage + self.applications = [] + + def setup(self, middleware): + """ + Setup middleware + + :param middleware: + :return: + """ + assert isinstance(middleware, BaseMiddleware) + if middleware.is_configured(): + raise ValueError('That middleware is already used!') + + self.applications.append(middleware) + middleware.setup(self) + log.debug(f"Loaded middleware '{middleware.__class__.__name__}'") + + async def trigger(self, action: str, args: typing.Iterable): + """ + Call action to middlewares with args lilt. + + :param action: + :param args: + :return: + """ + for app in self.applications: + await app.trigger(action, args) + + +class BaseMiddleware: + """ + Base class for middleware. + + All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message". + """ + + def __init__(self): + self._configured = False + self._manager = None + + @property + def manager(self) -> MiddlewareManager: + """ + Instance of MiddlewareManager + """ + if self._manager is None: + raise RuntimeError('Middleware is not configured!') + return self._manager + + def setup(self, manager): + """ + Mark middleware as configured + + :param manager: + :return: + """ + self._manager = manager + self._configured = True + + def is_configured(self) -> bool: + """ + Check middleware is configured + + :return: + """ + return self._configured + + async def trigger(self, action, args): + """ + Trigger action. + + :param action: + :param args: + :return: + """ + handler_name = f"on_{action}" + handler = getattr(self, handler_name, None) + if not handler: + return None + await handler(*args) diff --git a/aiogram/dispatcher/storage.py b/aiogram/dispatcher/storage.py index 85c23a61..622e4182 100644 --- a/aiogram/dispatcher/storage.py +++ b/aiogram/dispatcher/storage.py @@ -1,5 +1,14 @@ import typing +# Leak bucket +KEY = 'key' +LAST_CALL = 'called_at' +RATE_LIMIT = 'rate_limit' +RESULT = 'result' +EXCEEDED_COUNT = 'exceeded' +DELTA = 'delta' +THROTTLE_MANAGER = '$throttle_manager' + class BaseStorage: """ @@ -184,6 +193,78 @@ class BaseStorage: """ await self.reset_state(chat=chat, user=user, with_data=True) + def has_bucket(self): + return False + + async def get_bucket(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + default: typing.Optional[dict] = None) -> typing.Dict: + """ + Get state-data for user in chat. Return `default` if data is not presented in storage. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :param default: + :return: + """ + raise NotImplementedError + + async def set_bucket(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None): + """ + Set data for user in chat + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :param bucket: + """ + raise NotImplementedError + + async def update_bucket(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None, + bucket: typing.Dict = None, + **kwargs): + """ + Update data for user in chat + + You can use data parameter or|and kwargs. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param bucket: + :param chat: + :param user: + :param kwargs: + :return: + """ + raise NotImplementedError + + async def reset_bucket(self, *, + chat: typing.Union[str, int, None] = None, + user: typing.Union[str, int, None] = None): + """ + Reset data dor user in chat. + + Chat or user is always required. If one of this is not presented, + need set the missing value based on the presented + + :param chat: + :param user: + :return: + """ + await self.set_data(chat=chat, user=user, data={}) + class FSMContext: def __init__(self, storage, chat, user): diff --git a/aiogram/dispatcher/webhook.py b/aiogram/dispatcher/webhook.py index 80163818..e98635f9 100644 --- a/aiogram/dispatcher/webhook.py +++ b/aiogram/dispatcher/webhook.py @@ -186,10 +186,15 @@ class WebhookRequestHandler(web.View): dispatcher = self.get_dispatcher() loop = dispatcher.loop - results = task.result() - response = self.get_response(results) - if response is not None: - asyncio.ensure_future(response.execute_response(self.get_dispatcher().bot), loop=loop) + try: + results = task.result() + except Exception as e: + loop.create_task( + dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e)) + else: + response = self.get_response(results) + if response is not None: + asyncio.ensure_future(response.execute_response(dispatcher.bot), loop=loop) def get_response(self, results): """ diff --git a/aiogram/types/base.py b/aiogram/types/base.py index 7a4bdd56..e291afe5 100644 --- a/aiogram/types/base.py +++ b/aiogram/types/base.py @@ -138,10 +138,8 @@ class TelegramObject(metaclass=MetaTelegramObject): @property def bot(self): - bot = get_value('bot') - if bot is None: - raise RuntimeError('Can not found bot instance in current context!') - return bot + from ..dispatcher import ctx + return ctx.get_bot() def to_python(self) -> typing.Dict: """ diff --git a/aiogram/types/chat.py b/aiogram/types/chat.py index 04996f5e..392de77b 100644 --- a/aiogram/types/chat.py +++ b/aiogram/types/chat.py @@ -211,7 +211,7 @@ class ChatActions(helper.Helper): @classmethod async def _do(cls, action: str, sleep=None): - from aiogram.dispatcher.ctx import get_bot, get_chat + from ..dispatcher.ctx import get_bot, get_chat await get_bot().send_chat_action(get_chat(), action) if sleep: await asyncio.sleep(sleep) diff --git a/aiogram/types/chat_member.py b/aiogram/types/chat_member.py index 96be7b76..f445d69a 100644 --- a/aiogram/types/chat_member.py +++ b/aiogram/types/chat_member.py @@ -1,9 +1,9 @@ import datetime -from aiogram.utils import helper from . import base from . import fields from .user import User +from ..utils import helper class ChatMember(base.TelegramObject): diff --git a/aiogram/types/input_media.py b/aiogram/types/input_media.py index 391e6581..12c9e8ca 100644 --- a/aiogram/types/input_media.py +++ b/aiogram/types/input_media.py @@ -76,6 +76,7 @@ class MediaGroup(base.TelegramObject): """ Helper for sending media group """ + def __init__(self, medias: typing.Optional[typing.List[typing.Union[InputMedia, typing.Dict]]] = None): super(MediaGroup, self).__init__() self.media = [] diff --git a/aiogram/types/pre_checkout_query.py b/aiogram/types/pre_checkout_query.py index b9bcbb12..4786ce48 100644 --- a/aiogram/types/pre_checkout_query.py +++ b/aiogram/types/pre_checkout_query.py @@ -31,4 +31,4 @@ class PreCheckoutQuery(base.TelegramObject): def __eq__(self, other): if isinstance(other, type(self)): return other.id == self.id - return self.id == other \ No newline at end of file + return self.id == other diff --git a/aiogram/utils/context.py b/aiogram/utils/context.py index 2a846b23..dfef6882 100644 --- a/aiogram/utils/context.py +++ b/aiogram/utils/context.py @@ -31,7 +31,7 @@ def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine): del task._source_traceback[-1] try: - task.context = asyncio.Task.current_task().context + task.context = asyncio.Task.current_task().context.copy() except AttributeError: task.context = {CONFIGURED: True} @@ -114,3 +114,25 @@ def check_configured(): :return: """ return get_value(CONFIGURED) + + +class _Context: + """ + Other things for interactions with the execution context. + """ + + def __getitem__(self, item): + return get_value(item) + + def __setitem__(self, key, value): + set_value(key, value) + + def __delitem__(self, key): + del_value(key) + + @staticmethod + def get_context(): + return get_current_state() + + +context = _Context() diff --git a/aiogram/utils/exceptions.py b/aiogram/utils/exceptions.py index edbf5bcd..d0a40c0d 100644 --- a/aiogram/utils/exceptions.py +++ b/aiogram/utils/exceptions.py @@ -1,3 +1,5 @@ +import time + _PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: '] @@ -51,3 +53,21 @@ class MigrateToChat(TelegramAPIError): def __init__(self, chat_id): super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.") self.migrate_to_chat_id = chat_id + + +class Throttled(Exception): + def __init__(self, **kwargs): + from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT + self.key = kwargs.pop(KEY, '') + self.called_at = kwargs.pop(LAST_CALL, time.time()) + self.rate = kwargs.pop(RATE_LIMIT, None) + self.result = kwargs.pop(RESULT, False) + self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0) + self.delta = kwargs.pop(DELTA, 0) + self.user = kwargs.pop('user', None) + self.chat = kwargs.pop('chat', None) + + def __str__(self): + return f"Rate limit exceeded! (Limit: {self.rate} s, " \ + f"exceeded: {self.exceeded_count}, " \ + f"time delta: {round(self.delta, 3)} s)" diff --git a/aiogram/utils/executor.py b/aiogram/utils/executor.py index 45fdf7da..b1fec35e 100644 --- a/aiogram/utils/executor.py +++ b/aiogram/utils/executor.py @@ -34,6 +34,7 @@ async def _shutdown(dispatcher: Dispatcher, callback=None): if dispatcher.is_polling(): dispatcher.stop_polling() + # await dispatcher.wait_closed() await dispatcher.storage.close() await dispatcher.storage.wait_closed() diff --git a/dev_requirements.txt b/dev_requirements.txt new file mode 100644 index 00000000..7fd66b33 --- /dev/null +++ b/dev_requirements.txt @@ -0,0 +1,7 @@ +-r requirements.txt +ujson +emoji +pytest +pytest-asyncio +uvloop +aioredis diff --git a/examples/middleware_and_antiflood.py b/examples/middleware_and_antiflood.py new file mode 100644 index 00000000..80d029ea --- /dev/null +++ b/examples/middleware_and_antiflood.py @@ -0,0 +1,123 @@ +import asyncio + +from aiogram import Bot, types +from aiogram.contrib.fsm_storage.redis import RedisStorage2 +from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx +from aiogram.dispatcher.middlewares import BaseMiddleware +from aiogram.utils import context, executor +from aiogram.utils.exceptions import Throttled + +TOKEN = 'BOT TOKEN HERE' + +loop = asyncio.get_event_loop() + +# In this example used Redis storage +storage = RedisStorage2(db=5) + +bot = Bot(token=TOKEN, loop=loop) +dp = Dispatcher(bot, storage=storage) + + +def rate_limit(limit: int, key=None): + """ + Decorator for configuring rate limit and key in different functions. + + :param limit: + :param key: + :return: + """ + + def decorator(func): + setattr(func, 'throttling_rate_limit', limit) + if key: + setattr(func, 'throttling_key', key) + return func + + return decorator + + +class ThrottlingMiddleware(BaseMiddleware): + """ + Simple middleware + """ + + def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'): + self.rate_limit = limit + self.prefix = key_prefix + super(ThrottlingMiddleware, self).__init__() + + async def on_process_message(self, message: types.Message): + """ + That handler will be called when dispatcher receive message + + :param message: + """ + # Get current handler + handler = context.get_value('handler') + + # Get dispatcher from context + dispatcher = ctx.get_dispatcher() + + # If handler was configured get rate limit and key from handler + if handler: + limit = getattr(handler, 'throttling_rate_limit', self.rate_limit) + key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") + else: + limit = self.rate_limit + key = f"{self.prefix}_message" + + # Use Dispatcher.throttle method. + try: + await dispatcher.throttle(key, rate=limit) + except Throttled as t: + # Execute action + await self.message_throttled(message, t) + + # Cancel current handler + raise CancelHandler() + + async def message_throttled(self, message: types.Message, throttled: Throttled): + """ + Notify user only on first exceed and notify about unlocking only on last exceed + + :param message: + :param throttled: + """ + handler = context.get_value('handler') + dispatcher = ctx.get_dispatcher() + if handler: + key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") + else: + key = f"{self.prefix}_message" + + # Calculate how many time left to the end of block. + delta = throttled.rate - throttled.delta + + # Prevent flooding + if throttled.exceeded_count <= 2: + await message.reply('Too many requests! ') + + # Sleep. + await asyncio.sleep(delta) + + # Check lock status + thr = await dispatcher.check_key(key) + + # If current message is not last with current key - do not send message + if thr.exceeded_count == throttled.exceeded_count: + await message.reply('Unlocked.') + + +@dp.message_handler(commands=['start']) +@rate_limit(5, 'start') # is not required but with that you can configure throttling manager for current handler +async def cmd_test(message: types.Message): + # You can use that command every 5 seconds + await message.reply('Test passed! You can use that command every 5 seconds.') + + +if __name__ == '__main__': + # Setup middleware + dp.middleware.setup(ThrottlingMiddleware()) + + # Start long-polling + executor.start_polling(dp, loop=loop) diff --git a/examples/throtling_example.py b/examples/throtling_example.py new file mode 100644 index 00000000..e1bda994 --- /dev/null +++ b/examples/throtling_example.py @@ -0,0 +1,43 @@ +""" +Example for throttling manager. + +You can use that for flood controlling. +""" + +import asyncio +import logging + +from aiogram import Bot, types +from aiogram.contrib.fsm_storage.memory import MemoryStorage +from aiogram.dispatcher import Dispatcher +from aiogram.utils.exceptions import Throttled +from aiogram.utils.executor import start_polling + +API_TOKEN = 'BOT TOKEN HERE' + +logging.basicConfig(level=logging.INFO) + +loop = asyncio.get_event_loop() +bot = Bot(token=API_TOKEN, loop=loop) + +# Throttling manager does not working without Leaky Bucket. +# Then need to use storage's. For example use simple in-memory storage. +storage = MemoryStorage() +dp = Dispatcher(bot, storage=storage) + + +@dp.message_handler(commands=['start', 'help']) +async def send_welcome(message: types.Message): + try: + # Execute throttling manager with rate-limit equals to 2 seconds for key "start" + await dp.throttle('start', rate=2) + except Throttled: + # If request is throttled the `Throttled` exception will be raised. + await message.reply('Too many requests!') + else: + # Otherwise do something. + await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.") + + +if __name__ == '__main__': + start_polling(dp, loop=loop, skip_updates=True) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..fe936e18 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1 @@ +# pytest_plugins = "pytest_asyncio.plugin" diff --git a/tests/dataset.py b/tests/dataset.py deleted file mode 100644 index be9432df..00000000 --- a/tests/dataset.py +++ /dev/null @@ -1,25 +0,0 @@ -UPDATE = { - "update_id": 128526, - "message": { - "message_id": 11223, - "from": { - "id": 12345678, - "is_bot": False, - "first_name": "FirstName", - "last_name": "LastName", - "username": "username", - "language_code": "ru" - }, - "chat": { - "id": 12345678, - "first_name": "FirstName", - "last_name": "LastName", - "username": "username", - "type": "private" - }, - "date": 1508709711, - "text": "Hi, world!" - } -} - -MESSAGE = UPDATE['message'] diff --git a/tests/test_bot.py b/tests/test_bot.py new file mode 100644 index 00000000..9c0f860d --- /dev/null +++ b/tests/test_bot.py @@ -0,0 +1,4 @@ +import aiogram + +# bot = aiogram.Bot('123456789:AABBCCDDEEFFaabbccddeeff-1234567890') +# TODO: mock for aiogram.bot.api.request and then test all AI methods. diff --git a/tests/test_message.py b/tests/test_message.py deleted file mode 100644 index a3cec0ad..00000000 --- a/tests/test_message.py +++ /dev/null @@ -1,35 +0,0 @@ -import datetime -import unittest - -from aiogram import types -from dataset import MESSAGE - - -class TestMessage(unittest.TestCase): - def setUp(self): - self.message = types.Message(**MESSAGE) - - def test_update_id(self): - self.assertEqual(self.message.message_id, MESSAGE['message_id'], 'test') - self.assertEqual(self.message['message_id'], MESSAGE['message_id']) - - def test_from(self): - self.assertIsInstance(self.message.from_user, types.User) - self.assertEqual(self.message.from_user, self.message['from']) - - def test_chat(self): - self.assertIsInstance(self.message.chat, types.Chat) - self.assertEqual(self.message.chat, self.message['chat']) - - def test_date(self): - self.assertIsInstance(self.message.date, datetime.datetime) - self.assertEqual(int(self.message.date.timestamp()), MESSAGE['date']) - self.assertEqual(self.message.date, self.message['date']) - - def test_text(self): - self.assertEqual(self.message.text, MESSAGE['text']) - self.assertEqual(self.message['text'], MESSAGE['text']) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/types/dataset.py b/tests/types/dataset.py new file mode 100644 index 00000000..0a991a4a --- /dev/null +++ b/tests/types/dataset.py @@ -0,0 +1,75 @@ +USER = { + "id": 12345678, + "is_bot": False, + "first_name": "FirstName", + "last_name": "LastName", + "username": "username", + "language_code": "ru-RU" +} + +CHAT = { + "id": 12345678, + "first_name": "FirstName", + "last_name": "LastName", + "username": "username", + "type": "private" +} + +MESSAGE = { + "message_id": 11223, + "from": USER, + "chat": CHAT, + "date": 1508709711, + "text": "Hi, world!" +} + +DOCUMENT = { + "file_name": "test.docx", + "mime_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "file_id": "BQADAgADpgADy_JxS66XQTBRHFleAg", + "file_size": 21331 +} + +MESSAGE_WITH_DOCUMENT = { + "message_id": 12345, + "from": USER, + "chat": CHAT, + "date": 1508768012, + "document": DOCUMENT, + "caption": "doc description" +} + +UPDATE = { + "update_id": 128526, + "message": MESSAGE +} + +PHOTO = { + "file_id": "AgADBAADFak0G88YZAf8OAug7bHyS9x2ZxkABHVfpJywcloRAAGAAQABAg", + "file_size": 1101, + "width": 90, + "height": 51 +} + +ANIMATION = { + "file_name": "a9b0e0ca537aa344338f80978f0896b7.gif.mp4", + "mime_type": "video/mp4", + "thumb": PHOTO, + "file_id": "CgADBAAD4DUAAoceZAe2WiE9y0crrAI", + "file_size": 65837 +} + +GAME = { + "title": "Karate Kido", + "description": "No trees were harmed in the making of this game :)", + "photo": [PHOTO, PHOTO, PHOTO], + "animation": ANIMATION +} + +MESSAGE_WITH_GAME = { + "message_id": 12345, + "from": USER, + "chat": CHAT, + "date": 1508824810, + "game": GAME +} diff --git a/tests/types/test_animation.py b/tests/types/test_animation.py new file mode 100644 index 00000000..96b44174 --- /dev/null +++ b/tests/types/test_animation.py @@ -0,0 +1,39 @@ +from aiogram import types +from .dataset import ANIMATION + +animation = types.Animation(**ANIMATION) + + +def test_export(): + exported = animation.to_python() + assert isinstance(exported, dict) + assert exported == ANIMATION + + +def test_file_name(): + assert isinstance(animation.file_name, str) + assert animation.file_name == ANIMATION['file_name'] + + +def test_mime_type(): + assert isinstance(animation.mime_type, str) + assert animation.mime_type == ANIMATION['mime_type'] + + +def test_file_id(): + assert isinstance(animation.file_id, str) + # assert hash(animation) == ANIMATION['file_id'] + assert animation.file_id == ANIMATION['file_id'] + + +def test_file_size(): + assert isinstance(animation.file_size, int) + assert animation.file_size == ANIMATION['file_size'] + + +def test_thumb(): + assert isinstance(animation.thumb, types.PhotoSize) + assert animation.thumb.file_id == ANIMATION['thumb']['file_id'] + assert animation.thumb.width == ANIMATION['thumb']['width'] + assert animation.thumb.height == ANIMATION['thumb']['height'] + assert animation.thumb.file_size == ANIMATION['thumb']['file_size'] diff --git a/tests/types/test_chat.py b/tests/types/test_chat.py new file mode 100644 index 00000000..c2b6de4a --- /dev/null +++ b/tests/types/test_chat.py @@ -0,0 +1,61 @@ +from aiogram import types +from .dataset import CHAT + +chat = types.Chat(**CHAT) + + +def test_export(): + exported = chat.to_python() + assert isinstance(exported, dict) + assert exported == CHAT + + +def test_id(): + assert isinstance(chat.id, int) + assert chat.id == CHAT['id'] + assert hash(chat) == CHAT['id'] + + +def test_name(): + assert isinstance(chat.first_name, str) + assert chat.first_name == CHAT['first_name'] + + assert isinstance(chat.last_name, str) + assert chat.last_name == CHAT['last_name'] + + assert isinstance(chat.username, str) + assert chat.username == CHAT['username'] + + +def test_type(): + assert isinstance(chat.type, str) + assert chat.type == CHAT['type'] + + +def test_chat_types(): + assert types.ChatType.PRIVATE == 'private' + assert types.ChatType.GROUP == 'group' + assert types.ChatType.SUPER_GROUP == 'supergroup' + assert types.ChatType.CHANNEL == 'channel' + + +def test_chat_type_filters(): + from . import test_message + assert types.ChatType.is_private(test_message.message) + assert not types.ChatType.is_group(test_message.message) + assert not types.ChatType.is_super_group(test_message.message) + assert not types.ChatType.is_group_or_super_group(test_message.message) + assert not types.ChatType.is_channel(test_message.message) + + +def test_chat_actions(): + assert types.ChatActions.TYPING == 'typing' + assert types.ChatActions.UPLOAD_PHOTO == 'upload_photo' + assert types.ChatActions.RECORD_VIDEO == 'record_video' + assert types.ChatActions.UPLOAD_VIDEO == 'upload_video' + assert types.ChatActions.RECORD_AUDIO == 'record_audio' + assert types.ChatActions.UPLOAD_AUDIO == 'upload_audio' + assert types.ChatActions.UPLOAD_DOCUMENT == 'upload_document' + assert types.ChatActions.FIND_LOCATION == 'find_location' + assert types.ChatActions.RECORD_VIDEO_NOTE == 'record_video_note' + assert types.ChatActions.UPLOAD_VIDEO_NOTE == 'upload_video_note' diff --git a/tests/types/test_document.py b/tests/types/test_document.py new file mode 100644 index 00000000..64b53360 --- /dev/null +++ b/tests/types/test_document.py @@ -0,0 +1,35 @@ +from aiogram import types +from .dataset import DOCUMENT + +document = types.Document(**DOCUMENT) + + +def test_export(): + exported = document.to_python() + assert isinstance(exported, dict) + assert exported == DOCUMENT + + +def test_file_name(): + assert isinstance(document.file_name, str) + assert document.file_name == DOCUMENT['file_name'] + + +def test_mime_type(): + assert isinstance(document.mime_type, str) + assert document.mime_type == DOCUMENT['mime_type'] + + +def test_file_id(): + assert isinstance(document.file_id, str) + # assert hash(document) == DOCUMENT['file_id'] + assert document.file_id == DOCUMENT['file_id'] + + +def test_file_size(): + assert isinstance(document.file_size, int) + assert document.file_size == DOCUMENT['file_size'] + + +def test_thumb(): + assert document.thumb is None diff --git a/tests/types/test_game.py b/tests/types/test_game.py new file mode 100644 index 00000000..c81809f3 --- /dev/null +++ b/tests/types/test_game.py @@ -0,0 +1,29 @@ +from aiogram import types +from .dataset import GAME + +game = types.Game(**GAME) + +def test_export(): + exported = game.to_python() + assert isinstance(exported, dict) + assert exported == GAME + + +def test_title(): + assert isinstance(game.title, str) + assert game.title == GAME['title'] + + +def test_description(): + assert isinstance(game.description, str) + assert game.description == GAME['description'] + + +def test_photo(): + assert isinstance(game.photo, list) + assert len(game.photo) == len(GAME['photo']) + assert all(map(lambda t: isinstance(t, types.PhotoSize), game.photo)) + + +def test_animation(): + assert isinstance(game.animation, types.Animation) diff --git a/tests/types/test_message.py b/tests/types/test_message.py new file mode 100644 index 00000000..8071207e --- /dev/null +++ b/tests/types/test_message.py @@ -0,0 +1,39 @@ +import datetime + +from aiogram import types +from .dataset import MESSAGE + +message = types.Message(**MESSAGE) + + +def test_export(): + exported_chat = message.to_python() + assert isinstance(exported_chat, dict) + assert exported_chat == MESSAGE + + +def test_message_id(): + assert hash(message) == MESSAGE['message_id'] + assert message.message_id == MESSAGE['message_id'] + assert message['message_id'] == MESSAGE['message_id'] + + +def test_from(): + assert isinstance(message.from_user, types.User) + assert message.from_user == message['from'] + + +def test_chat(): + assert isinstance(message.chat, types.Chat) + assert message.chat == message['chat'] + + +def test_date(): + assert isinstance(message.date, datetime.datetime) + assert int(message.date.timestamp()) == MESSAGE['date'] + assert message.date == message['date'] + + +def test_text(): + assert message.text == MESSAGE['text'] + assert message['text'] == MESSAGE['text'] diff --git a/tests/types/test_photo.py b/tests/types/test_photo.py new file mode 100644 index 00000000..73d87fb7 --- /dev/null +++ b/tests/types/test_photo.py @@ -0,0 +1,27 @@ +from aiogram import types +from .dataset import PHOTO + +photo = types.PhotoSize(**PHOTO) + + +def test_export(): + exported = photo.to_python() + assert isinstance(exported, dict) + assert exported == PHOTO + + +def test_file_id(): + assert isinstance(photo.file_id, str) + assert photo.file_id == PHOTO['file_id'] + + +def test_file_size(): + assert isinstance(photo.file_size, int) + assert photo.file_size == PHOTO['file_size'] + + +def test_size(): + assert isinstance(photo.width, int) + assert isinstance(photo.height, int) + assert photo.width == PHOTO['width'] + assert photo.height == PHOTO['height'] diff --git a/tests/types/test_update.py b/tests/types/test_update.py new file mode 100644 index 00000000..72b97571 --- /dev/null +++ b/tests/types/test_update.py @@ -0,0 +1,20 @@ +from aiogram import types +from .dataset import UPDATE + +update = types.Update(**UPDATE) + + +def test_export(): + exported = update.to_python() + assert isinstance(exported, dict) + assert exported == UPDATE + + +def test_update_id(): + assert isinstance(update.update_id, int) + assert hash(update) == UPDATE['update_id'] + assert update.update_id == UPDATE['update_id'] + + +def test_message(): + assert isinstance(update.message, types.Message) diff --git a/tests/types/test_user.py b/tests/types/test_user.py new file mode 100644 index 00000000..ae8413aa --- /dev/null +++ b/tests/types/test_user.py @@ -0,0 +1,48 @@ +from babel import Locale + +from aiogram import types +from .dataset import USER + +user = types.User(**USER) + + +def test_export(): + exported = user.to_python() + assert isinstance(exported, dict) + assert exported == USER + + +def test_id(): + assert isinstance(user.id, int) + assert user.id == USER['id'] + assert hash(user) == USER['id'] + + +def test_bot(): + assert isinstance(user.is_bot, bool) + assert user.is_bot == USER['is_bot'] + + +def test_name(): + assert user.first_name == USER['first_name'] + assert user.last_name == USER['last_name'] + assert user.username == USER['username'] + + +def test_language_code(): + assert user.language_code == USER['language_code'] + assert user.locale == Locale.parse(USER['language_code'], sep='-') + + +def test_full_name(): + assert user.full_name == f"{USER['first_name']} {USER['last_name']}" + + +def test_mention(): + assert user.mention == f"@{USER['username']}" + assert user.get_mention('foo') == f"[foo](tg://user?id={USER['id']})" + assert user.get_mention('foo', as_html=True) == f"foo" + + +def test_url(): + assert user.url == f"tg://user?id={USER['id']}" diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index f1141634..00000000 --- a/tests/utils.py +++ /dev/null @@ -1,2 +0,0 @@ -def out(*message, sep=' '): - print('Test', sep.join(message)) diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..1460b55c --- /dev/null +++ b/tox.ini @@ -0,0 +1,7 @@ +[tox] +envlist = py36 + +[testenv] +deps = -rdev_requirements.txt +commands = pytest +skip_install = true \ No newline at end of file