mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 01:54:53 +00:00
Merge branch 'dev-1.x'
# Conflicts: # aiogram/__init__.py
This commit is contained in:
commit
e881064f6c
43 changed files with 1527 additions and 218 deletions
28
Makefile
28
Makefile
|
|
@ -2,17 +2,18 @@ VENV_NAME := venv
|
|||
PYTHON := $(VENV_NAME)/bin/python
|
||||
AIOGRAM_VERSION := $(shell $(PYTHON) -c "import aiogram;print(aiogram.__version__)")
|
||||
|
||||
RM := rm -rf
|
||||
|
||||
mkvenv:
|
||||
virtualenv $(VENV_NAME)
|
||||
$(PYTHON) -m pip install -r requirements.txt
|
||||
|
||||
clean:
|
||||
find . -name '*.pyc' -exec rm --force {} +
|
||||
find . -name '*.pyo' -exec rm --force {} +
|
||||
find . -name '*~' -exec rm --force {} +
|
||||
rm --force --recursive build/
|
||||
rm --force --recursive dist/
|
||||
rm --force --recursive *.egg-info
|
||||
find . -name '*.pyc' -exec $(RM) {} +
|
||||
find . -name '*.pyo' -exec $(RM) {} +
|
||||
find . -name '*~' -exec $(RM) {} +
|
||||
find . -name '__pycache__' -exec $(RM) {} +
|
||||
$(RM) build/ dist/ docs/build/ .tox/ .cache/ *.egg-info
|
||||
|
||||
tag:
|
||||
@echo "Add tag: '$(AIOGRAM_VERSION)'"
|
||||
|
|
@ -26,14 +27,23 @@ upload:
|
|||
|
||||
release:
|
||||
make clean
|
||||
make tag
|
||||
make test
|
||||
make build
|
||||
make tag
|
||||
@echo "Released aiogram $(AIOGRAM_VERSION)"
|
||||
|
||||
full-release:
|
||||
make release
|
||||
make upload
|
||||
|
||||
|
||||
make install:
|
||||
install:
|
||||
$(PYTHON) setup.py install
|
||||
|
||||
test:
|
||||
tox
|
||||
|
||||
summary:
|
||||
cloc aiogram/ tests/ examples/ setup.py
|
||||
|
||||
docs: docs/source/*
|
||||
cd docs && $(MAKE) html
|
||||
|
|
|
|||
|
|
@ -10,7 +10,16 @@ except ImportError as e:
|
|||
|
||||
from .utils.versions import Stage, Version
|
||||
|
||||
VERSION = Version(1, 0, 2, stage=Stage.FINAL, build=0)
|
||||
try:
|
||||
import uvloop
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
import asyncio
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
VERSION = Version(1, 0, 3, stage=Stage.FINAL, build=0)
|
||||
API_VERSION = Version(3, 5)
|
||||
|
||||
__version__ = VERSION.version
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def _compose_data(params=None, files=None):
|
|||
return data
|
||||
|
||||
|
||||
async def request(session, token, method, data=None, files=None, continue_retry=False, **kwargs) -> bool or dict:
|
||||
async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict:
|
||||
"""
|
||||
Make request to API
|
||||
|
||||
|
|
@ -144,8 +144,6 @@ async def request(session, token, method, data=None, files=None, continue_retry=
|
|||
:type data: :obj:`dict`
|
||||
:param files: files
|
||||
:type files: :obj:`dict`
|
||||
:param continue_retry:
|
||||
:type continue_retry: :obj:`dict`
|
||||
:return: result
|
||||
:rtype :obj:`bool` or :obj:`dict`
|
||||
"""
|
||||
|
|
@ -158,18 +156,13 @@ async def request(session, token, method, data=None, files=None, continue_retry=
|
|||
return await _check_result(method, response)
|
||||
except aiohttp.ClientError as e:
|
||||
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
||||
except exceptions.RetryAfter as e:
|
||||
if continue_retry:
|
||||
await asyncio.sleep(e.timeout)
|
||||
return await request(session, token, method, data, files, **kwargs)
|
||||
raise
|
||||
|
||||
|
||||
class Methods(Helper):
|
||||
"""
|
||||
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api
|
||||
|
||||
List is updated to Bot API 3.4
|
||||
List is updated to Bot API 3.5
|
||||
"""
|
||||
mode = HelperMode.lowerCamelCase
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ class BaseBot:
|
|||
loop: Optional[Union[asyncio.BaseEventLoop, asyncio.AbstractEventLoop]] = None,
|
||||
connections_limit: Optional[base.Integer] = 10,
|
||||
proxy: str = None, proxy_auth: Optional[aiohttp.BasicAuth] = None,
|
||||
continue_retry: Optional[bool] = False,
|
||||
validate_token: Optional[bool] = True):
|
||||
"""
|
||||
Instructions how to get Bot token is found here: https://core.telegram.org/bots#3-how-do-i-create-a-bot
|
||||
|
|
@ -33,8 +32,6 @@ class BaseBot:
|
|||
:type proxy: :obj:`str`
|
||||
:param proxy_auth: Authentication information
|
||||
:type proxy_auth: Optional :obj:`aiohttp.BasicAuth`
|
||||
:param continue_retry: automatic retry sent request when flood control exceeded
|
||||
:type continue_retry: :obj:`bool`
|
||||
:param validate_token: Validate token.
|
||||
:type validate_token: :obj:`bool`
|
||||
:raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError`
|
||||
|
|
@ -48,9 +45,6 @@ class BaseBot:
|
|||
self.proxy = proxy
|
||||
self.proxy_auth = proxy_auth
|
||||
|
||||
# Action on error
|
||||
self.continue_retry = continue_retry
|
||||
|
||||
# Asyncio loop instance
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -68,8 +62,11 @@ class BaseBot:
|
|||
self._data = {}
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
When bot object is deleting - need close all sessions
|
||||
Close all client sessions
|
||||
"""
|
||||
for session in self._temp_sessions:
|
||||
if not session.closed:
|
||||
|
|
@ -77,17 +74,19 @@ class BaseBot:
|
|||
if self.session and not self.session.closed:
|
||||
self.session.close()
|
||||
|
||||
def create_temp_session(self, limit: int = 1) -> aiohttp.ClientSession:
|
||||
def create_temp_session(self, limit: int = 1, force_close: bool = False) -> aiohttp.ClientSession:
|
||||
"""
|
||||
Create temporary session
|
||||
|
||||
:param limit: Limit of connections
|
||||
:type limit: :obj:`int`
|
||||
:param force_close: Set to True to force close and do reconnect after each request (and between redirects).
|
||||
:type force_close: :obj:`bool`
|
||||
:return: New session
|
||||
:rtype: :obj:`aiohttp.TCPConnector`
|
||||
"""
|
||||
session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=limit, force_close=True),
|
||||
connector=aiohttp.TCPConnector(limit=limit, force_close=force_close),
|
||||
loop=self.loop, json_serialize=json.dumps)
|
||||
self._temp_sessions.append(session)
|
||||
return session
|
||||
|
|
@ -123,8 +122,7 @@ class BaseBot:
|
|||
:raise: :obj:`aiogram.exceptions.TelegramApiError`
|
||||
"""
|
||||
return await api.request(self.session, self.__token, method, data, files,
|
||||
proxy=self.proxy, proxy_auth=self.proxy_auth,
|
||||
continue_retry=self.continue_retry)
|
||||
proxy=self.proxy, proxy_auth=self.proxy_auth)
|
||||
|
||||
async def download_file(self, file_path: base.String,
|
||||
destination: Optional[base.InputFile] = None,
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ class Bot(BaseBot):
|
|||
delattr(self, '_me')
|
||||
|
||||
async def download_file_by_id(self, file_id: base.String, destination=None,
|
||||
timeout: base.Integer=30, chunk_size: base.Integer=65536, seek: base.Boolean=True):
|
||||
timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
|
||||
seek: base.Boolean = True):
|
||||
"""
|
||||
Download file by file_id to destination
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class MemoryStorage(BaseStorage):
|
|||
chat_id = str(chat_id)
|
||||
user_id = str(user_id)
|
||||
if user_id not in self.data[chat_id]:
|
||||
self.data[chat_id][user_id] = {'state': None, 'data': {}}
|
||||
self.data[chat_id][user_id] = {'state': None, 'data': {}, 'bucket': {}}
|
||||
return self.data[chat_id][user_id]
|
||||
|
||||
async def get_state(self, *,
|
||||
|
|
@ -82,3 +82,27 @@ class MemoryStorage(BaseStorage):
|
|||
await self.set_state(chat=chat, user=user, state=None)
|
||||
if with_data:
|
||||
await self.set_data(chat=chat, user=user, data={})
|
||||
|
||||
def has_bucket(self):
|
||||
return True
|
||||
|
||||
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
user = self._get_user(chat, user)
|
||||
return user['bucket']
|
||||
|
||||
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
user = self._get_user(chat, user)
|
||||
user['bucket'] = bucket
|
||||
|
||||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None, **kwargs):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
user = self._get_user(chat, user)
|
||||
if bucket is None:
|
||||
bucket = []
|
||||
user['bucket'].update(bucket, **kwargs)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ This module has redis storage for finite-state machine based on `aioredis <https
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import aioredis
|
||||
|
|
@ -10,6 +11,10 @@ import aioredis
|
|||
from ...dispatcher.storage import BaseStorage
|
||||
from ...utils import json
|
||||
|
||||
STATE_KEY = 'state'
|
||||
STATE_DATA_KEY = 'data'
|
||||
STATE_BUCKET_KEY = 'bucket'
|
||||
|
||||
|
||||
class RedisStorage(BaseStorage):
|
||||
"""
|
||||
|
|
@ -90,10 +95,11 @@ class RedisStorage(BaseStorage):
|
|||
return json.loads(data)
|
||||
|
||||
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
state=None, data=None):
|
||||
state=None, data=None, bucket=None):
|
||||
"""
|
||||
Write record to storage
|
||||
|
||||
:param bucket:
|
||||
:param chat:
|
||||
:param user:
|
||||
:param state:
|
||||
|
|
@ -102,11 +108,13 @@ class RedisStorage(BaseStorage):
|
|||
"""
|
||||
if data is None:
|
||||
data = {}
|
||||
if bucket is None:
|
||||
bucket = {}
|
||||
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
addr = f"fsm:{chat}:{user}"
|
||||
|
||||
record = {'state': state, 'data': data}
|
||||
record = {'state': state, 'data': data, 'bucket': bucket}
|
||||
|
||||
conn = await self.redis
|
||||
await conn.execute('SET', addr, json.dumps(record))
|
||||
|
|
@ -168,3 +176,220 @@ class RedisStorage(BaseStorage):
|
|||
else:
|
||||
keys = await conn.execute('KEYS', 'fsm:*')
|
||||
conn.execute('DEL', *keys)
|
||||
|
||||
def has_bucket(self):
|
||||
return True
|
||||
|
||||
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Dict:
|
||||
record = await self.get_record(chat=chat, user=user)
|
||||
return record.get('bucket', {})
|
||||
|
||||
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None):
|
||||
record = await self.get_record(chat=chat, user=user)
|
||||
await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket)
|
||||
|
||||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None, **kwargs):
|
||||
record = await self.get_record(chat=chat, user=user)
|
||||
record_bucket = record.get('bucket', {})
|
||||
record_bucket.update(bucket, **kwargs)
|
||||
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
|
||||
|
||||
|
||||
class RedisStorage2(BaseStorage):
|
||||
"""
|
||||
Busted Redis-base storage for FSM.
|
||||
Works with Redis connection pool and customizable keys prefix.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
storage = RedisStorage('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key')
|
||||
dp = Dispatcher(bot, storage=storage)
|
||||
|
||||
And need to close Redis connection when shutdown
|
||||
|
||||
.. code-block:: python3
|
||||
|
||||
await dp.storage.close()
|
||||
await dp.storage.wait_closed()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None,
|
||||
pool_size=10, loop=None, prefix='fsm', **kwargs):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._db = db
|
||||
self._password = password
|
||||
self._ssl = ssl
|
||||
self._pool_size = pool_size
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._kwargs = kwargs
|
||||
self._prefix = (prefix,)
|
||||
|
||||
self._redis: aioredis.RedisConnection = None
|
||||
self._connection_lock = asyncio.Lock(loop=self._loop)
|
||||
|
||||
@property
|
||||
async def redis(self) -> aioredis.Redis:
|
||||
"""
|
||||
Get Redis connection
|
||||
|
||||
This property is awaitable.
|
||||
"""
|
||||
# Use thread-safe asyncio Lock because this method without that is not safe
|
||||
async with self._connection_lock:
|
||||
if self._redis is None:
|
||||
self._redis = await aioredis.create_redis_pool((self._host, self._port),
|
||||
db=self._db, password=self._password, ssl=self._ssl,
|
||||
minsize=1, maxsize=self._pool_size,
|
||||
loop=self._loop, **self._kwargs)
|
||||
return self._redis
|
||||
|
||||
def generate_key(self, *parts):
|
||||
return ':'.join(self._prefix + tuple(map(str, parts)))
|
||||
|
||||
async def close(self):
|
||||
async with self._connection_lock:
|
||||
if self._redis and not self._redis.closed:
|
||||
self._redis.close()
|
||||
del self._redis
|
||||
self._redis = None
|
||||
|
||||
async def wait_closed(self):
|
||||
async with self._connection_lock:
|
||||
if self._redis:
|
||||
return await self._redis.wait_closed()
|
||||
return True
|
||||
|
||||
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_KEY)
|
||||
redis = await self.redis
|
||||
return await redis.get(key, encoding='utf8') or None
|
||||
|
||||
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||
redis = await self.redis
|
||||
raw_result = await redis.get(key, encoding='utf8')
|
||||
if raw_result:
|
||||
return json.loads(raw_result)
|
||||
return default or {}
|
||||
|
||||
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
state: typing.Optional[typing.AnyStr] = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_KEY)
|
||||
redis = await self.redis
|
||||
if state is None:
|
||||
await redis.delete(key)
|
||||
else:
|
||||
await redis.set(key, state)
|
||||
|
||||
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||
redis = await self.redis
|
||||
await redis.set(key, json.dumps(data))
|
||||
|
||||
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
data: typing.Dict = None, **kwargs):
|
||||
temp_data = await self.get_data(chat=chat, user=user, default={})
|
||||
temp_data.update(data, **kwargs)
|
||||
await self.set_data(chat=chat, user=user, data=temp_data)
|
||||
|
||||
def has_bucket(self):
|
||||
return True
|
||||
|
||||
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||
redis = await self.redis
|
||||
raw_result = await redis.get(key, encoding='utf8')
|
||||
if raw_result:
|
||||
return json.loads(raw_result)
|
||||
return default or {}
|
||||
|
||||
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None):
|
||||
chat, user = self.check_address(chat=chat, user=user)
|
||||
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||
redis = await self.redis
|
||||
await redis.set(key, json.dumps(bucket))
|
||||
|
||||
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None, **kwargs):
|
||||
temp_bucket = await self.get_data(chat=chat, user=user)
|
||||
temp_bucket.update(bucket, **kwargs)
|
||||
await self.set_data(chat=chat, user=user, data=temp_bucket)
|
||||
|
||||
async def reset_all(self, full=True):
|
||||
"""
|
||||
Reset states in DB
|
||||
|
||||
:param full: clean DB or clean only states
|
||||
:return:
|
||||
"""
|
||||
conn = await self.redis
|
||||
|
||||
if full:
|
||||
conn.flushdb()
|
||||
else:
|
||||
keys = await conn.keys(self.generate_key('*'))
|
||||
conn.delete(*keys)
|
||||
|
||||
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
|
||||
"""
|
||||
Get list of all stored chat's and user's
|
||||
|
||||
:return: list of tuples where first element is chat id and second is user id
|
||||
"""
|
||||
conn = await self.redis
|
||||
result = []
|
||||
|
||||
keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8')
|
||||
for item in keys:
|
||||
*_, chat, user, _ = item.split(':')
|
||||
result.append((chat, user))
|
||||
|
||||
return result
|
||||
|
||||
async def import_redis1(self, redis1):
|
||||
await migrate_redis1_to_redis2(redis1, self)
|
||||
|
||||
|
||||
async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2):
|
||||
"""
|
||||
Helper for migrating from RedisStorage to RedisStorage2
|
||||
|
||||
:param storage1: instance of RedisStorage
|
||||
:param storage2: instance of RedisStorage2
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(storage1, RedisStorage)
|
||||
assert isinstance(storage2, RedisStorage2)
|
||||
|
||||
log = logging.getLogger('aiogram.RedisStorage')
|
||||
|
||||
for chat, user in await storage1.get_states_list():
|
||||
state = await storage1.get_state(chat=chat, user=user)
|
||||
await storage2.set_state(chat=chat, user=user, state=state)
|
||||
|
||||
data = await storage1.get_data(chat=chat, user=user)
|
||||
await storage2.set_data(chat=chat, user=user, data=data)
|
||||
|
||||
bucket = await storage1.get_bucket(chat=chat, user=user)
|
||||
await storage2.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
|
||||
log.info(f"Migrated user {user} in chat {chat}")
|
||||
|
|
|
|||
0
aiogram/contrib/middlewares/__init__.py
Normal file
0
aiogram/contrib/middlewares/__init__.py
Normal file
132
aiogram/contrib/middlewares/logging.py
Normal file
132
aiogram/contrib/middlewares/logging.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from aiogram import types
|
||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
|
||||
HANDLED_STR = ['Unhandled', 'Handled']
|
||||
|
||||
|
||||
class LoggingMiddleware(BaseMiddleware):
|
||||
def __init__(self, logger=__name__):
|
||||
if not isinstance(logger, logging.Logger):
|
||||
logger = logging.getLogger(logger)
|
||||
|
||||
self.logger = logger
|
||||
|
||||
super(LoggingMiddleware, self).__init__()
|
||||
|
||||
def check_timeout(self, obj):
|
||||
start = obj.conf.get('_start', None)
|
||||
if start:
|
||||
del obj.conf['_start']
|
||||
return round((time.time() - start) * 1000)
|
||||
return -1
|
||||
|
||||
async def on_pre_process_update(self, update: types.Update):
|
||||
update.conf['_start'] = time.time()
|
||||
self.logger.debug(f"Received update [ID:{update.update_id}]")
|
||||
|
||||
async def on_post_process_update(self, update: types.Update, result):
|
||||
timeout = self.check_timeout(update)
|
||||
if timeout > 0:
|
||||
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
|
||||
|
||||
async def on_pre_process_message(self, message: types.Message):
|
||||
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||
|
||||
async def on_post_process_message(self, message: types.Message, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||
|
||||
async def on_pre_process_edited_message(self, edited_message):
|
||||
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
|
||||
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||
|
||||
async def on_post_process_edited_message(self, edited_message, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"edited message [ID:{edited_message.message_id}] "
|
||||
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||
|
||||
async def on_pre_process_channel_post(self, channel_post: types.Message):
|
||||
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
|
||||
f"in channel [ID:{channel_post.chat.id}]")
|
||||
|
||||
async def on_post_process_channel_post(self, channel_post: types.Message, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"channel post [ID:{channel_post.message_id}] "
|
||||
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
|
||||
|
||||
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message):
|
||||
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
|
||||
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||
|
||||
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"edited channel post [ID:{edited_channel_post.message_id}] "
|
||||
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||
|
||||
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery):
|
||||
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
|
||||
f"from user [ID:{inline_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"inline query [ID:{inline_query.id}] "
|
||||
f"from user [ID:{inline_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
|
||||
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||
f"result [ID:{chosen_inline_result.result_id}]")
|
||||
|
||||
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||
f"result [ID:{chosen_inline_result.result_id}]")
|
||||
|
||||
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery):
|
||||
if callback_query.message:
|
||||
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
|
||||
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
|
||||
f"from user [ID:{callback_query.message.from_user.id}]")
|
||||
else:
|
||||
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
|
||||
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||
f"from user [ID:{callback_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_callback_query(self, callback_query, results):
|
||||
if callback_query.message:
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"callback query [ID:{callback_query.id}] "
|
||||
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
|
||||
f"from user [ID:{callback_query.message.from_user.id}]")
|
||||
else:
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"callback query [ID:{callback_query.id}] "
|
||||
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||
f"from user [ID:{callback_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery):
|
||||
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
|
||||
f"from user [ID:{shipping_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_shipping_query(self, shipping_query, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"shipping query [ID:{shipping_query.id}] "
|
||||
f"from user [ID:{shipping_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
|
||||
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||
|
||||
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results):
|
||||
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||
f"pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||
|
||||
async def on_pre_process_error(self, dispatcher, update, error):
|
||||
timeout = self.check_timeout(update)
|
||||
if timeout > 0:
|
||||
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")
|
||||
|
|
@ -1,20 +1,21 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import time
|
||||
import typing
|
||||
|
||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
||||
generate_default_filters
|
||||
from .handler import Handler
|
||||
from .storage import BaseStorage, DisabledStorage, FSMContext
|
||||
from .handler import CancelHandler, Handler, SkipHandler
|
||||
from .middlewares import MiddlewareManager
|
||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
|
||||
from .webhook import BaseResponse
|
||||
from ..bot import Bot
|
||||
from ..types.message import ContentType
|
||||
from ..utils import context
|
||||
from ..utils.deprecated import deprecated
|
||||
from ..utils.exceptions import NetworkError, TelegramAPIError
|
||||
from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -22,6 +23,8 @@ MODE = 'MODE'
|
|||
LONG_POLLING = 'long-polling'
|
||||
UPDATE_OBJECT = 'update_object'
|
||||
|
||||
DEFAULT_RATE_LIMIT = .1
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
"""
|
||||
|
|
@ -33,7 +36,9 @@ class Dispatcher:
|
|||
"""
|
||||
|
||||
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
||||
run_tasks_by_default: bool = False):
|
||||
run_tasks_by_default: bool = False,
|
||||
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
|
||||
|
||||
if loop is None:
|
||||
loop = bot.loop
|
||||
if storage is None:
|
||||
|
|
@ -44,27 +49,33 @@ class Dispatcher:
|
|||
self.storage = storage
|
||||
self.run_tasks_by_default = run_tasks_by_default
|
||||
|
||||
self.throttling_rate_limit = throttling_rate_limit
|
||||
self.no_throttle_error = no_throttle_error
|
||||
|
||||
self.last_update_id = 0
|
||||
|
||||
self.updates_handler = Handler(self)
|
||||
self.message_handlers = Handler(self)
|
||||
self.edited_message_handlers = Handler(self)
|
||||
self.channel_post_handlers = Handler(self)
|
||||
self.edited_channel_post_handlers = Handler(self)
|
||||
self.inline_query_handlers = Handler(self)
|
||||
self.chosen_inline_result_handlers = Handler(self)
|
||||
self.callback_query_handlers = Handler(self)
|
||||
self.shipping_query_handlers = Handler(self)
|
||||
self.pre_checkout_query_handlers = Handler(self)
|
||||
self.updates_handler = Handler(self, middleware_key='update')
|
||||
self.message_handlers = Handler(self, middleware_key='message')
|
||||
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
|
||||
self.channel_post_handlers = Handler(self, middleware_key='channel_post')
|
||||
self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
|
||||
self.inline_query_handlers = Handler(self, middleware_key='inline_query')
|
||||
self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
|
||||
self.callback_query_handlers = Handler(self, middleware_key='callback_query')
|
||||
self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
|
||||
self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
|
||||
self.errors_handlers = Handler(self, once=False, middleware_key='error')
|
||||
|
||||
self.middleware = MiddlewareManager(self)
|
||||
|
||||
self.updates_handler.register(self.process_update)
|
||||
|
||||
self.errors_handlers = Handler(self, once=False)
|
||||
|
||||
self._polling = False
|
||||
self._closed = True
|
||||
self._close_waiter = loop.create_future()
|
||||
|
||||
def __del__(self):
|
||||
self._polling = False
|
||||
self.stop_polling()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
|
|
@ -105,7 +116,7 @@ class Dispatcher:
|
|||
"""
|
||||
tasks = []
|
||||
for update in updates:
|
||||
tasks.append(self.process_update(update))
|
||||
tasks.append(self.updates_handler.notify(update))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def process_update(self, update):
|
||||
|
|
@ -115,72 +126,59 @@ class Dispatcher:
|
|||
:param update:
|
||||
:return:
|
||||
"""
|
||||
start = time.time()
|
||||
success = True
|
||||
|
||||
self.last_update_id = update.update_id
|
||||
context.set_value(UPDATE_OBJECT, update)
|
||||
try:
|
||||
self.last_update_id = update.update_id
|
||||
has_context = context.check_configured()
|
||||
if has_context:
|
||||
context.set_value(UPDATE_OBJECT, update)
|
||||
if update.message:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id)
|
||||
context.update_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id)
|
||||
context.update_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id,
|
||||
state=state)
|
||||
return await self.message_handlers.notify(update.message)
|
||||
if update.edited_message:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id)
|
||||
context.update_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id)
|
||||
context.update_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id,
|
||||
state=state)
|
||||
return await self.edited_message_handlers.notify(update.edited_message)
|
||||
if update.channel_post:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
||||
context.update_state(chat=update.channel_post.chat.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
||||
context.update_state(chat=update.channel_post.chat.id,
|
||||
state=state)
|
||||
return await self.channel_post_handlers.notify(update.channel_post)
|
||||
if update.edited_channel_post:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
||||
context.update_state(chat=update.edited_channel_post.chat.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
||||
context.update_state(chat=update.edited_channel_post.chat.id,
|
||||
state=state)
|
||||
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
||||
if update.inline_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
||||
context.update_state(user=update.inline_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
||||
context.update_state(user=update.inline_query.from_user.id,
|
||||
state=state)
|
||||
return await self.inline_query_handlers.notify(update.inline_query)
|
||||
if update.chosen_inline_result:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
||||
context.update_state(user=update.chosen_inline_result.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
||||
context.update_state(user=update.chosen_inline_result.from_user.id,
|
||||
state=state)
|
||||
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
||||
if update.callback_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
|
||||
user=update.callback_query.from_user.id)
|
||||
context.update_state(user=update.callback_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(
|
||||
chat=update.callback_query.message.chat.id if update.callback_query.message else None,
|
||||
user=update.callback_query.from_user.id)
|
||||
context.update_state(user=update.callback_query.from_user.id,
|
||||
state=state)
|
||||
return await self.callback_query_handlers.notify(update.callback_query)
|
||||
if update.shipping_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
||||
context.update_state(user=update.shipping_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
||||
context.update_state(user=update.shipping_query.from_user.id,
|
||||
state=state)
|
||||
return await self.shipping_query_handlers.notify(update.shipping_query)
|
||||
if update.pre_checkout_query:
|
||||
if has_context:
|
||||
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
||||
context.update_state(user=update.pre_checkout_query.from_user.id,
|
||||
state=state)
|
||||
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
||||
context.update_state(user=update.pre_checkout_query.from_user.id,
|
||||
state=state)
|
||||
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
||||
except Exception as e:
|
||||
success = False
|
||||
|
|
@ -188,10 +186,6 @@ class Dispatcher:
|
|||
if err:
|
||||
return err
|
||||
raise
|
||||
finally:
|
||||
log.info(f"Process update [ID:{update.update_id}]: "
|
||||
f"{['failed', 'success'][success]} "
|
||||
f"(in {round((time.time() - start) * 1000)} ms)")
|
||||
|
||||
async def reset_webhook(self, check=True) -> bool:
|
||||
"""
|
||||
|
|
@ -244,24 +238,26 @@ class Dispatcher:
|
|||
|
||||
self._polling = True
|
||||
offset = None
|
||||
while self._polling:
|
||||
try:
|
||||
updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
|
||||
except NetworkError:
|
||||
log.exception('Cause exception while getting updates.')
|
||||
await asyncio.sleep(15)
|
||||
continue
|
||||
try:
|
||||
while self._polling:
|
||||
try:
|
||||
updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
|
||||
except NetworkError:
|
||||
log.exception('Cause exception while getting updates.')
|
||||
await asyncio.sleep(15)
|
||||
continue
|
||||
|
||||
if updates:
|
||||
log.debug(f"Received {len(updates)} updates.")
|
||||
offset = updates[-1].update_id + 1
|
||||
if updates:
|
||||
log.debug(f"Received {len(updates)} updates.")
|
||||
offset = updates[-1].update_id + 1
|
||||
|
||||
self.loop.create_task(self._process_polling_updates(updates))
|
||||
self.loop.create_task(self._process_polling_updates(updates))
|
||||
|
||||
if relax:
|
||||
await asyncio.sleep(relax)
|
||||
|
||||
log.warning('Polling is stopped.')
|
||||
if relax:
|
||||
await asyncio.sleep(relax)
|
||||
finally:
|
||||
self._close_waiter.set_result(None)
|
||||
log.warning('Polling is stopped.')
|
||||
|
||||
async def _process_polling_updates(self, updates):
|
||||
"""
|
||||
|
|
@ -270,8 +266,8 @@ class Dispatcher:
|
|||
:param updates: list of updates.
|
||||
"""
|
||||
need_to_call = []
|
||||
for response in await self.process_updates(updates):
|
||||
for response in response:
|
||||
for responses in itertools.chain.from_iterable(await self.process_updates(updates)):
|
||||
for response in responses:
|
||||
if not isinstance(response, BaseResponse):
|
||||
continue
|
||||
need_to_call.append(response.execute_response(self.bot))
|
||||
|
|
@ -288,12 +284,21 @@ class Dispatcher:
|
|||
def stop_polling(self):
|
||||
"""
|
||||
Break long-polling process.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self._polling:
|
||||
log.info('Stop polling.')
|
||||
log.info('Stop polling...')
|
||||
self._polling = False
|
||||
|
||||
async def wait_closed(self):
|
||||
"""
|
||||
Wait closing the long polling
|
||||
|
||||
:return:
|
||||
"""
|
||||
await asyncio.shield(self._close_waiter, loop=self.loop)
|
||||
|
||||
@deprecated('The old method was renamed to `is_polling`')
|
||||
def is_pooling(self):
|
||||
return self.is_polling()
|
||||
|
|
@ -897,7 +902,8 @@ class Dispatcher:
|
|||
"""
|
||||
|
||||
def decorator(callback):
|
||||
self.register_errors_handler(callback, func=func, exception=exception)
|
||||
self.register_errors_handler(self._wrap_async_task(callback, run_task),
|
||||
func=func, exception=exception)
|
||||
return callback
|
||||
|
||||
return decorator
|
||||
|
|
@ -929,6 +935,109 @@ class Dispatcher:
|
|||
|
||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||
|
||||
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
|
||||
"""
|
||||
Execute throttling manager.
|
||||
Return True limit is not exceeded otherwise raise ThrottleError or return False
|
||||
|
||||
:param key: key in storage
|
||||
:param rate: limit (by default is equals with default rate limit)
|
||||
:param user: user id
|
||||
:param chat: chat id
|
||||
:param no_error: return boolean value instead of raising error
|
||||
:return: bool
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if no_error is None:
|
||||
no_error = self.no_throttle_error
|
||||
if rate is None:
|
||||
rate = self.throttling_rate_limit
|
||||
if user is None and chat is None:
|
||||
from . import ctx
|
||||
user = ctx.get_user()
|
||||
chat = ctx.get_chat()
|
||||
|
||||
# Detect current time
|
||||
now = time.time()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
|
||||
# Fix bucket
|
||||
if bucket is None:
|
||||
bucket = {key: {}}
|
||||
if key not in bucket:
|
||||
bucket[key] = {}
|
||||
data = bucket[key]
|
||||
|
||||
# Calculate
|
||||
called = data.get(LAST_CALL, now)
|
||||
delta = now - called
|
||||
result = delta >= rate or delta <= 0
|
||||
|
||||
# Save results
|
||||
data[RESULT] = result
|
||||
data[RATE_LIMIT] = rate
|
||||
data[LAST_CALL] = now
|
||||
data[DELTA] = delta
|
||||
if not result:
|
||||
data[EXCEEDED_COUNT] += 1
|
||||
else:
|
||||
data[EXCEEDED_COUNT] = 1
|
||||
bucket[key].update(data)
|
||||
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
|
||||
if not result and not no_error:
|
||||
# Raise if that is allowed
|
||||
raise Throttled(key=key, chat=chat, user=user, **data)
|
||||
return result
|
||||
|
||||
async def check_key(self, key, chat=None, user=None):
|
||||
"""
|
||||
Get information about key in bucket
|
||||
|
||||
:param key:
|
||||
:param chat:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if user is None and chat is None:
|
||||
from . import ctx
|
||||
user = ctx.get_user()
|
||||
chat = ctx.get_chat()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
data = bucket.get(key, {})
|
||||
return Throttled(key=key, chat=chat, user=user, **data)
|
||||
|
||||
async def release_key(self, key, chat=None, user=None):
|
||||
"""
|
||||
Release blocked key
|
||||
|
||||
:param key:
|
||||
:param chat:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
if not self.storage.has_bucket():
|
||||
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||
|
||||
if user is None and chat is None:
|
||||
from . import ctx
|
||||
user = ctx.get_user()
|
||||
chat = ctx.get_chat()
|
||||
|
||||
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||
if bucket and key in bucket:
|
||||
del bucket['key']
|
||||
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||
return True
|
||||
return False
|
||||
|
||||
def async_task(self, func):
|
||||
"""
|
||||
Execute handler as task and return None.
|
||||
|
|
@ -947,10 +1056,14 @@ class Dispatcher:
|
|||
"""
|
||||
|
||||
def process_response(task):
|
||||
response = task.result()
|
||||
|
||||
if isinstance(response, BaseResponse):
|
||||
self.loop.create_task(response.execute_response(self.bot))
|
||||
try:
|
||||
response = task.result()
|
||||
except Exception as e:
|
||||
self.loop.create_task(
|
||||
self.errors_handlers.notify(self, task.context.get(UPDATE_OBJECT, None), e))
|
||||
else:
|
||||
if isinstance(response, BaseResponse):
|
||||
self.loop.create_task(response.execute_response(self.bot))
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ from ..utils import context
|
|||
def _get(key, default=None, no_error=False):
|
||||
result = context.get_value(key, default)
|
||||
if not no_error and result is None:
|
||||
raise RuntimeError(f"Context is not configured for '{key}'")
|
||||
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
|
||||
f"Maybe asyncio task factory is not configured!\n"
|
||||
f"\t>>> from aiogram.utils import context\n"
|
||||
f"\t>>> loop.set_task_factory(context.task_factory)")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,26 +9,45 @@ from ..utils.helper import Helper, HelperMode, Item
|
|||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
async def check_filter(filter_, args, kwargs):
|
||||
async def check_filter(filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param filter_:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
||||
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
||||
return await filter_(*args, **kwargs)
|
||||
return await filter_(*args)
|
||||
else:
|
||||
return filter_(*args, **kwargs)
|
||||
return filter_(*args)
|
||||
|
||||
|
||||
async def check_filters(filters, args, kwargs):
|
||||
async def check_filters(filters, args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param filters:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(filter_, args, kwargs)
|
||||
f = await check_filter(filter_, args)
|
||||
if not f:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Filter:
|
||||
"""
|
||||
Base class for filters
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.check(*args, **kwargs)
|
||||
|
||||
|
|
@ -37,6 +56,10 @@ class Filter:
|
|||
|
||||
|
||||
class AsyncFilter(Filter):
|
||||
"""
|
||||
Base class for asynchronous filters
|
||||
"""
|
||||
|
||||
def __aiter__(self):
|
||||
return None
|
||||
|
||||
|
|
@ -48,23 +71,35 @@ class AsyncFilter(Filter):
|
|||
|
||||
|
||||
class AnyFilter(AsyncFilter):
|
||||
"""
|
||||
One filter from many
|
||||
"""
|
||||
|
||||
def __init__(self, *filters: callable):
|
||||
self.filters = filters
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
f = (check_filter(filter_, args, kwargs) for filter_ in self.filters)
|
||||
async def check(self, *args):
|
||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||
return any(await asyncio.gather(*f))
|
||||
|
||||
|
||||
class NotFilter(AsyncFilter):
|
||||
"""
|
||||
Reverse filter
|
||||
"""
|
||||
|
||||
def __init__(self, filter_: callable):
|
||||
self.filter = filter_
|
||||
|
||||
async def check(self, *args, **kwargs):
|
||||
return not await check_filter(self.filter, args, kwargs)
|
||||
async def check(self, *args):
|
||||
return not await check_filter(self.filter, args)
|
||||
|
||||
|
||||
class CommandsFilter(AsyncFilter):
|
||||
"""
|
||||
Check commands in message
|
||||
"""
|
||||
|
||||
def __init__(self, commands):
|
||||
self.commands = commands
|
||||
|
||||
|
|
@ -85,6 +120,10 @@ class CommandsFilter(AsyncFilter):
|
|||
|
||||
|
||||
class RegexpFilter(Filter):
|
||||
"""
|
||||
Regexp filter for messages
|
||||
"""
|
||||
|
||||
def __init__(self, regexp):
|
||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
|
|
@ -94,6 +133,10 @@ class RegexpFilter(Filter):
|
|||
|
||||
|
||||
class ContentTypeFilter(Filter):
|
||||
"""
|
||||
Check message content type
|
||||
"""
|
||||
|
||||
def __init__(self, content_types):
|
||||
self.content_types = content_types
|
||||
|
||||
|
|
@ -103,6 +146,10 @@ class ContentTypeFilter(Filter):
|
|||
|
||||
|
||||
class CancelFilter(Filter):
|
||||
"""
|
||||
Find cancel in message text
|
||||
"""
|
||||
|
||||
def __init__(self, cancel_set=None):
|
||||
if cancel_set is None:
|
||||
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
||||
|
|
@ -114,6 +161,10 @@ class CancelFilter(Filter):
|
|||
|
||||
|
||||
class StateFilter(AsyncFilter):
|
||||
"""
|
||||
Check user state
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher, state):
|
||||
self.dispatcher = dispatcher
|
||||
self.state = state
|
||||
|
|
@ -137,6 +188,10 @@ class StateFilter(AsyncFilter):
|
|||
|
||||
|
||||
class StatesListFilter(StateFilter):
|
||||
"""
|
||||
List of states
|
||||
"""
|
||||
|
||||
async def check(self, obj):
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
|
|
@ -146,6 +201,10 @@ class StatesListFilter(StateFilter):
|
|||
|
||||
|
||||
class ExceptionsFilter(Filter):
|
||||
"""
|
||||
Filter for exceptions
|
||||
"""
|
||||
|
||||
def __init__(self, exception):
|
||||
self.exception = exception
|
||||
|
||||
|
|
@ -159,6 +218,14 @@ class ExceptionsFilter(Filter):
|
|||
|
||||
|
||||
def generate_default_filters(dispatcher, *args, **kwargs):
|
||||
"""
|
||||
Prepare filters
|
||||
|
||||
:param dispatcher:
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
filters_set = []
|
||||
|
||||
for name, filter_ in kwargs.items():
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from aiogram.utils import context
|
||||
from .filters import check_filters
|
||||
|
||||
|
||||
|
|
@ -10,11 +11,12 @@ class CancelHandler(BaseException):
|
|||
|
||||
|
||||
class Handler:
|
||||
def __init__(self, dispatcher, once=True):
|
||||
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||
self.dispatcher = dispatcher
|
||||
self.once = once
|
||||
|
||||
self.handlers = []
|
||||
self.middleware_key = middleware_key
|
||||
|
||||
def register(self, handler, filters=None, index=None):
|
||||
"""
|
||||
|
|
@ -48,20 +50,24 @@ class Handler:
|
|||
return True
|
||||
raise ValueError('This handler is not registered!')
|
||||
|
||||
async def notify(self, *args, **kwargs):
|
||||
async def notify(self, *args):
|
||||
"""
|
||||
Notify handlers
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
results = []
|
||||
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
|
||||
for filters, handler in self.handlers:
|
||||
if await check_filters(filters, args, kwargs):
|
||||
if await check_filters(filters, args):
|
||||
try:
|
||||
response = await handler(*args, **kwargs)
|
||||
if self.middleware_key:
|
||||
context.set_value('handler', handler)
|
||||
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||
response = await handler(*args)
|
||||
if results is not None:
|
||||
results.append(response)
|
||||
if self.once:
|
||||
|
|
@ -70,5 +76,8 @@ class Handler:
|
|||
continue
|
||||
except CancelHandler:
|
||||
break
|
||||
if self.middleware_key:
|
||||
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||
args + (results,))
|
||||
|
||||
return results
|
||||
|
|
|
|||
101
aiogram/dispatcher/middlewares.py
Normal file
101
aiogram/dispatcher/middlewares.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
import typing
|
||||
|
||||
log = logging.getLogger('aiogram.Middleware')
|
||||
|
||||
|
||||
class MiddlewareManager:
|
||||
"""
|
||||
Middlewares manager. Works only with dispatcher.
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher):
|
||||
"""
|
||||
Init
|
||||
|
||||
:param dispatcher: instance of Dispatcher
|
||||
"""
|
||||
self.dispatcher = dispatcher
|
||||
self.loop = dispatcher.loop
|
||||
self.bot = dispatcher.bot
|
||||
self.storage = dispatcher.storage
|
||||
self.applications = []
|
||||
|
||||
def setup(self, middleware):
|
||||
"""
|
||||
Setup middleware
|
||||
|
||||
:param middleware:
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(middleware, BaseMiddleware)
|
||||
if middleware.is_configured():
|
||||
raise ValueError('That middleware is already used!')
|
||||
|
||||
self.applications.append(middleware)
|
||||
middleware.setup(self)
|
||||
log.debug(f"Loaded middleware '{middleware.__class__.__name__}'")
|
||||
|
||||
async def trigger(self, action: str, args: typing.Iterable):
|
||||
"""
|
||||
Call action to middlewares with args lilt.
|
||||
|
||||
:param action:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
for app in self.applications:
|
||||
await app.trigger(action, args)
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
"""
|
||||
Base class for middleware.
|
||||
|
||||
All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message".
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._configured = False
|
||||
self._manager = None
|
||||
|
||||
@property
|
||||
def manager(self) -> MiddlewareManager:
|
||||
"""
|
||||
Instance of MiddlewareManager
|
||||
"""
|
||||
if self._manager is None:
|
||||
raise RuntimeError('Middleware is not configured!')
|
||||
return self._manager
|
||||
|
||||
def setup(self, manager):
|
||||
"""
|
||||
Mark middleware as configured
|
||||
|
||||
:param manager:
|
||||
:return:
|
||||
"""
|
||||
self._manager = manager
|
||||
self._configured = True
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
"""
|
||||
Check middleware is configured
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self._configured
|
||||
|
||||
async def trigger(self, action, args):
|
||||
"""
|
||||
Trigger action.
|
||||
|
||||
:param action:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
handler_name = f"on_{action}"
|
||||
handler = getattr(self, handler_name, None)
|
||||
if not handler:
|
||||
return None
|
||||
await handler(*args)
|
||||
|
|
@ -1,5 +1,14 @@
|
|||
import typing
|
||||
|
||||
# Leak bucket
|
||||
KEY = 'key'
|
||||
LAST_CALL = 'called_at'
|
||||
RATE_LIMIT = 'rate_limit'
|
||||
RESULT = 'result'
|
||||
EXCEEDED_COUNT = 'exceeded'
|
||||
DELTA = 'delta'
|
||||
THROTTLE_MANAGER = '$throttle_manager'
|
||||
|
||||
|
||||
class BaseStorage:
|
||||
"""
|
||||
|
|
@ -184,6 +193,78 @@ class BaseStorage:
|
|||
"""
|
||||
await self.reset_state(chat=chat, user=user, with_data=True)
|
||||
|
||||
def has_bucket(self):
|
||||
return False
|
||||
|
||||
async def get_bucket(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||
"""
|
||||
Get state-data for user in chat. Return `default` if data is not presented in storage.
|
||||
|
||||
Chat or user is always required. If one of this is not presented,
|
||||
need set the missing value based on the presented
|
||||
|
||||
:param chat:
|
||||
:param user:
|
||||
:param default:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def set_bucket(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None):
|
||||
"""
|
||||
Set data for user in chat
|
||||
|
||||
Chat or user is always required. If one of this is not presented,
|
||||
need set the missing value based on the presented
|
||||
|
||||
:param chat:
|
||||
:param user:
|
||||
:param bucket:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def update_bucket(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None,
|
||||
bucket: typing.Dict = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Update data for user in chat
|
||||
|
||||
You can use data parameter or|and kwargs.
|
||||
|
||||
Chat or user is always required. If one of this is not presented,
|
||||
need set the missing value based on the presented
|
||||
|
||||
:param bucket:
|
||||
:param chat:
|
||||
:param user:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_bucket(self, *,
|
||||
chat: typing.Union[str, int, None] = None,
|
||||
user: typing.Union[str, int, None] = None):
|
||||
"""
|
||||
Reset data dor user in chat.
|
||||
|
||||
Chat or user is always required. If one of this is not presented,
|
||||
need set the missing value based on the presented
|
||||
|
||||
:param chat:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
await self.set_data(chat=chat, user=user, data={})
|
||||
|
||||
|
||||
class FSMContext:
|
||||
def __init__(self, storage, chat, user):
|
||||
|
|
|
|||
|
|
@ -186,10 +186,15 @@ class WebhookRequestHandler(web.View):
|
|||
dispatcher = self.get_dispatcher()
|
||||
loop = dispatcher.loop
|
||||
|
||||
results = task.result()
|
||||
response = self.get_response(results)
|
||||
if response is not None:
|
||||
asyncio.ensure_future(response.execute_response(self.get_dispatcher().bot), loop=loop)
|
||||
try:
|
||||
results = task.result()
|
||||
except Exception as e:
|
||||
loop.create_task(
|
||||
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
|
||||
else:
|
||||
response = self.get_response(results)
|
||||
if response is not None:
|
||||
asyncio.ensure_future(response.execute_response(dispatcher.bot), loop=loop)
|
||||
|
||||
def get_response(self, results):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -138,10 +138,8 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
|||
|
||||
@property
|
||||
def bot(self):
|
||||
bot = get_value('bot')
|
||||
if bot is None:
|
||||
raise RuntimeError('Can not found bot instance in current context!')
|
||||
return bot
|
||||
from ..dispatcher import ctx
|
||||
return ctx.get_bot()
|
||||
|
||||
def to_python(self) -> typing.Dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class ChatActions(helper.Helper):
|
|||
|
||||
@classmethod
|
||||
async def _do(cls, action: str, sleep=None):
|
||||
from aiogram.dispatcher.ctx import get_bot, get_chat
|
||||
from ..dispatcher.ctx import get_bot, get_chat
|
||||
await get_bot().send_chat_action(get_chat(), action)
|
||||
if sleep:
|
||||
await asyncio.sleep(sleep)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import datetime
|
||||
|
||||
from aiogram.utils import helper
|
||||
from . import base
|
||||
from . import fields
|
||||
from .user import User
|
||||
from ..utils import helper
|
||||
|
||||
|
||||
class ChatMember(base.TelegramObject):
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ class MediaGroup(base.TelegramObject):
|
|||
"""
|
||||
Helper for sending media group
|
||||
"""
|
||||
|
||||
def __init__(self, medias: typing.Optional[typing.List[typing.Union[InputMedia, typing.Dict]]] = None):
|
||||
super(MediaGroup, self).__init__()
|
||||
self.media = []
|
||||
|
|
|
|||
|
|
@ -31,4 +31,4 @@ class PreCheckoutQuery(base.TelegramObject):
|
|||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return other.id == self.id
|
||||
return self.id == other
|
||||
return self.id == other
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
|
|||
del task._source_traceback[-1]
|
||||
|
||||
try:
|
||||
task.context = asyncio.Task.current_task().context
|
||||
task.context = asyncio.Task.current_task().context.copy()
|
||||
except AttributeError:
|
||||
task.context = {CONFIGURED: True}
|
||||
|
||||
|
|
@ -114,3 +114,25 @@ def check_configured():
|
|||
:return:
|
||||
"""
|
||||
return get_value(CONFIGURED)
|
||||
|
||||
|
||||
class _Context:
|
||||
"""
|
||||
Other things for interactions with the execution context.
|
||||
"""
|
||||
|
||||
def __getitem__(self, item):
|
||||
return get_value(item)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
set_value(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del_value(key)
|
||||
|
||||
@staticmethod
|
||||
def get_context():
|
||||
return get_current_state()
|
||||
|
||||
|
||||
context = _Context()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import time
|
||||
|
||||
_PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: ']
|
||||
|
||||
|
||||
|
|
@ -51,3 +53,21 @@ class MigrateToChat(TelegramAPIError):
|
|||
def __init__(self, chat_id):
|
||||
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
|
||||
self.migrate_to_chat_id = chat_id
|
||||
|
||||
|
||||
class Throttled(Exception):
|
||||
def __init__(self, **kwargs):
|
||||
from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT
|
||||
self.key = kwargs.pop(KEY, '<None>')
|
||||
self.called_at = kwargs.pop(LAST_CALL, time.time())
|
||||
self.rate = kwargs.pop(RATE_LIMIT, None)
|
||||
self.result = kwargs.pop(RESULT, False)
|
||||
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
|
||||
self.delta = kwargs.pop(DELTA, 0)
|
||||
self.user = kwargs.pop('user', None)
|
||||
self.chat = kwargs.pop('chat', None)
|
||||
|
||||
def __str__(self):
|
||||
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
|
||||
f"exceeded: {self.exceeded_count}, " \
|
||||
f"time delta: {round(self.delta, 3)} s)"
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ async def _shutdown(dispatcher: Dispatcher, callback=None):
|
|||
|
||||
if dispatcher.is_polling():
|
||||
dispatcher.stop_polling()
|
||||
# await dispatcher.wait_closed()
|
||||
|
||||
await dispatcher.storage.close()
|
||||
await dispatcher.storage.wait_closed()
|
||||
|
|
|
|||
7
dev_requirements.txt
Normal file
7
dev_requirements.txt
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-r requirements.txt
|
||||
ujson
|
||||
emoji
|
||||
pytest
|
||||
pytest-asyncio
|
||||
uvloop
|
||||
aioredis
|
||||
123
examples/middleware_and_antiflood.py
Normal file
123
examples/middleware_and_antiflood.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
import asyncio
|
||||
|
||||
from aiogram import Bot, types
|
||||
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
||||
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
|
||||
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||
from aiogram.utils import context, executor
|
||||
from aiogram.utils.exceptions import Throttled
|
||||
|
||||
TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# In this example used Redis storage
|
||||
storage = RedisStorage2(db=5)
|
||||
|
||||
bot = Bot(token=TOKEN, loop=loop)
|
||||
dp = Dispatcher(bot, storage=storage)
|
||||
|
||||
|
||||
def rate_limit(limit: int, key=None):
|
||||
"""
|
||||
Decorator for configuring rate limit and key in different functions.
|
||||
|
||||
:param limit:
|
||||
:param key:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
setattr(func, 'throttling_rate_limit', limit)
|
||||
if key:
|
||||
setattr(func, 'throttling_key', key)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ThrottlingMiddleware(BaseMiddleware):
|
||||
"""
|
||||
Simple middleware
|
||||
"""
|
||||
|
||||
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
|
||||
self.rate_limit = limit
|
||||
self.prefix = key_prefix
|
||||
super(ThrottlingMiddleware, self).__init__()
|
||||
|
||||
async def on_process_message(self, message: types.Message):
|
||||
"""
|
||||
That handler will be called when dispatcher receive message
|
||||
|
||||
:param message:
|
||||
"""
|
||||
# Get current handler
|
||||
handler = context.get_value('handler')
|
||||
|
||||
# Get dispatcher from context
|
||||
dispatcher = ctx.get_dispatcher()
|
||||
|
||||
# If handler was configured get rate limit and key from handler
|
||||
if handler:
|
||||
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
|
||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||
else:
|
||||
limit = self.rate_limit
|
||||
key = f"{self.prefix}_message"
|
||||
|
||||
# Use Dispatcher.throttle method.
|
||||
try:
|
||||
await dispatcher.throttle(key, rate=limit)
|
||||
except Throttled as t:
|
||||
# Execute action
|
||||
await self.message_throttled(message, t)
|
||||
|
||||
# Cancel current handler
|
||||
raise CancelHandler()
|
||||
|
||||
async def message_throttled(self, message: types.Message, throttled: Throttled):
|
||||
"""
|
||||
Notify user only on first exceed and notify about unlocking only on last exceed
|
||||
|
||||
:param message:
|
||||
:param throttled:
|
||||
"""
|
||||
handler = context.get_value('handler')
|
||||
dispatcher = ctx.get_dispatcher()
|
||||
if handler:
|
||||
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||
else:
|
||||
key = f"{self.prefix}_message"
|
||||
|
||||
# Calculate how many time left to the end of block.
|
||||
delta = throttled.rate - throttled.delta
|
||||
|
||||
# Prevent flooding
|
||||
if throttled.exceeded_count <= 2:
|
||||
await message.reply('Too many requests! ')
|
||||
|
||||
# Sleep.
|
||||
await asyncio.sleep(delta)
|
||||
|
||||
# Check lock status
|
||||
thr = await dispatcher.check_key(key)
|
||||
|
||||
# If current message is not last with current key - do not send message
|
||||
if thr.exceeded_count == throttled.exceeded_count:
|
||||
await message.reply('Unlocked.')
|
||||
|
||||
|
||||
@dp.message_handler(commands=['start'])
|
||||
@rate_limit(5, 'start') # is not required but with that you can configure throttling manager for current handler
|
||||
async def cmd_test(message: types.Message):
|
||||
# You can use that command every 5 seconds
|
||||
await message.reply('Test passed! You can use that command every 5 seconds.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup middleware
|
||||
dp.middleware.setup(ThrottlingMiddleware())
|
||||
|
||||
# Start long-polling
|
||||
executor.start_polling(dp, loop=loop)
|
||||
43
examples/throtling_example.py
Normal file
43
examples/throtling_example.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""
|
||||
Example for throttling manager.
|
||||
|
||||
You can use that for flood controlling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiogram import Bot, types
|
||||
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||
from aiogram.dispatcher import Dispatcher
|
||||
from aiogram.utils.exceptions import Throttled
|
||||
from aiogram.utils.executor import start_polling
|
||||
|
||||
API_TOKEN = 'BOT TOKEN HERE'
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
bot = Bot(token=API_TOKEN, loop=loop)
|
||||
|
||||
# Throttling manager does not working without Leaky Bucket.
|
||||
# Then need to use storage's. For example use simple in-memory storage.
|
||||
storage = MemoryStorage()
|
||||
dp = Dispatcher(bot, storage=storage)
|
||||
|
||||
|
||||
@dp.message_handler(commands=['start', 'help'])
|
||||
async def send_welcome(message: types.Message):
|
||||
try:
|
||||
# Execute throttling manager with rate-limit equals to 2 seconds for key "start"
|
||||
await dp.throttle('start', rate=2)
|
||||
except Throttled:
|
||||
# If request is throttled the `Throttled` exception will be raised.
|
||||
await message.reply('Too many requests!')
|
||||
else:
|
||||
# Otherwise do something.
|
||||
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_polling(dp, loop=loop, skip_updates=True)
|
||||
1
tests/conftest.py
Normal file
1
tests/conftest.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# pytest_plugins = "pytest_asyncio.plugin"
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
UPDATE = {
|
||||
"update_id": 128526,
|
||||
"message": {
|
||||
"message_id": 11223,
|
||||
"from": {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"first_name": "FirstName",
|
||||
"last_name": "LastName",
|
||||
"username": "username",
|
||||
"language_code": "ru"
|
||||
},
|
||||
"chat": {
|
||||
"id": 12345678,
|
||||
"first_name": "FirstName",
|
||||
"last_name": "LastName",
|
||||
"username": "username",
|
||||
"type": "private"
|
||||
},
|
||||
"date": 1508709711,
|
||||
"text": "Hi, world!"
|
||||
}
|
||||
}
|
||||
|
||||
MESSAGE = UPDATE['message']
|
||||
4
tests/test_bot.py
Normal file
4
tests/test_bot.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
import aiogram
|
||||
|
||||
# bot = aiogram.Bot('123456789:AABBCCDDEEFFaabbccddeeff-1234567890')
|
||||
# TODO: mock for aiogram.bot.api.request and then test all AI methods.
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
import datetime
|
||||
import unittest
|
||||
|
||||
from aiogram import types
|
||||
from dataset import MESSAGE
|
||||
|
||||
|
||||
class TestMessage(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.message = types.Message(**MESSAGE)
|
||||
|
||||
def test_update_id(self):
|
||||
self.assertEqual(self.message.message_id, MESSAGE['message_id'], 'test')
|
||||
self.assertEqual(self.message['message_id'], MESSAGE['message_id'])
|
||||
|
||||
def test_from(self):
|
||||
self.assertIsInstance(self.message.from_user, types.User)
|
||||
self.assertEqual(self.message.from_user, self.message['from'])
|
||||
|
||||
def test_chat(self):
|
||||
self.assertIsInstance(self.message.chat, types.Chat)
|
||||
self.assertEqual(self.message.chat, self.message['chat'])
|
||||
|
||||
def test_date(self):
|
||||
self.assertIsInstance(self.message.date, datetime.datetime)
|
||||
self.assertEqual(int(self.message.date.timestamp()), MESSAGE['date'])
|
||||
self.assertEqual(self.message.date, self.message['date'])
|
||||
|
||||
def test_text(self):
|
||||
self.assertEqual(self.message.text, MESSAGE['text'])
|
||||
self.assertEqual(self.message['text'], MESSAGE['text'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
0
tests/types/__init__.py
Normal file
0
tests/types/__init__.py
Normal file
75
tests/types/dataset.py
Normal file
75
tests/types/dataset.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
USER = {
|
||||
"id": 12345678,
|
||||
"is_bot": False,
|
||||
"first_name": "FirstName",
|
||||
"last_name": "LastName",
|
||||
"username": "username",
|
||||
"language_code": "ru-RU"
|
||||
}
|
||||
|
||||
CHAT = {
|
||||
"id": 12345678,
|
||||
"first_name": "FirstName",
|
||||
"last_name": "LastName",
|
||||
"username": "username",
|
||||
"type": "private"
|
||||
}
|
||||
|
||||
MESSAGE = {
|
||||
"message_id": 11223,
|
||||
"from": USER,
|
||||
"chat": CHAT,
|
||||
"date": 1508709711,
|
||||
"text": "Hi, world!"
|
||||
}
|
||||
|
||||
DOCUMENT = {
|
||||
"file_name": "test.docx",
|
||||
"mime_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"file_id": "BQADAgADpgADy_JxS66XQTBRHFleAg",
|
||||
"file_size": 21331
|
||||
}
|
||||
|
||||
MESSAGE_WITH_DOCUMENT = {
|
||||
"message_id": 12345,
|
||||
"from": USER,
|
||||
"chat": CHAT,
|
||||
"date": 1508768012,
|
||||
"document": DOCUMENT,
|
||||
"caption": "doc description"
|
||||
}
|
||||
|
||||
UPDATE = {
|
||||
"update_id": 128526,
|
||||
"message": MESSAGE
|
||||
}
|
||||
|
||||
PHOTO = {
|
||||
"file_id": "AgADBAADFak0G88YZAf8OAug7bHyS9x2ZxkABHVfpJywcloRAAGAAQABAg",
|
||||
"file_size": 1101,
|
||||
"width": 90,
|
||||
"height": 51
|
||||
}
|
||||
|
||||
ANIMATION = {
|
||||
"file_name": "a9b0e0ca537aa344338f80978f0896b7.gif.mp4",
|
||||
"mime_type": "video/mp4",
|
||||
"thumb": PHOTO,
|
||||
"file_id": "CgADBAAD4DUAAoceZAe2WiE9y0crrAI",
|
||||
"file_size": 65837
|
||||
}
|
||||
|
||||
GAME = {
|
||||
"title": "Karate Kido",
|
||||
"description": "No trees were harmed in the making of this game :)",
|
||||
"photo": [PHOTO, PHOTO, PHOTO],
|
||||
"animation": ANIMATION
|
||||
}
|
||||
|
||||
MESSAGE_WITH_GAME = {
|
||||
"message_id": 12345,
|
||||
"from": USER,
|
||||
"chat": CHAT,
|
||||
"date": 1508824810,
|
||||
"game": GAME
|
||||
}
|
||||
39
tests/types/test_animation.py
Normal file
39
tests/types/test_animation.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from aiogram import types
|
||||
from .dataset import ANIMATION
|
||||
|
||||
animation = types.Animation(**ANIMATION)
|
||||
|
||||
|
||||
def test_export():
|
||||
exported = animation.to_python()
|
||||
assert isinstance(exported, dict)
|
||||
assert exported == ANIMATION
|
||||
|
||||
|
||||
def test_file_name():
|
||||
assert isinstance(animation.file_name, str)
|
||||
assert animation.file_name == ANIMATION['file_name']
|
||||
|
||||
|
||||
def test_mime_type():
|
||||
assert isinstance(animation.mime_type, str)
|
||||
assert animation.mime_type == ANIMATION['mime_type']
|
||||
|
||||
|
||||
def test_file_id():
|
||||
assert isinstance(animation.file_id, str)
|
||||
# assert hash(animation) == ANIMATION['file_id']
|
||||
assert animation.file_id == ANIMATION['file_id']
|
||||
|
||||
|
||||
def test_file_size():
|
||||
assert isinstance(animation.file_size, int)
|
||||
assert animation.file_size == ANIMATION['file_size']
|
||||
|
||||
|
||||
def test_thumb():
|
||||
assert isinstance(animation.thumb, types.PhotoSize)
|
||||
assert animation.thumb.file_id == ANIMATION['thumb']['file_id']
|
||||
assert animation.thumb.width == ANIMATION['thumb']['width']
|
||||
assert animation.thumb.height == ANIMATION['thumb']['height']
|
||||
assert animation.thumb.file_size == ANIMATION['thumb']['file_size']
|
||||
61
tests/types/test_chat.py
Normal file
61
tests/types/test_chat.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from aiogram import types
|
||||
from .dataset import CHAT
|
||||
|
||||
chat = types.Chat(**CHAT)
|
||||
|
||||
|
||||
def test_export():
|
||||
exported = chat.to_python()
|
||||
assert isinstance(exported, dict)
|
||||
assert exported == CHAT
|
||||
|
||||
|
||||
def test_id():
|
||||
assert isinstance(chat.id, int)
|
||||
assert chat.id == CHAT['id']
|
||||
assert hash(chat) == CHAT['id']
|
||||
|
||||
|
||||
def test_name():
|
||||
assert isinstance(chat.first_name, str)
|
||||
assert chat.first_name == CHAT['first_name']
|
||||
|
||||
assert isinstance(chat.last_name, str)
|
||||
assert chat.last_name == CHAT['last_name']
|
||||
|
||||
assert isinstance(chat.username, str)
|
||||
assert chat.username == CHAT['username']
|
||||
|
||||
|
||||
def test_type():
|
||||
assert isinstance(chat.type, str)
|
||||
assert chat.type == CHAT['type']
|
||||
|
||||
|
||||
def test_chat_types():
|
||||
assert types.ChatType.PRIVATE == 'private'
|
||||
assert types.ChatType.GROUP == 'group'
|
||||
assert types.ChatType.SUPER_GROUP == 'supergroup'
|
||||
assert types.ChatType.CHANNEL == 'channel'
|
||||
|
||||
|
||||
def test_chat_type_filters():
|
||||
from . import test_message
|
||||
assert types.ChatType.is_private(test_message.message)
|
||||
assert not types.ChatType.is_group(test_message.message)
|
||||
assert not types.ChatType.is_super_group(test_message.message)
|
||||
assert not types.ChatType.is_group_or_super_group(test_message.message)
|
||||
assert not types.ChatType.is_channel(test_message.message)
|
||||
|
||||
|
||||
def test_chat_actions():
|
||||
assert types.ChatActions.TYPING == 'typing'
|
||||
assert types.ChatActions.UPLOAD_PHOTO == 'upload_photo'
|
||||
assert types.ChatActions.RECORD_VIDEO == 'record_video'
|
||||
assert types.ChatActions.UPLOAD_VIDEO == 'upload_video'
|
||||
assert types.ChatActions.RECORD_AUDIO == 'record_audio'
|
||||
assert types.ChatActions.UPLOAD_AUDIO == 'upload_audio'
|
||||
assert types.ChatActions.UPLOAD_DOCUMENT == 'upload_document'
|
||||
assert types.ChatActions.FIND_LOCATION == 'find_location'
|
||||
assert types.ChatActions.RECORD_VIDEO_NOTE == 'record_video_note'
|
||||
assert types.ChatActions.UPLOAD_VIDEO_NOTE == 'upload_video_note'
|
||||
35
tests/types/test_document.py
Normal file
35
tests/types/test_document.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from aiogram import types
|
||||
from .dataset import DOCUMENT
|
||||
|
||||
document = types.Document(**DOCUMENT)
|
||||
|
||||
|
||||
def test_export():
|
||||
exported = document.to_python()
|
||||
assert isinstance(exported, dict)
|
||||
assert exported == DOCUMENT
|
||||
|
||||
|
||||
def test_file_name():
|
||||
assert isinstance(document.file_name, str)
|
||||
assert document.file_name == DOCUMENT['file_name']
|
||||
|
||||
|
||||
def test_mime_type():
|
||||
assert isinstance(document.mime_type, str)
|
||||
assert document.mime_type == DOCUMENT['mime_type']
|
||||
|
||||
|
||||
def test_file_id():
|
||||
assert isinstance(document.file_id, str)
|
||||
# assert hash(document) == DOCUMENT['file_id']
|
||||
assert document.file_id == DOCUMENT['file_id']
|
||||
|
||||
|
||||
def test_file_size():
|
||||
assert isinstance(document.file_size, int)
|
||||
assert document.file_size == DOCUMENT['file_size']
|
||||
|
||||
|
||||
def test_thumb():
|
||||
assert document.thumb is None
|
||||
29
tests/types/test_game.py
Normal file
29
tests/types/test_game.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from aiogram import types
|
||||
from .dataset import GAME
|
||||
|
||||
game = types.Game(**GAME)
|
||||
|
||||
def test_export():
|
||||
exported = game.to_python()
|
||||
assert isinstance(exported, dict)
|
||||
assert exported == GAME
|
||||
|
||||
|
||||
def test_title():
|
||||
assert isinstance(game.title, str)
|
||||
assert game.title == GAME['title']
|
||||
|
||||
|
||||
def test_description():
|
||||
assert isinstance(game.description, str)
|
||||
assert game.description == GAME['description']
|
||||
|
||||
|
||||
def test_photo():
|
||||
assert isinstance(game.photo, list)
|
||||
assert len(game.photo) == len(GAME['photo'])
|
||||
assert all(map(lambda t: isinstance(t, types.PhotoSize), game.photo))
|
||||
|
||||
|
||||
def test_animation():
|
||||
assert isinstance(game.animation, types.Animation)
|
||||
39
tests/types/test_message.py
Normal file
39
tests/types/test_message.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import datetime
|
||||
|
||||
from aiogram import types
|
||||
from .dataset import MESSAGE
|
||||
|
||||
message = types.Message(**MESSAGE)
|
||||
|
||||
|
||||
def test_export():
|
||||
exported_chat = message.to_python()
|
||||
assert isinstance(exported_chat, dict)
|
||||
assert exported_chat == MESSAGE
|
||||
|
||||
|
||||
def test_message_id():
|
||||
assert hash(message) == MESSAGE['message_id']
|
||||
assert message.message_id == MESSAGE['message_id']
|
||||
assert message['message_id'] == MESSAGE['message_id']
|
||||
|
||||
|
||||
def test_from():
|
||||
assert isinstance(message.from_user, types.User)
|
||||
assert message.from_user == message['from']
|
||||
|
||||
|
||||
def test_chat():
|
||||
assert isinstance(message.chat, types.Chat)
|
||||
assert message.chat == message['chat']
|
||||
|
||||
|
||||
def test_date():
|
||||
assert isinstance(message.date, datetime.datetime)
|
||||
assert int(message.date.timestamp()) == MESSAGE['date']
|
||||
assert message.date == message['date']
|
||||
|
||||
|
||||
def test_text():
|
||||
assert message.text == MESSAGE['text']
|
||||
assert message['text'] == MESSAGE['text']
|
||||
27
tests/types/test_photo.py
Normal file
27
tests/types/test_photo.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from aiogram import types
|
||||
from .dataset import PHOTO
|
||||
|
||||
photo = types.PhotoSize(**PHOTO)
|
||||
|
||||
|
||||
def test_export():
|
||||
exported = photo.to_python()
|
||||
assert isinstance(exported, dict)
|
||||
assert exported == PHOTO
|
||||
|
||||
|
||||
def test_file_id():
|
||||
assert isinstance(photo.file_id, str)
|
||||
assert photo.file_id == PHOTO['file_id']
|
||||
|
||||
|
||||
def test_file_size():
|
||||
assert isinstance(photo.file_size, int)
|
||||
assert photo.file_size == PHOTO['file_size']
|
||||
|
||||
|
||||
def test_size():
|
||||
assert isinstance(photo.width, int)
|
||||
assert isinstance(photo.height, int)
|
||||
assert photo.width == PHOTO['width']
|
||||
assert photo.height == PHOTO['height']
|
||||
20
tests/types/test_update.py
Normal file
20
tests/types/test_update.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from aiogram import types
|
||||
from .dataset import UPDATE
|
||||
|
||||
update = types.Update(**UPDATE)
|
||||
|
||||
|
||||
def test_export():
|
||||
exported = update.to_python()
|
||||
assert isinstance(exported, dict)
|
||||
assert exported == UPDATE
|
||||
|
||||
|
||||
def test_update_id():
|
||||
assert isinstance(update.update_id, int)
|
||||
assert hash(update) == UPDATE['update_id']
|
||||
assert update.update_id == UPDATE['update_id']
|
||||
|
||||
|
||||
def test_message():
|
||||
assert isinstance(update.message, types.Message)
|
||||
48
tests/types/test_user.py
Normal file
48
tests/types/test_user.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from babel import Locale
|
||||
|
||||
from aiogram import types
|
||||
from .dataset import USER
|
||||
|
||||
user = types.User(**USER)
|
||||
|
||||
|
||||
def test_export():
|
||||
exported = user.to_python()
|
||||
assert isinstance(exported, dict)
|
||||
assert exported == USER
|
||||
|
||||
|
||||
def test_id():
|
||||
assert isinstance(user.id, int)
|
||||
assert user.id == USER['id']
|
||||
assert hash(user) == USER['id']
|
||||
|
||||
|
||||
def test_bot():
|
||||
assert isinstance(user.is_bot, bool)
|
||||
assert user.is_bot == USER['is_bot']
|
||||
|
||||
|
||||
def test_name():
|
||||
assert user.first_name == USER['first_name']
|
||||
assert user.last_name == USER['last_name']
|
||||
assert user.username == USER['username']
|
||||
|
||||
|
||||
def test_language_code():
|
||||
assert user.language_code == USER['language_code']
|
||||
assert user.locale == Locale.parse(USER['language_code'], sep='-')
|
||||
|
||||
|
||||
def test_full_name():
|
||||
assert user.full_name == f"{USER['first_name']} {USER['last_name']}"
|
||||
|
||||
|
||||
def test_mention():
|
||||
assert user.mention == f"@{USER['username']}"
|
||||
assert user.get_mention('foo') == f"[foo](tg://user?id={USER['id']})"
|
||||
assert user.get_mention('foo', as_html=True) == f"<a href=\"tg://user?id={USER['id']}\">foo</a>"
|
||||
|
||||
|
||||
def test_url():
|
||||
assert user.url == f"tg://user?id={USER['id']}"
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
def out(*message, sep=' '):
|
||||
print('Test', sep.join(message))
|
||||
7
tox.ini
Normal file
7
tox.ini
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
[tox]
|
||||
envlist = py36
|
||||
|
||||
[testenv]
|
||||
deps = -rdev_requirements.txt
|
||||
commands = pytest
|
||||
skip_install = true
|
||||
Loading…
Add table
Add a link
Reference in a new issue