mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
Merge branch 'dev-1.x'
# Conflicts: # aiogram/__init__.py
This commit is contained in:
commit
e881064f6c
43 changed files with 1527 additions and 218 deletions
28
Makefile
28
Makefile
|
|
@ -2,17 +2,18 @@ VENV_NAME := venv
|
||||||
PYTHON := $(VENV_NAME)/bin/python
|
PYTHON := $(VENV_NAME)/bin/python
|
||||||
AIOGRAM_VERSION := $(shell $(PYTHON) -c "import aiogram;print(aiogram.__version__)")
|
AIOGRAM_VERSION := $(shell $(PYTHON) -c "import aiogram;print(aiogram.__version__)")
|
||||||
|
|
||||||
|
RM := rm -rf
|
||||||
|
|
||||||
mkvenv:
|
mkvenv:
|
||||||
virtualenv $(VENV_NAME)
|
virtualenv $(VENV_NAME)
|
||||||
$(PYTHON) -m pip install -r requirements.txt
|
$(PYTHON) -m pip install -r requirements.txt
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
find . -name '*.pyc' -exec rm --force {} +
|
find . -name '*.pyc' -exec $(RM) {} +
|
||||||
find . -name '*.pyo' -exec rm --force {} +
|
find . -name '*.pyo' -exec $(RM) {} +
|
||||||
find . -name '*~' -exec rm --force {} +
|
find . -name '*~' -exec $(RM) {} +
|
||||||
rm --force --recursive build/
|
find . -name '__pycache__' -exec $(RM) {} +
|
||||||
rm --force --recursive dist/
|
$(RM) build/ dist/ docs/build/ .tox/ .cache/ *.egg-info
|
||||||
rm --force --recursive *.egg-info
|
|
||||||
|
|
||||||
tag:
|
tag:
|
||||||
@echo "Add tag: '$(AIOGRAM_VERSION)'"
|
@echo "Add tag: '$(AIOGRAM_VERSION)'"
|
||||||
|
|
@ -26,14 +27,23 @@ upload:
|
||||||
|
|
||||||
release:
|
release:
|
||||||
make clean
|
make clean
|
||||||
make tag
|
make test
|
||||||
make build
|
make build
|
||||||
|
make tag
|
||||||
@echo "Released aiogram $(AIOGRAM_VERSION)"
|
@echo "Released aiogram $(AIOGRAM_VERSION)"
|
||||||
|
|
||||||
full-release:
|
full-release:
|
||||||
make release
|
make release
|
||||||
make upload
|
make upload
|
||||||
|
|
||||||
|
install:
|
||||||
make install:
|
|
||||||
$(PYTHON) setup.py install
|
$(PYTHON) setup.py install
|
||||||
|
|
||||||
|
test:
|
||||||
|
tox
|
||||||
|
|
||||||
|
summary:
|
||||||
|
cloc aiogram/ tests/ examples/ setup.py
|
||||||
|
|
||||||
|
docs: docs/source/*
|
||||||
|
cd docs && $(MAKE) html
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,16 @@ except ImportError as e:
|
||||||
|
|
||||||
from .utils.versions import Stage, Version
|
from .utils.versions import Stage, Version
|
||||||
|
|
||||||
VERSION = Version(1, 0, 2, stage=Stage.FINAL, build=0)
|
try:
|
||||||
|
import uvloop
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
VERSION = Version(1, 0, 3, stage=Stage.FINAL, build=0)
|
||||||
API_VERSION = Version(3, 5)
|
API_VERSION = Version(3, 5)
|
||||||
|
|
||||||
__version__ = VERSION.version
|
__version__ = VERSION.version
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ def _compose_data(params=None, files=None):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def request(session, token, method, data=None, files=None, continue_retry=False, **kwargs) -> bool or dict:
|
async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict:
|
||||||
"""
|
"""
|
||||||
Make request to API
|
Make request to API
|
||||||
|
|
||||||
|
|
@ -144,8 +144,6 @@ async def request(session, token, method, data=None, files=None, continue_retry=
|
||||||
:type data: :obj:`dict`
|
:type data: :obj:`dict`
|
||||||
:param files: files
|
:param files: files
|
||||||
:type files: :obj:`dict`
|
:type files: :obj:`dict`
|
||||||
:param continue_retry:
|
|
||||||
:type continue_retry: :obj:`dict`
|
|
||||||
:return: result
|
:return: result
|
||||||
:rtype :obj:`bool` or :obj:`dict`
|
:rtype :obj:`bool` or :obj:`dict`
|
||||||
"""
|
"""
|
||||||
|
|
@ -158,18 +156,13 @@ async def request(session, token, method, data=None, files=None, continue_retry=
|
||||||
return await _check_result(method, response)
|
return await _check_result(method, response)
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
|
||||||
except exceptions.RetryAfter as e:
|
|
||||||
if continue_retry:
|
|
||||||
await asyncio.sleep(e.timeout)
|
|
||||||
return await request(session, token, method, data, files, **kwargs)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class Methods(Helper):
|
class Methods(Helper):
|
||||||
"""
|
"""
|
||||||
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api
|
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api
|
||||||
|
|
||||||
List is updated to Bot API 3.4
|
List is updated to Bot API 3.5
|
||||||
"""
|
"""
|
||||||
mode = HelperMode.lowerCamelCase
|
mode = HelperMode.lowerCamelCase
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@ class BaseBot:
|
||||||
loop: Optional[Union[asyncio.BaseEventLoop, asyncio.AbstractEventLoop]] = None,
|
loop: Optional[Union[asyncio.BaseEventLoop, asyncio.AbstractEventLoop]] = None,
|
||||||
connections_limit: Optional[base.Integer] = 10,
|
connections_limit: Optional[base.Integer] = 10,
|
||||||
proxy: str = None, proxy_auth: Optional[aiohttp.BasicAuth] = None,
|
proxy: str = None, proxy_auth: Optional[aiohttp.BasicAuth] = None,
|
||||||
continue_retry: Optional[bool] = False,
|
|
||||||
validate_token: Optional[bool] = True):
|
validate_token: Optional[bool] = True):
|
||||||
"""
|
"""
|
||||||
Instructions how to get Bot token is found here: https://core.telegram.org/bots#3-how-do-i-create-a-bot
|
Instructions how to get Bot token is found here: https://core.telegram.org/bots#3-how-do-i-create-a-bot
|
||||||
|
|
@ -33,8 +32,6 @@ class BaseBot:
|
||||||
:type proxy: :obj:`str`
|
:type proxy: :obj:`str`
|
||||||
:param proxy_auth: Authentication information
|
:param proxy_auth: Authentication information
|
||||||
:type proxy_auth: Optional :obj:`aiohttp.BasicAuth`
|
:type proxy_auth: Optional :obj:`aiohttp.BasicAuth`
|
||||||
:param continue_retry: automatic retry sent request when flood control exceeded
|
|
||||||
:type continue_retry: :obj:`bool`
|
|
||||||
:param validate_token: Validate token.
|
:param validate_token: Validate token.
|
||||||
:type validate_token: :obj:`bool`
|
:type validate_token: :obj:`bool`
|
||||||
:raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError`
|
:raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError`
|
||||||
|
|
@ -48,9 +45,6 @@ class BaseBot:
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self.proxy_auth = proxy_auth
|
self.proxy_auth = proxy_auth
|
||||||
|
|
||||||
# Action on error
|
|
||||||
self.continue_retry = continue_retry
|
|
||||||
|
|
||||||
# Asyncio loop instance
|
# Asyncio loop instance
|
||||||
if loop is None:
|
if loop is None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
@ -68,8 +62,11 @@ class BaseBot:
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
"""
|
"""
|
||||||
When bot object is deleting - need close all sessions
|
Close all client sessions
|
||||||
"""
|
"""
|
||||||
for session in self._temp_sessions:
|
for session in self._temp_sessions:
|
||||||
if not session.closed:
|
if not session.closed:
|
||||||
|
|
@ -77,17 +74,19 @@ class BaseBot:
|
||||||
if self.session and not self.session.closed:
|
if self.session and not self.session.closed:
|
||||||
self.session.close()
|
self.session.close()
|
||||||
|
|
||||||
def create_temp_session(self, limit: int = 1) -> aiohttp.ClientSession:
|
def create_temp_session(self, limit: int = 1, force_close: bool = False) -> aiohttp.ClientSession:
|
||||||
"""
|
"""
|
||||||
Create temporary session
|
Create temporary session
|
||||||
|
|
||||||
:param limit: Limit of connections
|
:param limit: Limit of connections
|
||||||
:type limit: :obj:`int`
|
:type limit: :obj:`int`
|
||||||
|
:param force_close: Set to True to force close and do reconnect after each request (and between redirects).
|
||||||
|
:type force_close: :obj:`bool`
|
||||||
:return: New session
|
:return: New session
|
||||||
:rtype: :obj:`aiohttp.TCPConnector`
|
:rtype: :obj:`aiohttp.TCPConnector`
|
||||||
"""
|
"""
|
||||||
session = aiohttp.ClientSession(
|
session = aiohttp.ClientSession(
|
||||||
connector=aiohttp.TCPConnector(limit=limit, force_close=True),
|
connector=aiohttp.TCPConnector(limit=limit, force_close=force_close),
|
||||||
loop=self.loop, json_serialize=json.dumps)
|
loop=self.loop, json_serialize=json.dumps)
|
||||||
self._temp_sessions.append(session)
|
self._temp_sessions.append(session)
|
||||||
return session
|
return session
|
||||||
|
|
@ -123,8 +122,7 @@ class BaseBot:
|
||||||
:raise: :obj:`aiogram.exceptions.TelegramApiError`
|
:raise: :obj:`aiogram.exceptions.TelegramApiError`
|
||||||
"""
|
"""
|
||||||
return await api.request(self.session, self.__token, method, data, files,
|
return await api.request(self.session, self.__token, method, data, files,
|
||||||
proxy=self.proxy, proxy_auth=self.proxy_auth,
|
proxy=self.proxy, proxy_auth=self.proxy_auth)
|
||||||
continue_retry=self.continue_retry)
|
|
||||||
|
|
||||||
async def download_file(self, file_path: base.String,
|
async def download_file(self, file_path: base.String,
|
||||||
destination: Optional[base.InputFile] = None,
|
destination: Optional[base.InputFile] = None,
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,8 @@ class Bot(BaseBot):
|
||||||
delattr(self, '_me')
|
delattr(self, '_me')
|
||||||
|
|
||||||
async def download_file_by_id(self, file_id: base.String, destination=None,
|
async def download_file_by_id(self, file_id: base.String, destination=None,
|
||||||
timeout: base.Integer=30, chunk_size: base.Integer=65536, seek: base.Boolean=True):
|
timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
|
||||||
|
seek: base.Boolean = True):
|
||||||
"""
|
"""
|
||||||
Download file by file_id to destination
|
Download file by file_id to destination
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class MemoryStorage(BaseStorage):
|
||||||
chat_id = str(chat_id)
|
chat_id = str(chat_id)
|
||||||
user_id = str(user_id)
|
user_id = str(user_id)
|
||||||
if user_id not in self.data[chat_id]:
|
if user_id not in self.data[chat_id]:
|
||||||
self.data[chat_id][user_id] = {'state': None, 'data': {}}
|
self.data[chat_id][user_id] = {'state': None, 'data': {}, 'bucket': {}}
|
||||||
return self.data[chat_id][user_id]
|
return self.data[chat_id][user_id]
|
||||||
|
|
||||||
async def get_state(self, *,
|
async def get_state(self, *,
|
||||||
|
|
@ -82,3 +82,27 @@ class MemoryStorage(BaseStorage):
|
||||||
await self.set_state(chat=chat, user=user, state=None)
|
await self.set_state(chat=chat, user=user, state=None)
|
||||||
if with_data:
|
if with_data:
|
||||||
await self.set_data(chat=chat, user=user, data={})
|
await self.set_data(chat=chat, user=user, data={})
|
||||||
|
|
||||||
|
def has_bucket(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
user = self._get_user(chat, user)
|
||||||
|
return user['bucket']
|
||||||
|
|
||||||
|
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
user = self._get_user(chat, user)
|
||||||
|
user['bucket'] = bucket
|
||||||
|
|
||||||
|
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None, **kwargs):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
user = self._get_user(chat, user)
|
||||||
|
if bucket is None:
|
||||||
|
bucket = []
|
||||||
|
user['bucket'].update(bucket, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ This module has redis storage for finite-state machine based on `aioredis <https
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import aioredis
|
import aioredis
|
||||||
|
|
@ -10,6 +11,10 @@ import aioredis
|
||||||
from ...dispatcher.storage import BaseStorage
|
from ...dispatcher.storage import BaseStorage
|
||||||
from ...utils import json
|
from ...utils import json
|
||||||
|
|
||||||
|
STATE_KEY = 'state'
|
||||||
|
STATE_DATA_KEY = 'data'
|
||||||
|
STATE_BUCKET_KEY = 'bucket'
|
||||||
|
|
||||||
|
|
||||||
class RedisStorage(BaseStorage):
|
class RedisStorage(BaseStorage):
|
||||||
"""
|
"""
|
||||||
|
|
@ -90,10 +95,11 @@ class RedisStorage(BaseStorage):
|
||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
|
|
||||||
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
state=None, data=None):
|
state=None, data=None, bucket=None):
|
||||||
"""
|
"""
|
||||||
Write record to storage
|
Write record to storage
|
||||||
|
|
||||||
|
:param bucket:
|
||||||
:param chat:
|
:param chat:
|
||||||
:param user:
|
:param user:
|
||||||
:param state:
|
:param state:
|
||||||
|
|
@ -102,11 +108,13 @@ class RedisStorage(BaseStorage):
|
||||||
"""
|
"""
|
||||||
if data is None:
|
if data is None:
|
||||||
data = {}
|
data = {}
|
||||||
|
if bucket is None:
|
||||||
|
bucket = {}
|
||||||
|
|
||||||
chat, user = self.check_address(chat=chat, user=user)
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
addr = f"fsm:{chat}:{user}"
|
addr = f"fsm:{chat}:{user}"
|
||||||
|
|
||||||
record = {'state': state, 'data': data}
|
record = {'state': state, 'data': data, 'bucket': bucket}
|
||||||
|
|
||||||
conn = await self.redis
|
conn = await self.redis
|
||||||
await conn.execute('SET', addr, json.dumps(record))
|
await conn.execute('SET', addr, json.dumps(record))
|
||||||
|
|
@ -168,3 +176,220 @@ class RedisStorage(BaseStorage):
|
||||||
else:
|
else:
|
||||||
keys = await conn.execute('KEYS', 'fsm:*')
|
keys = await conn.execute('KEYS', 'fsm:*')
|
||||||
conn.execute('DEL', *keys)
|
conn.execute('DEL', *keys)
|
||||||
|
|
||||||
|
def has_bucket(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[str] = None) -> typing.Dict:
|
||||||
|
record = await self.get_record(chat=chat, user=user)
|
||||||
|
return record.get('bucket', {})
|
||||||
|
|
||||||
|
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None):
|
||||||
|
record = await self.get_record(chat=chat, user=user)
|
||||||
|
await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket)
|
||||||
|
|
||||||
|
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None, **kwargs):
|
||||||
|
record = await self.get_record(chat=chat, user=user)
|
||||||
|
record_bucket = record.get('bucket', {})
|
||||||
|
record_bucket.update(bucket, **kwargs)
|
||||||
|
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisStorage2(BaseStorage):
|
||||||
|
"""
|
||||||
|
Busted Redis-base storage for FSM.
|
||||||
|
Works with Redis connection pool and customizable keys prefix.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
storage = RedisStorage('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key')
|
||||||
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
|
||||||
|
And need to close Redis connection when shutdown
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
await dp.storage.close()
|
||||||
|
await dp.storage.wait_closed()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None,
|
||||||
|
pool_size=10, loop=None, prefix='fsm', **kwargs):
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
self._db = db
|
||||||
|
self._password = password
|
||||||
|
self._ssl = ssl
|
||||||
|
self._pool_size = pool_size
|
||||||
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._prefix = (prefix,)
|
||||||
|
|
||||||
|
self._redis: aioredis.RedisConnection = None
|
||||||
|
self._connection_lock = asyncio.Lock(loop=self._loop)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def redis(self) -> aioredis.Redis:
|
||||||
|
"""
|
||||||
|
Get Redis connection
|
||||||
|
|
||||||
|
This property is awaitable.
|
||||||
|
"""
|
||||||
|
# Use thread-safe asyncio Lock because this method without that is not safe
|
||||||
|
async with self._connection_lock:
|
||||||
|
if self._redis is None:
|
||||||
|
self._redis = await aioredis.create_redis_pool((self._host, self._port),
|
||||||
|
db=self._db, password=self._password, ssl=self._ssl,
|
||||||
|
minsize=1, maxsize=self._pool_size,
|
||||||
|
loop=self._loop, **self._kwargs)
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
def generate_key(self, *parts):
|
||||||
|
return ':'.join(self._prefix + tuple(map(str, parts)))
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
async with self._connection_lock:
|
||||||
|
if self._redis and not self._redis.closed:
|
||||||
|
self._redis.close()
|
||||||
|
del self._redis
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
async def wait_closed(self):
|
||||||
|
async with self._connection_lock:
|
||||||
|
if self._redis:
|
||||||
|
return await self._redis.wait_closed()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[str] = None) -> typing.Optional[str]:
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
return await redis.get(key, encoding='utf8') or None
|
||||||
|
|
||||||
|
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
raw_result = await redis.get(key, encoding='utf8')
|
||||||
|
if raw_result:
|
||||||
|
return json.loads(raw_result)
|
||||||
|
return default or {}
|
||||||
|
|
||||||
|
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
state: typing.Optional[typing.AnyStr] = None):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
if state is None:
|
||||||
|
await redis.delete(key)
|
||||||
|
else:
|
||||||
|
await redis.set(key, state)
|
||||||
|
|
||||||
|
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
data: typing.Dict = None):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_DATA_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
await redis.set(key, json.dumps(data))
|
||||||
|
|
||||||
|
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
data: typing.Dict = None, **kwargs):
|
||||||
|
temp_data = await self.get_data(chat=chat, user=user, default={})
|
||||||
|
temp_data.update(data, **kwargs)
|
||||||
|
await self.set_data(chat=chat, user=user, data=temp_data)
|
||||||
|
|
||||||
|
def has_bucket(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
raw_result = await redis.get(key, encoding='utf8')
|
||||||
|
if raw_result:
|
||||||
|
return json.loads(raw_result)
|
||||||
|
return default or {}
|
||||||
|
|
||||||
|
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None):
|
||||||
|
chat, user = self.check_address(chat=chat, user=user)
|
||||||
|
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
|
||||||
|
redis = await self.redis
|
||||||
|
await redis.set(key, json.dumps(bucket))
|
||||||
|
|
||||||
|
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None, **kwargs):
|
||||||
|
temp_bucket = await self.get_data(chat=chat, user=user)
|
||||||
|
temp_bucket.update(bucket, **kwargs)
|
||||||
|
await self.set_data(chat=chat, user=user, data=temp_bucket)
|
||||||
|
|
||||||
|
async def reset_all(self, full=True):
|
||||||
|
"""
|
||||||
|
Reset states in DB
|
||||||
|
|
||||||
|
:param full: clean DB or clean only states
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
conn = await self.redis
|
||||||
|
|
||||||
|
if full:
|
||||||
|
conn.flushdb()
|
||||||
|
else:
|
||||||
|
keys = await conn.keys(self.generate_key('*'))
|
||||||
|
conn.delete(*keys)
|
||||||
|
|
||||||
|
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
|
||||||
|
"""
|
||||||
|
Get list of all stored chat's and user's
|
||||||
|
|
||||||
|
:return: list of tuples where first element is chat id and second is user id
|
||||||
|
"""
|
||||||
|
conn = await self.redis
|
||||||
|
result = []
|
||||||
|
|
||||||
|
keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8')
|
||||||
|
for item in keys:
|
||||||
|
*_, chat, user, _ = item.split(':')
|
||||||
|
result.append((chat, user))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def import_redis1(self, redis1):
|
||||||
|
await migrate_redis1_to_redis2(redis1, self)
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2):
|
||||||
|
"""
|
||||||
|
Helper for migrating from RedisStorage to RedisStorage2
|
||||||
|
|
||||||
|
:param storage1: instance of RedisStorage
|
||||||
|
:param storage2: instance of RedisStorage2
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert isinstance(storage1, RedisStorage)
|
||||||
|
assert isinstance(storage2, RedisStorage2)
|
||||||
|
|
||||||
|
log = logging.getLogger('aiogram.RedisStorage')
|
||||||
|
|
||||||
|
for chat, user in await storage1.get_states_list():
|
||||||
|
state = await storage1.get_state(chat=chat, user=user)
|
||||||
|
await storage2.set_state(chat=chat, user=user, state=state)
|
||||||
|
|
||||||
|
data = await storage1.get_data(chat=chat, user=user)
|
||||||
|
await storage2.set_data(chat=chat, user=user, data=data)
|
||||||
|
|
||||||
|
bucket = await storage1.get_bucket(chat=chat, user=user)
|
||||||
|
await storage2.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||||
|
|
||||||
|
log.info(f"Migrated user {user} in chat {chat}")
|
||||||
|
|
|
||||||
0
aiogram/contrib/middlewares/__init__.py
Normal file
0
aiogram/contrib/middlewares/__init__.py
Normal file
132
aiogram/contrib/middlewares/logging.py
Normal file
132
aiogram/contrib/middlewares/logging.py
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from aiogram import types
|
||||||
|
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||||
|
|
||||||
|
HANDLED_STR = ['Unhandled', 'Handled']
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingMiddleware(BaseMiddleware):
|
||||||
|
def __init__(self, logger=__name__):
|
||||||
|
if not isinstance(logger, logging.Logger):
|
||||||
|
logger = logging.getLogger(logger)
|
||||||
|
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
super(LoggingMiddleware, self).__init__()
|
||||||
|
|
||||||
|
def check_timeout(self, obj):
|
||||||
|
start = obj.conf.get('_start', None)
|
||||||
|
if start:
|
||||||
|
del obj.conf['_start']
|
||||||
|
return round((time.time() - start) * 1000)
|
||||||
|
return -1
|
||||||
|
|
||||||
|
async def on_pre_process_update(self, update: types.Update):
|
||||||
|
update.conf['_start'] = time.time()
|
||||||
|
self.logger.debug(f"Received update [ID:{update.update_id}]")
|
||||||
|
|
||||||
|
async def on_post_process_update(self, update: types.Update, result):
|
||||||
|
timeout = self.check_timeout(update)
|
||||||
|
if timeout > 0:
|
||||||
|
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
|
||||||
|
|
||||||
|
async def on_pre_process_message(self, message: types.Message):
|
||||||
|
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_message(self, message: types.Message, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_edited_message(self, edited_message):
|
||||||
|
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
|
||||||
|
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_edited_message(self, edited_message, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"edited message [ID:{edited_message.message_id}] "
|
||||||
|
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_channel_post(self, channel_post: types.Message):
|
||||||
|
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
|
||||||
|
f"in channel [ID:{channel_post.chat.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_channel_post(self, channel_post: types.Message, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"channel post [ID:{channel_post.message_id}] "
|
||||||
|
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message):
|
||||||
|
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
|
||||||
|
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"edited channel post [ID:{edited_channel_post.message_id}] "
|
||||||
|
f"in channel [ID:{edited_channel_post.chat.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery):
|
||||||
|
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
|
||||||
|
f"from user [ID:{inline_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"inline query [ID:{inline_query.id}] "
|
||||||
|
f"from user [ID:{inline_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
|
||||||
|
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||||
|
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||||
|
f"result [ID:{chosen_inline_result.result_id}]")
|
||||||
|
|
||||||
|
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
|
||||||
|
f"from user [ID:{chosen_inline_result.from_user.id}] "
|
||||||
|
f"result [ID:{chosen_inline_result.result_id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery):
|
||||||
|
if callback_query.message:
|
||||||
|
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
|
||||||
|
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
|
||||||
|
f"from user [ID:{callback_query.message.from_user.id}]")
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
|
||||||
|
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||||
|
f"from user [ID:{callback_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_callback_query(self, callback_query, results):
|
||||||
|
if callback_query.message:
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"callback query [ID:{callback_query.id}] "
|
||||||
|
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
|
||||||
|
f"from user [ID:{callback_query.message.from_user.id}]")
|
||||||
|
else:
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"callback query [ID:{callback_query.id}] "
|
||||||
|
f"from inline message [ID:{callback_query.inline_message_id}] "
|
||||||
|
f"from user [ID:{callback_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery):
|
||||||
|
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
|
||||||
|
f"from user [ID:{shipping_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_shipping_query(self, shipping_query, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"shipping query [ID:{shipping_query.id}] "
|
||||||
|
f"from user [ID:{shipping_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
|
||||||
|
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||||
|
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results):
|
||||||
|
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
|
||||||
|
f"pre-checkout query [ID:{pre_checkout_query.id}] "
|
||||||
|
f"from user [ID:{pre_checkout_query.from_user.id}]")
|
||||||
|
|
||||||
|
async def on_pre_process_error(self, dispatcher, update, error):
|
||||||
|
timeout = self.check_timeout(update)
|
||||||
|
if timeout > 0:
|
||||||
|
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")
|
||||||
|
|
@ -1,20 +1,21 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import typing
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import typing
|
||||||
|
|
||||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
|
||||||
generate_default_filters
|
generate_default_filters
|
||||||
from .handler import Handler
|
from .handler import CancelHandler, Handler, SkipHandler
|
||||||
from .storage import BaseStorage, DisabledStorage, FSMContext
|
from .middlewares import MiddlewareManager
|
||||||
|
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
|
||||||
from .webhook import BaseResponse
|
from .webhook import BaseResponse
|
||||||
from ..bot import Bot
|
from ..bot import Bot
|
||||||
from ..types.message import ContentType
|
from ..types.message import ContentType
|
||||||
from ..utils import context
|
from ..utils import context
|
||||||
from ..utils.deprecated import deprecated
|
from ..utils.deprecated import deprecated
|
||||||
from ..utils.exceptions import NetworkError, TelegramAPIError
|
from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -22,6 +23,8 @@ MODE = 'MODE'
|
||||||
LONG_POLLING = 'long-polling'
|
LONG_POLLING = 'long-polling'
|
||||||
UPDATE_OBJECT = 'update_object'
|
UPDATE_OBJECT = 'update_object'
|
||||||
|
|
||||||
|
DEFAULT_RATE_LIMIT = .1
|
||||||
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
@ -33,7 +36,9 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
||||||
run_tasks_by_default: bool = False):
|
run_tasks_by_default: bool = False,
|
||||||
|
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
|
||||||
|
|
||||||
if loop is None:
|
if loop is None:
|
||||||
loop = bot.loop
|
loop = bot.loop
|
||||||
if storage is None:
|
if storage is None:
|
||||||
|
|
@ -44,27 +49,33 @@ class Dispatcher:
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.run_tasks_by_default = run_tasks_by_default
|
self.run_tasks_by_default = run_tasks_by_default
|
||||||
|
|
||||||
|
self.throttling_rate_limit = throttling_rate_limit
|
||||||
|
self.no_throttle_error = no_throttle_error
|
||||||
|
|
||||||
self.last_update_id = 0
|
self.last_update_id = 0
|
||||||
|
|
||||||
self.updates_handler = Handler(self)
|
self.updates_handler = Handler(self, middleware_key='update')
|
||||||
self.message_handlers = Handler(self)
|
self.message_handlers = Handler(self, middleware_key='message')
|
||||||
self.edited_message_handlers = Handler(self)
|
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
|
||||||
self.channel_post_handlers = Handler(self)
|
self.channel_post_handlers = Handler(self, middleware_key='channel_post')
|
||||||
self.edited_channel_post_handlers = Handler(self)
|
self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
|
||||||
self.inline_query_handlers = Handler(self)
|
self.inline_query_handlers = Handler(self, middleware_key='inline_query')
|
||||||
self.chosen_inline_result_handlers = Handler(self)
|
self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
|
||||||
self.callback_query_handlers = Handler(self)
|
self.callback_query_handlers = Handler(self, middleware_key='callback_query')
|
||||||
self.shipping_query_handlers = Handler(self)
|
self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
|
||||||
self.pre_checkout_query_handlers = Handler(self)
|
self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
|
||||||
|
self.errors_handlers = Handler(self, once=False, middleware_key='error')
|
||||||
|
|
||||||
|
self.middleware = MiddlewareManager(self)
|
||||||
|
|
||||||
self.updates_handler.register(self.process_update)
|
self.updates_handler.register(self.process_update)
|
||||||
|
|
||||||
self.errors_handlers = Handler(self, once=False)
|
|
||||||
|
|
||||||
self._polling = False
|
self._polling = False
|
||||||
|
self._closed = True
|
||||||
|
self._close_waiter = loop.create_future()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self._polling = False
|
self.stop_polling()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
|
|
@ -105,7 +116,7 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
tasks = []
|
tasks = []
|
||||||
for update in updates:
|
for update in updates:
|
||||||
tasks.append(self.process_update(update))
|
tasks.append(self.updates_handler.notify(update))
|
||||||
return await asyncio.gather(*tasks)
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def process_update(self, update):
|
async def process_update(self, update):
|
||||||
|
|
@ -115,72 +126,59 @@ class Dispatcher:
|
||||||
:param update:
|
:param update:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
start = time.time()
|
self.last_update_id = update.update_id
|
||||||
success = True
|
context.set_value(UPDATE_OBJECT, update)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.last_update_id = update.update_id
|
|
||||||
has_context = context.check_configured()
|
|
||||||
if has_context:
|
|
||||||
context.set_value(UPDATE_OBJECT, update)
|
|
||||||
if update.message:
|
if update.message:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.message.chat.id,
|
||||||
state = await self.storage.get_state(chat=update.message.chat.id,
|
user=update.message.from_user.id)
|
||||||
user=update.message.from_user.id)
|
context.update_state(chat=update.message.chat.id,
|
||||||
context.update_state(chat=update.message.chat.id,
|
user=update.message.from_user.id,
|
||||||
user=update.message.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.message_handlers.notify(update.message)
|
return await self.message_handlers.notify(update.message)
|
||||||
if update.edited_message:
|
if update.edited_message:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
||||||
state = await self.storage.get_state(chat=update.edited_message.chat.id,
|
user=update.edited_message.from_user.id)
|
||||||
user=update.edited_message.from_user.id)
|
context.update_state(chat=update.edited_message.chat.id,
|
||||||
context.update_state(chat=update.edited_message.chat.id,
|
user=update.edited_message.from_user.id,
|
||||||
user=update.edited_message.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.edited_message_handlers.notify(update.edited_message)
|
return await self.edited_message_handlers.notify(update.edited_message)
|
||||||
if update.channel_post:
|
if update.channel_post:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
||||||
state = await self.storage.get_state(chat=update.channel_post.chat.id)
|
context.update_state(chat=update.channel_post.chat.id,
|
||||||
context.update_state(chat=update.channel_post.chat.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.channel_post_handlers.notify(update.channel_post)
|
return await self.channel_post_handlers.notify(update.channel_post)
|
||||||
if update.edited_channel_post:
|
if update.edited_channel_post:
|
||||||
if has_context:
|
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
||||||
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
|
context.update_state(chat=update.edited_channel_post.chat.id,
|
||||||
context.update_state(chat=update.edited_channel_post.chat.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
||||||
if update.inline_query:
|
if update.inline_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.inline_query.from_user.id)
|
context.update_state(user=update.inline_query.from_user.id,
|
||||||
context.update_state(user=update.inline_query.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.inline_query_handlers.notify(update.inline_query)
|
return await self.inline_query_handlers.notify(update.inline_query)
|
||||||
if update.chosen_inline_result:
|
if update.chosen_inline_result:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
context.update_state(user=update.chosen_inline_result.from_user.id,
|
||||||
context.update_state(user=update.chosen_inline_result.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
||||||
if update.callback_query:
|
if update.callback_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(
|
||||||
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
|
chat=update.callback_query.message.chat.id if update.callback_query.message else None,
|
||||||
user=update.callback_query.from_user.id)
|
user=update.callback_query.from_user.id)
|
||||||
context.update_state(user=update.callback_query.from_user.id,
|
context.update_state(user=update.callback_query.from_user.id,
|
||||||
state=state)
|
state=state)
|
||||||
return await self.callback_query_handlers.notify(update.callback_query)
|
return await self.callback_query_handlers.notify(update.callback_query)
|
||||||
if update.shipping_query:
|
if update.shipping_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
|
context.update_state(user=update.shipping_query.from_user.id,
|
||||||
context.update_state(user=update.shipping_query.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.shipping_query_handlers.notify(update.shipping_query)
|
return await self.shipping_query_handlers.notify(update.shipping_query)
|
||||||
if update.pre_checkout_query:
|
if update.pre_checkout_query:
|
||||||
if has_context:
|
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
||||||
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
context.update_state(user=update.pre_checkout_query.from_user.id,
|
||||||
context.update_state(user=update.pre_checkout_query.from_user.id,
|
state=state)
|
||||||
state=state)
|
|
||||||
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
success = False
|
success = False
|
||||||
|
|
@ -188,10 +186,6 @@ class Dispatcher:
|
||||||
if err:
|
if err:
|
||||||
return err
|
return err
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
log.info(f"Process update [ID:{update.update_id}]: "
|
|
||||||
f"{['failed', 'success'][success]} "
|
|
||||||
f"(in {round((time.time() - start) * 1000)} ms)")
|
|
||||||
|
|
||||||
async def reset_webhook(self, check=True) -> bool:
|
async def reset_webhook(self, check=True) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -244,24 +238,26 @@ class Dispatcher:
|
||||||
|
|
||||||
self._polling = True
|
self._polling = True
|
||||||
offset = None
|
offset = None
|
||||||
while self._polling:
|
try:
|
||||||
try:
|
while self._polling:
|
||||||
updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
|
try:
|
||||||
except NetworkError:
|
updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
|
||||||
log.exception('Cause exception while getting updates.')
|
except NetworkError:
|
||||||
await asyncio.sleep(15)
|
log.exception('Cause exception while getting updates.')
|
||||||
continue
|
await asyncio.sleep(15)
|
||||||
|
continue
|
||||||
|
|
||||||
if updates:
|
if updates:
|
||||||
log.debug(f"Received {len(updates)} updates.")
|
log.debug(f"Received {len(updates)} updates.")
|
||||||
offset = updates[-1].update_id + 1
|
offset = updates[-1].update_id + 1
|
||||||
|
|
||||||
self.loop.create_task(self._process_polling_updates(updates))
|
self.loop.create_task(self._process_polling_updates(updates))
|
||||||
|
|
||||||
if relax:
|
if relax:
|
||||||
await asyncio.sleep(relax)
|
await asyncio.sleep(relax)
|
||||||
|
finally:
|
||||||
log.warning('Polling is stopped.')
|
self._close_waiter.set_result(None)
|
||||||
|
log.warning('Polling is stopped.')
|
||||||
|
|
||||||
async def _process_polling_updates(self, updates):
|
async def _process_polling_updates(self, updates):
|
||||||
"""
|
"""
|
||||||
|
|
@ -270,8 +266,8 @@ class Dispatcher:
|
||||||
:param updates: list of updates.
|
:param updates: list of updates.
|
||||||
"""
|
"""
|
||||||
need_to_call = []
|
need_to_call = []
|
||||||
for response in await self.process_updates(updates):
|
for responses in itertools.chain.from_iterable(await self.process_updates(updates)):
|
||||||
for response in response:
|
for response in responses:
|
||||||
if not isinstance(response, BaseResponse):
|
if not isinstance(response, BaseResponse):
|
||||||
continue
|
continue
|
||||||
need_to_call.append(response.execute_response(self.bot))
|
need_to_call.append(response.execute_response(self.bot))
|
||||||
|
|
@ -288,12 +284,21 @@ class Dispatcher:
|
||||||
def stop_polling(self):
|
def stop_polling(self):
|
||||||
"""
|
"""
|
||||||
Break long-polling process.
|
Break long-polling process.
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if self._polling:
|
if self._polling:
|
||||||
log.info('Stop polling.')
|
log.info('Stop polling...')
|
||||||
self._polling = False
|
self._polling = False
|
||||||
|
|
||||||
|
async def wait_closed(self):
|
||||||
|
"""
|
||||||
|
Wait closing the long polling
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await asyncio.shield(self._close_waiter, loop=self.loop)
|
||||||
|
|
||||||
@deprecated('The old method was renamed to `is_polling`')
|
@deprecated('The old method was renamed to `is_polling`')
|
||||||
def is_pooling(self):
|
def is_pooling(self):
|
||||||
return self.is_polling()
|
return self.is_polling()
|
||||||
|
|
@ -897,7 +902,8 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(callback):
|
def decorator(callback):
|
||||||
self.register_errors_handler(callback, func=func, exception=exception)
|
self.register_errors_handler(self._wrap_async_task(callback, run_task),
|
||||||
|
func=func, exception=exception)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
@ -929,6 +935,109 @@ class Dispatcher:
|
||||||
|
|
||||||
return FSMContext(storage=self.storage, chat=chat, user=user)
|
return FSMContext(storage=self.storage, chat=chat, user=user)
|
||||||
|
|
||||||
|
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
|
||||||
|
"""
|
||||||
|
Execute throttling manager.
|
||||||
|
Return True limit is not exceeded otherwise raise ThrottleError or return False
|
||||||
|
|
||||||
|
:param key: key in storage
|
||||||
|
:param rate: limit (by default is equals with default rate limit)
|
||||||
|
:param user: user id
|
||||||
|
:param chat: chat id
|
||||||
|
:param no_error: return boolean value instead of raising error
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
if not self.storage.has_bucket():
|
||||||
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
|
if no_error is None:
|
||||||
|
no_error = self.no_throttle_error
|
||||||
|
if rate is None:
|
||||||
|
rate = self.throttling_rate_limit
|
||||||
|
if user is None and chat is None:
|
||||||
|
from . import ctx
|
||||||
|
user = ctx.get_user()
|
||||||
|
chat = ctx.get_chat()
|
||||||
|
|
||||||
|
# Detect current time
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||||
|
|
||||||
|
# Fix bucket
|
||||||
|
if bucket is None:
|
||||||
|
bucket = {key: {}}
|
||||||
|
if key not in bucket:
|
||||||
|
bucket[key] = {}
|
||||||
|
data = bucket[key]
|
||||||
|
|
||||||
|
# Calculate
|
||||||
|
called = data.get(LAST_CALL, now)
|
||||||
|
delta = now - called
|
||||||
|
result = delta >= rate or delta <= 0
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
data[RESULT] = result
|
||||||
|
data[RATE_LIMIT] = rate
|
||||||
|
data[LAST_CALL] = now
|
||||||
|
data[DELTA] = delta
|
||||||
|
if not result:
|
||||||
|
data[EXCEEDED_COUNT] += 1
|
||||||
|
else:
|
||||||
|
data[EXCEEDED_COUNT] = 1
|
||||||
|
bucket[key].update(data)
|
||||||
|
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||||
|
|
||||||
|
if not result and not no_error:
|
||||||
|
# Raise if that is allowed
|
||||||
|
raise Throttled(key=key, chat=chat, user=user, **data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def check_key(self, key, chat=None, user=None):
|
||||||
|
"""
|
||||||
|
Get information about key in bucket
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not self.storage.has_bucket():
|
||||||
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
|
if user is None and chat is None:
|
||||||
|
from . import ctx
|
||||||
|
user = ctx.get_user()
|
||||||
|
chat = ctx.get_chat()
|
||||||
|
|
||||||
|
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||||
|
data = bucket.get(key, {})
|
||||||
|
return Throttled(key=key, chat=chat, user=user, **data)
|
||||||
|
|
||||||
|
async def release_key(self, key, chat=None, user=None):
|
||||||
|
"""
|
||||||
|
Release blocked key
|
||||||
|
|
||||||
|
:param key:
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not self.storage.has_bucket():
|
||||||
|
raise RuntimeError('This storage does not provide Leaky Bucket')
|
||||||
|
|
||||||
|
if user is None and chat is None:
|
||||||
|
from . import ctx
|
||||||
|
user = ctx.get_user()
|
||||||
|
chat = ctx.get_chat()
|
||||||
|
|
||||||
|
bucket = await self.storage.get_bucket(chat=chat, user=user)
|
||||||
|
if bucket and key in bucket:
|
||||||
|
del bucket['key']
|
||||||
|
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def async_task(self, func):
|
def async_task(self, func):
|
||||||
"""
|
"""
|
||||||
Execute handler as task and return None.
|
Execute handler as task and return None.
|
||||||
|
|
@ -947,10 +1056,14 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def process_response(task):
|
def process_response(task):
|
||||||
response = task.result()
|
try:
|
||||||
|
response = task.result()
|
||||||
if isinstance(response, BaseResponse):
|
except Exception as e:
|
||||||
self.loop.create_task(response.execute_response(self.bot))
|
self.loop.create_task(
|
||||||
|
self.errors_handlers.notify(self, task.context.get(UPDATE_OBJECT, None), e))
|
||||||
|
else:
|
||||||
|
if isinstance(response, BaseResponse):
|
||||||
|
self.loop.create_task(response.execute_response(self.bot))
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,10 @@ from ..utils import context
|
||||||
def _get(key, default=None, no_error=False):
|
def _get(key, default=None, no_error=False):
|
||||||
result = context.get_value(key, default)
|
result = context.get_value(key, default)
|
||||||
if not no_error and result is None:
|
if not no_error and result is None:
|
||||||
raise RuntimeError(f"Context is not configured for '{key}'")
|
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
|
||||||
|
f"Maybe asyncio task factory is not configured!\n"
|
||||||
|
f"\t>>> from aiogram.utils import context\n"
|
||||||
|
f"\t>>> loop.set_task_factory(context.task_factory)")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,26 +9,45 @@ from ..utils.helper import Helper, HelperMode, Item
|
||||||
USER_STATE = 'USER_STATE'
|
USER_STATE = 'USER_STATE'
|
||||||
|
|
||||||
|
|
||||||
async def check_filter(filter_, args, kwargs):
|
async def check_filter(filter_, args):
|
||||||
|
"""
|
||||||
|
Helper for executing filter
|
||||||
|
|
||||||
|
:param filter_:
|
||||||
|
:param args:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if not callable(filter_):
|
if not callable(filter_):
|
||||||
raise TypeError('Filter must be callable and/or awaitable!')
|
raise TypeError('Filter must be callable and/or awaitable!')
|
||||||
|
|
||||||
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
|
||||||
return await filter_(*args, **kwargs)
|
return await filter_(*args)
|
||||||
else:
|
else:
|
||||||
return filter_(*args, **kwargs)
|
return filter_(*args)
|
||||||
|
|
||||||
|
|
||||||
async def check_filters(filters, args, kwargs):
|
async def check_filters(filters, args):
|
||||||
|
"""
|
||||||
|
Check list of filters
|
||||||
|
|
||||||
|
:param filters:
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if filters is not None:
|
if filters is not None:
|
||||||
for filter_ in filters:
|
for filter_ in filters:
|
||||||
f = await check_filter(filter_, args, kwargs)
|
f = await check_filter(filter_, args)
|
||||||
if not f:
|
if not f:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
|
"""
|
||||||
|
Base class for filters
|
||||||
|
"""
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.check(*args, **kwargs)
|
return self.check(*args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -37,6 +56,10 @@ class Filter:
|
||||||
|
|
||||||
|
|
||||||
class AsyncFilter(Filter):
|
class AsyncFilter(Filter):
|
||||||
|
"""
|
||||||
|
Base class for asynchronous filters
|
||||||
|
"""
|
||||||
|
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -48,23 +71,35 @@ class AsyncFilter(Filter):
|
||||||
|
|
||||||
|
|
||||||
class AnyFilter(AsyncFilter):
|
class AnyFilter(AsyncFilter):
|
||||||
|
"""
|
||||||
|
One filter from many
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *filters: callable):
|
def __init__(self, *filters: callable):
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
|
||||||
async def check(self, *args, **kwargs):
|
async def check(self, *args):
|
||||||
f = (check_filter(filter_, args, kwargs) for filter_ in self.filters)
|
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||||
return any(await asyncio.gather(*f))
|
return any(await asyncio.gather(*f))
|
||||||
|
|
||||||
|
|
||||||
class NotFilter(AsyncFilter):
|
class NotFilter(AsyncFilter):
|
||||||
|
"""
|
||||||
|
Reverse filter
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, filter_: callable):
|
def __init__(self, filter_: callable):
|
||||||
self.filter = filter_
|
self.filter = filter_
|
||||||
|
|
||||||
async def check(self, *args, **kwargs):
|
async def check(self, *args):
|
||||||
return not await check_filter(self.filter, args, kwargs)
|
return not await check_filter(self.filter, args)
|
||||||
|
|
||||||
|
|
||||||
class CommandsFilter(AsyncFilter):
|
class CommandsFilter(AsyncFilter):
|
||||||
|
"""
|
||||||
|
Check commands in message
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, commands):
|
def __init__(self, commands):
|
||||||
self.commands = commands
|
self.commands = commands
|
||||||
|
|
||||||
|
|
@ -85,6 +120,10 @@ class CommandsFilter(AsyncFilter):
|
||||||
|
|
||||||
|
|
||||||
class RegexpFilter(Filter):
|
class RegexpFilter(Filter):
|
||||||
|
"""
|
||||||
|
Regexp filter for messages
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, regexp):
|
def __init__(self, regexp):
|
||||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||||
|
|
||||||
|
|
@ -94,6 +133,10 @@ class RegexpFilter(Filter):
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeFilter(Filter):
|
class ContentTypeFilter(Filter):
|
||||||
|
"""
|
||||||
|
Check message content type
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, content_types):
|
def __init__(self, content_types):
|
||||||
self.content_types = content_types
|
self.content_types = content_types
|
||||||
|
|
||||||
|
|
@ -103,6 +146,10 @@ class ContentTypeFilter(Filter):
|
||||||
|
|
||||||
|
|
||||||
class CancelFilter(Filter):
|
class CancelFilter(Filter):
|
||||||
|
"""
|
||||||
|
Find cancel in message text
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, cancel_set=None):
|
def __init__(self, cancel_set=None):
|
||||||
if cancel_set is None:
|
if cancel_set is None:
|
||||||
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
||||||
|
|
@ -114,6 +161,10 @@ class CancelFilter(Filter):
|
||||||
|
|
||||||
|
|
||||||
class StateFilter(AsyncFilter):
|
class StateFilter(AsyncFilter):
|
||||||
|
"""
|
||||||
|
Check user state
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dispatcher, state):
|
def __init__(self, dispatcher, state):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
self.state = state
|
self.state = state
|
||||||
|
|
@ -137,6 +188,10 @@ class StateFilter(AsyncFilter):
|
||||||
|
|
||||||
|
|
||||||
class StatesListFilter(StateFilter):
|
class StatesListFilter(StateFilter):
|
||||||
|
"""
|
||||||
|
List of states
|
||||||
|
"""
|
||||||
|
|
||||||
async def check(self, obj):
|
async def check(self, obj):
|
||||||
chat, user = self.get_target(obj)
|
chat, user = self.get_target(obj)
|
||||||
|
|
||||||
|
|
@ -146,6 +201,10 @@ class StatesListFilter(StateFilter):
|
||||||
|
|
||||||
|
|
||||||
class ExceptionsFilter(Filter):
|
class ExceptionsFilter(Filter):
|
||||||
|
"""
|
||||||
|
Filter for exceptions
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, exception):
|
def __init__(self, exception):
|
||||||
self.exception = exception
|
self.exception = exception
|
||||||
|
|
||||||
|
|
@ -159,6 +218,14 @@ class ExceptionsFilter(Filter):
|
||||||
|
|
||||||
|
|
||||||
def generate_default_filters(dispatcher, *args, **kwargs):
|
def generate_default_filters(dispatcher, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Prepare filters
|
||||||
|
|
||||||
|
:param dispatcher:
|
||||||
|
:param args:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
filters_set = []
|
filters_set = []
|
||||||
|
|
||||||
for name, filter_ in kwargs.items():
|
for name, filter_ in kwargs.items():
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from aiogram.utils import context
|
||||||
from .filters import check_filters
|
from .filters import check_filters
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -10,11 +11,12 @@ class CancelHandler(BaseException):
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
def __init__(self, dispatcher, once=True):
|
def __init__(self, dispatcher, once=True, middleware_key=None):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
self.once = once
|
self.once = once
|
||||||
|
|
||||||
self.handlers = []
|
self.handlers = []
|
||||||
|
self.middleware_key = middleware_key
|
||||||
|
|
||||||
def register(self, handler, filters=None, index=None):
|
def register(self, handler, filters=None, index=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -48,20 +50,24 @@ class Handler:
|
||||||
return True
|
return True
|
||||||
raise ValueError('This handler is not registered!')
|
raise ValueError('This handler is not registered!')
|
||||||
|
|
||||||
async def notify(self, *args, **kwargs):
|
async def notify(self, *args):
|
||||||
"""
|
"""
|
||||||
Notify handlers
|
Notify handlers
|
||||||
|
|
||||||
:param args:
|
:param args:
|
||||||
:param kwargs:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
if self.middleware_key:
|
||||||
|
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
|
||||||
for filters, handler in self.handlers:
|
for filters, handler in self.handlers:
|
||||||
if await check_filters(filters, args, kwargs):
|
if await check_filters(filters, args):
|
||||||
try:
|
try:
|
||||||
response = await handler(*args, **kwargs)
|
if self.middleware_key:
|
||||||
|
context.set_value('handler', handler)
|
||||||
|
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
|
||||||
|
response = await handler(*args)
|
||||||
if results is not None:
|
if results is not None:
|
||||||
results.append(response)
|
results.append(response)
|
||||||
if self.once:
|
if self.once:
|
||||||
|
|
@ -70,5 +76,8 @@ class Handler:
|
||||||
continue
|
continue
|
||||||
except CancelHandler:
|
except CancelHandler:
|
||||||
break
|
break
|
||||||
|
if self.middleware_key:
|
||||||
|
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
|
||||||
|
args + (results,))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
101
aiogram/dispatcher/middlewares.py
Normal file
101
aiogram/dispatcher/middlewares.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
|
||||||
|
log = logging.getLogger('aiogram.Middleware')
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareManager:
|
||||||
|
"""
|
||||||
|
Middlewares manager. Works only with dispatcher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dispatcher):
|
||||||
|
"""
|
||||||
|
Init
|
||||||
|
|
||||||
|
:param dispatcher: instance of Dispatcher
|
||||||
|
"""
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
self.loop = dispatcher.loop
|
||||||
|
self.bot = dispatcher.bot
|
||||||
|
self.storage = dispatcher.storage
|
||||||
|
self.applications = []
|
||||||
|
|
||||||
|
def setup(self, middleware):
|
||||||
|
"""
|
||||||
|
Setup middleware
|
||||||
|
|
||||||
|
:param middleware:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert isinstance(middleware, BaseMiddleware)
|
||||||
|
if middleware.is_configured():
|
||||||
|
raise ValueError('That middleware is already used!')
|
||||||
|
|
||||||
|
self.applications.append(middleware)
|
||||||
|
middleware.setup(self)
|
||||||
|
log.debug(f"Loaded middleware '{middleware.__class__.__name__}'")
|
||||||
|
|
||||||
|
async def trigger(self, action: str, args: typing.Iterable):
|
||||||
|
"""
|
||||||
|
Call action to middlewares with args lilt.
|
||||||
|
|
||||||
|
:param action:
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for app in self.applications:
|
||||||
|
await app.trigger(action, args)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMiddleware:
|
||||||
|
"""
|
||||||
|
Base class for middleware.
|
||||||
|
|
||||||
|
All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._configured = False
|
||||||
|
self._manager = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def manager(self) -> MiddlewareManager:
|
||||||
|
"""
|
||||||
|
Instance of MiddlewareManager
|
||||||
|
"""
|
||||||
|
if self._manager is None:
|
||||||
|
raise RuntimeError('Middleware is not configured!')
|
||||||
|
return self._manager
|
||||||
|
|
||||||
|
def setup(self, manager):
|
||||||
|
"""
|
||||||
|
Mark middleware as configured
|
||||||
|
|
||||||
|
:param manager:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self._manager = manager
|
||||||
|
self._configured = True
|
||||||
|
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check middleware is configured
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self._configured
|
||||||
|
|
||||||
|
async def trigger(self, action, args):
|
||||||
|
"""
|
||||||
|
Trigger action.
|
||||||
|
|
||||||
|
:param action:
|
||||||
|
:param args:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
handler_name = f"on_{action}"
|
||||||
|
handler = getattr(self, handler_name, None)
|
||||||
|
if not handler:
|
||||||
|
return None
|
||||||
|
await handler(*args)
|
||||||
|
|
@ -1,5 +1,14 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
# Leak bucket
|
||||||
|
KEY = 'key'
|
||||||
|
LAST_CALL = 'called_at'
|
||||||
|
RATE_LIMIT = 'rate_limit'
|
||||||
|
RESULT = 'result'
|
||||||
|
EXCEEDED_COUNT = 'exceeded'
|
||||||
|
DELTA = 'delta'
|
||||||
|
THROTTLE_MANAGER = '$throttle_manager'
|
||||||
|
|
||||||
|
|
||||||
class BaseStorage:
|
class BaseStorage:
|
||||||
"""
|
"""
|
||||||
|
|
@ -184,6 +193,78 @@ class BaseStorage:
|
||||||
"""
|
"""
|
||||||
await self.reset_state(chat=chat, user=user, with_data=True)
|
await self.reset_state(chat=chat, user=user, with_data=True)
|
||||||
|
|
||||||
|
def has_bucket(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_bucket(self, *,
|
||||||
|
chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None,
|
||||||
|
default: typing.Optional[dict] = None) -> typing.Dict:
|
||||||
|
"""
|
||||||
|
Get state-data for user in chat. Return `default` if data is not presented in storage.
|
||||||
|
|
||||||
|
Chat or user is always required. If one of this is not presented,
|
||||||
|
need set the missing value based on the presented
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param default:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def set_bucket(self, *,
|
||||||
|
chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None):
|
||||||
|
"""
|
||||||
|
Set data for user in chat
|
||||||
|
|
||||||
|
Chat or user is always required. If one of this is not presented,
|
||||||
|
need set the missing value based on the presented
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param bucket:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def update_bucket(self, *,
|
||||||
|
chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None,
|
||||||
|
bucket: typing.Dict = None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Update data for user in chat
|
||||||
|
|
||||||
|
You can use data parameter or|and kwargs.
|
||||||
|
|
||||||
|
Chat or user is always required. If one of this is not presented,
|
||||||
|
need set the missing value based on the presented
|
||||||
|
|
||||||
|
:param bucket:
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reset_bucket(self, *,
|
||||||
|
chat: typing.Union[str, int, None] = None,
|
||||||
|
user: typing.Union[str, int, None] = None):
|
||||||
|
"""
|
||||||
|
Reset data dor user in chat.
|
||||||
|
|
||||||
|
Chat or user is always required. If one of this is not presented,
|
||||||
|
need set the missing value based on the presented
|
||||||
|
|
||||||
|
:param chat:
|
||||||
|
:param user:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self.set_data(chat=chat, user=user, data={})
|
||||||
|
|
||||||
|
|
||||||
class FSMContext:
|
class FSMContext:
|
||||||
def __init__(self, storage, chat, user):
|
def __init__(self, storage, chat, user):
|
||||||
|
|
|
||||||
|
|
@ -186,10 +186,15 @@ class WebhookRequestHandler(web.View):
|
||||||
dispatcher = self.get_dispatcher()
|
dispatcher = self.get_dispatcher()
|
||||||
loop = dispatcher.loop
|
loop = dispatcher.loop
|
||||||
|
|
||||||
results = task.result()
|
try:
|
||||||
response = self.get_response(results)
|
results = task.result()
|
||||||
if response is not None:
|
except Exception as e:
|
||||||
asyncio.ensure_future(response.execute_response(self.get_dispatcher().bot), loop=loop)
|
loop.create_task(
|
||||||
|
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
|
||||||
|
else:
|
||||||
|
response = self.get_response(results)
|
||||||
|
if response is not None:
|
||||||
|
asyncio.ensure_future(response.execute_response(dispatcher.bot), loop=loop)
|
||||||
|
|
||||||
def get_response(self, results):
|
def get_response(self, results):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -138,10 +138,8 @@ class TelegramObject(metaclass=MetaTelegramObject):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self):
|
def bot(self):
|
||||||
bot = get_value('bot')
|
from ..dispatcher import ctx
|
||||||
if bot is None:
|
return ctx.get_bot()
|
||||||
raise RuntimeError('Can not found bot instance in current context!')
|
|
||||||
return bot
|
|
||||||
|
|
||||||
def to_python(self) -> typing.Dict:
|
def to_python(self) -> typing.Dict:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,7 @@ class ChatActions(helper.Helper):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _do(cls, action: str, sleep=None):
|
async def _do(cls, action: str, sleep=None):
|
||||||
from aiogram.dispatcher.ctx import get_bot, get_chat
|
from ..dispatcher.ctx import get_bot, get_chat
|
||||||
await get_bot().send_chat_action(get_chat(), action)
|
await get_bot().send_chat_action(get_chat(), action)
|
||||||
if sleep:
|
if sleep:
|
||||||
await asyncio.sleep(sleep)
|
await asyncio.sleep(sleep)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from aiogram.utils import helper
|
|
||||||
from . import base
|
from . import base
|
||||||
from . import fields
|
from . import fields
|
||||||
from .user import User
|
from .user import User
|
||||||
|
from ..utils import helper
|
||||||
|
|
||||||
|
|
||||||
class ChatMember(base.TelegramObject):
|
class ChatMember(base.TelegramObject):
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,7 @@ class MediaGroup(base.TelegramObject):
|
||||||
"""
|
"""
|
||||||
Helper for sending media group
|
Helper for sending media group
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, medias: typing.Optional[typing.List[typing.Union[InputMedia, typing.Dict]]] = None):
|
def __init__(self, medias: typing.Optional[typing.List[typing.Union[InputMedia, typing.Dict]]] = None):
|
||||||
super(MediaGroup, self).__init__()
|
super(MediaGroup, self).__init__()
|
||||||
self.media = []
|
self.media = []
|
||||||
|
|
|
||||||
|
|
@ -31,4 +31,4 @@ class PreCheckoutQuery(base.TelegramObject):
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, type(self)):
|
if isinstance(other, type(self)):
|
||||||
return other.id == self.id
|
return other.id == self.id
|
||||||
return self.id == other
|
return self.id == other
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
|
||||||
del task._source_traceback[-1]
|
del task._source_traceback[-1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
task.context = asyncio.Task.current_task().context
|
task.context = asyncio.Task.current_task().context.copy()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
task.context = {CONFIGURED: True}
|
task.context = {CONFIGURED: True}
|
||||||
|
|
||||||
|
|
@ -114,3 +114,25 @@ def check_configured():
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return get_value(CONFIGURED)
|
return get_value(CONFIGURED)
|
||||||
|
|
||||||
|
|
||||||
|
class _Context:
|
||||||
|
"""
|
||||||
|
Other things for interactions with the execution context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return get_value(item)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
set_value(key, value)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del_value(key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_context():
|
||||||
|
return get_current_state()
|
||||||
|
|
||||||
|
|
||||||
|
context = _Context()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import time
|
||||||
|
|
||||||
_PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: ']
|
_PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: ']
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,3 +53,21 @@ class MigrateToChat(TelegramAPIError):
|
||||||
def __init__(self, chat_id):
|
def __init__(self, chat_id):
|
||||||
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
|
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
|
||||||
self.migrate_to_chat_id = chat_id
|
self.migrate_to_chat_id = chat_id
|
||||||
|
|
||||||
|
|
||||||
|
class Throttled(Exception):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT
|
||||||
|
self.key = kwargs.pop(KEY, '<None>')
|
||||||
|
self.called_at = kwargs.pop(LAST_CALL, time.time())
|
||||||
|
self.rate = kwargs.pop(RATE_LIMIT, None)
|
||||||
|
self.result = kwargs.pop(RESULT, False)
|
||||||
|
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
|
||||||
|
self.delta = kwargs.pop(DELTA, 0)
|
||||||
|
self.user = kwargs.pop('user', None)
|
||||||
|
self.chat = kwargs.pop('chat', None)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
|
||||||
|
f"exceeded: {self.exceeded_count}, " \
|
||||||
|
f"time delta: {round(self.delta, 3)} s)"
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ async def _shutdown(dispatcher: Dispatcher, callback=None):
|
||||||
|
|
||||||
if dispatcher.is_polling():
|
if dispatcher.is_polling():
|
||||||
dispatcher.stop_polling()
|
dispatcher.stop_polling()
|
||||||
|
# await dispatcher.wait_closed()
|
||||||
|
|
||||||
await dispatcher.storage.close()
|
await dispatcher.storage.close()
|
||||||
await dispatcher.storage.wait_closed()
|
await dispatcher.storage.wait_closed()
|
||||||
|
|
|
||||||
7
dev_requirements.txt
Normal file
7
dev_requirements.txt
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
-r requirements.txt
|
||||||
|
ujson
|
||||||
|
emoji
|
||||||
|
pytest
|
||||||
|
pytest-asyncio
|
||||||
|
uvloop
|
||||||
|
aioredis
|
||||||
123
examples/middleware_and_antiflood.py
Normal file
123
examples/middleware_and_antiflood.py
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from aiogram import Bot, types
|
||||||
|
from aiogram.contrib.fsm_storage.redis import RedisStorage2
|
||||||
|
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
|
||||||
|
from aiogram.dispatcher.middlewares import BaseMiddleware
|
||||||
|
from aiogram.utils import context, executor
|
||||||
|
from aiogram.utils.exceptions import Throttled
|
||||||
|
|
||||||
|
TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# In this example used Redis storage
|
||||||
|
storage = RedisStorage2(db=5)
|
||||||
|
|
||||||
|
bot = Bot(token=TOKEN, loop=loop)
|
||||||
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limit(limit: int, key=None):
|
||||||
|
"""
|
||||||
|
Decorator for configuring rate limit and key in different functions.
|
||||||
|
|
||||||
|
:param limit:
|
||||||
|
:param key:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
setattr(func, 'throttling_rate_limit', limit)
|
||||||
|
if key:
|
||||||
|
setattr(func, 'throttling_key', key)
|
||||||
|
return func
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class ThrottlingMiddleware(BaseMiddleware):
|
||||||
|
"""
|
||||||
|
Simple middleware
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
|
||||||
|
self.rate_limit = limit
|
||||||
|
self.prefix = key_prefix
|
||||||
|
super(ThrottlingMiddleware, self).__init__()
|
||||||
|
|
||||||
|
async def on_process_message(self, message: types.Message):
|
||||||
|
"""
|
||||||
|
That handler will be called when dispatcher receive message
|
||||||
|
|
||||||
|
:param message:
|
||||||
|
"""
|
||||||
|
# Get current handler
|
||||||
|
handler = context.get_value('handler')
|
||||||
|
|
||||||
|
# Get dispatcher from context
|
||||||
|
dispatcher = ctx.get_dispatcher()
|
||||||
|
|
||||||
|
# If handler was configured get rate limit and key from handler
|
||||||
|
if handler:
|
||||||
|
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
|
||||||
|
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||||
|
else:
|
||||||
|
limit = self.rate_limit
|
||||||
|
key = f"{self.prefix}_message"
|
||||||
|
|
||||||
|
# Use Dispatcher.throttle method.
|
||||||
|
try:
|
||||||
|
await dispatcher.throttle(key, rate=limit)
|
||||||
|
except Throttled as t:
|
||||||
|
# Execute action
|
||||||
|
await self.message_throttled(message, t)
|
||||||
|
|
||||||
|
# Cancel current handler
|
||||||
|
raise CancelHandler()
|
||||||
|
|
||||||
|
async def message_throttled(self, message: types.Message, throttled: Throttled):
|
||||||
|
"""
|
||||||
|
Notify user only on first exceed and notify about unlocking only on last exceed
|
||||||
|
|
||||||
|
:param message:
|
||||||
|
:param throttled:
|
||||||
|
"""
|
||||||
|
handler = context.get_value('handler')
|
||||||
|
dispatcher = ctx.get_dispatcher()
|
||||||
|
if handler:
|
||||||
|
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
|
||||||
|
else:
|
||||||
|
key = f"{self.prefix}_message"
|
||||||
|
|
||||||
|
# Calculate how many time left to the end of block.
|
||||||
|
delta = throttled.rate - throttled.delta
|
||||||
|
|
||||||
|
# Prevent flooding
|
||||||
|
if throttled.exceeded_count <= 2:
|
||||||
|
await message.reply('Too many requests! ')
|
||||||
|
|
||||||
|
# Sleep.
|
||||||
|
await asyncio.sleep(delta)
|
||||||
|
|
||||||
|
# Check lock status
|
||||||
|
thr = await dispatcher.check_key(key)
|
||||||
|
|
||||||
|
# If current message is not last with current key - do not send message
|
||||||
|
if thr.exceeded_count == throttled.exceeded_count:
|
||||||
|
await message.reply('Unlocked.')
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands=['start'])
|
||||||
|
@rate_limit(5, 'start') # is not required but with that you can configure throttling manager for current handler
|
||||||
|
async def cmd_test(message: types.Message):
|
||||||
|
# You can use that command every 5 seconds
|
||||||
|
await message.reply('Test passed! You can use that command every 5 seconds.')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Setup middleware
|
||||||
|
dp.middleware.setup(ThrottlingMiddleware())
|
||||||
|
|
||||||
|
# Start long-polling
|
||||||
|
executor.start_polling(dp, loop=loop)
|
||||||
43
examples/throtling_example.py
Normal file
43
examples/throtling_example.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
"""
|
||||||
|
Example for throttling manager.
|
||||||
|
|
||||||
|
You can use that for flood controlling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiogram import Bot, types
|
||||||
|
from aiogram.contrib.fsm_storage.memory import MemoryStorage
|
||||||
|
from aiogram.dispatcher import Dispatcher
|
||||||
|
from aiogram.utils.exceptions import Throttled
|
||||||
|
from aiogram.utils.executor import start_polling
|
||||||
|
|
||||||
|
API_TOKEN = 'BOT TOKEN HERE'
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
bot = Bot(token=API_TOKEN, loop=loop)
|
||||||
|
|
||||||
|
# Throttling manager does not working without Leaky Bucket.
|
||||||
|
# Then need to use storage's. For example use simple in-memory storage.
|
||||||
|
storage = MemoryStorage()
|
||||||
|
dp = Dispatcher(bot, storage=storage)
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message_handler(commands=['start', 'help'])
|
||||||
|
async def send_welcome(message: types.Message):
|
||||||
|
try:
|
||||||
|
# Execute throttling manager with rate-limit equals to 2 seconds for key "start"
|
||||||
|
await dp.throttle('start', rate=2)
|
||||||
|
except Throttled:
|
||||||
|
# If request is throttled the `Throttled` exception will be raised.
|
||||||
|
await message.reply('Too many requests!')
|
||||||
|
else:
|
||||||
|
# Otherwise do something.
|
||||||
|
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
start_polling(dp, loop=loop, skip_updates=True)
|
||||||
1
tests/conftest.py
Normal file
1
tests/conftest.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# pytest_plugins = "pytest_asyncio.plugin"
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
UPDATE = {
|
|
||||||
"update_id": 128526,
|
|
||||||
"message": {
|
|
||||||
"message_id": 11223,
|
|
||||||
"from": {
|
|
||||||
"id": 12345678,
|
|
||||||
"is_bot": False,
|
|
||||||
"first_name": "FirstName",
|
|
||||||
"last_name": "LastName",
|
|
||||||
"username": "username",
|
|
||||||
"language_code": "ru"
|
|
||||||
},
|
|
||||||
"chat": {
|
|
||||||
"id": 12345678,
|
|
||||||
"first_name": "FirstName",
|
|
||||||
"last_name": "LastName",
|
|
||||||
"username": "username",
|
|
||||||
"type": "private"
|
|
||||||
},
|
|
||||||
"date": 1508709711,
|
|
||||||
"text": "Hi, world!"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MESSAGE = UPDATE['message']
|
|
||||||
4
tests/test_bot.py
Normal file
4
tests/test_bot.py
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
import aiogram
|
||||||
|
|
||||||
|
# bot = aiogram.Bot('123456789:AABBCCDDEEFFaabbccddeeff-1234567890')
|
||||||
|
# TODO: mock for aiogram.bot.api.request and then test all AI methods.
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
import datetime
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from aiogram import types
|
|
||||||
from dataset import MESSAGE
|
|
||||||
|
|
||||||
|
|
||||||
class TestMessage(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.message = types.Message(**MESSAGE)
|
|
||||||
|
|
||||||
def test_update_id(self):
|
|
||||||
self.assertEqual(self.message.message_id, MESSAGE['message_id'], 'test')
|
|
||||||
self.assertEqual(self.message['message_id'], MESSAGE['message_id'])
|
|
||||||
|
|
||||||
def test_from(self):
|
|
||||||
self.assertIsInstance(self.message.from_user, types.User)
|
|
||||||
self.assertEqual(self.message.from_user, self.message['from'])
|
|
||||||
|
|
||||||
def test_chat(self):
|
|
||||||
self.assertIsInstance(self.message.chat, types.Chat)
|
|
||||||
self.assertEqual(self.message.chat, self.message['chat'])
|
|
||||||
|
|
||||||
def test_date(self):
|
|
||||||
self.assertIsInstance(self.message.date, datetime.datetime)
|
|
||||||
self.assertEqual(int(self.message.date.timestamp()), MESSAGE['date'])
|
|
||||||
self.assertEqual(self.message.date, self.message['date'])
|
|
||||||
|
|
||||||
def test_text(self):
|
|
||||||
self.assertEqual(self.message.text, MESSAGE['text'])
|
|
||||||
self.assertEqual(self.message['text'], MESSAGE['text'])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
0
tests/types/__init__.py
Normal file
0
tests/types/__init__.py
Normal file
75
tests/types/dataset.py
Normal file
75
tests/types/dataset.py
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
USER = {
|
||||||
|
"id": 12345678,
|
||||||
|
"is_bot": False,
|
||||||
|
"first_name": "FirstName",
|
||||||
|
"last_name": "LastName",
|
||||||
|
"username": "username",
|
||||||
|
"language_code": "ru-RU"
|
||||||
|
}
|
||||||
|
|
||||||
|
CHAT = {
|
||||||
|
"id": 12345678,
|
||||||
|
"first_name": "FirstName",
|
||||||
|
"last_name": "LastName",
|
||||||
|
"username": "username",
|
||||||
|
"type": "private"
|
||||||
|
}
|
||||||
|
|
||||||
|
MESSAGE = {
|
||||||
|
"message_id": 11223,
|
||||||
|
"from": USER,
|
||||||
|
"chat": CHAT,
|
||||||
|
"date": 1508709711,
|
||||||
|
"text": "Hi, world!"
|
||||||
|
}
|
||||||
|
|
||||||
|
DOCUMENT = {
|
||||||
|
"file_name": "test.docx",
|
||||||
|
"mime_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"file_id": "BQADAgADpgADy_JxS66XQTBRHFleAg",
|
||||||
|
"file_size": 21331
|
||||||
|
}
|
||||||
|
|
||||||
|
MESSAGE_WITH_DOCUMENT = {
|
||||||
|
"message_id": 12345,
|
||||||
|
"from": USER,
|
||||||
|
"chat": CHAT,
|
||||||
|
"date": 1508768012,
|
||||||
|
"document": DOCUMENT,
|
||||||
|
"caption": "doc description"
|
||||||
|
}
|
||||||
|
|
||||||
|
UPDATE = {
|
||||||
|
"update_id": 128526,
|
||||||
|
"message": MESSAGE
|
||||||
|
}
|
||||||
|
|
||||||
|
PHOTO = {
|
||||||
|
"file_id": "AgADBAADFak0G88YZAf8OAug7bHyS9x2ZxkABHVfpJywcloRAAGAAQABAg",
|
||||||
|
"file_size": 1101,
|
||||||
|
"width": 90,
|
||||||
|
"height": 51
|
||||||
|
}
|
||||||
|
|
||||||
|
ANIMATION = {
|
||||||
|
"file_name": "a9b0e0ca537aa344338f80978f0896b7.gif.mp4",
|
||||||
|
"mime_type": "video/mp4",
|
||||||
|
"thumb": PHOTO,
|
||||||
|
"file_id": "CgADBAAD4DUAAoceZAe2WiE9y0crrAI",
|
||||||
|
"file_size": 65837
|
||||||
|
}
|
||||||
|
|
||||||
|
GAME = {
|
||||||
|
"title": "Karate Kido",
|
||||||
|
"description": "No trees were harmed in the making of this game :)",
|
||||||
|
"photo": [PHOTO, PHOTO, PHOTO],
|
||||||
|
"animation": ANIMATION
|
||||||
|
}
|
||||||
|
|
||||||
|
MESSAGE_WITH_GAME = {
|
||||||
|
"message_id": 12345,
|
||||||
|
"from": USER,
|
||||||
|
"chat": CHAT,
|
||||||
|
"date": 1508824810,
|
||||||
|
"game": GAME
|
||||||
|
}
|
||||||
39
tests/types/test_animation.py
Normal file
39
tests/types/test_animation.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import ANIMATION
|
||||||
|
|
||||||
|
animation = types.Animation(**ANIMATION)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported = animation.to_python()
|
||||||
|
assert isinstance(exported, dict)
|
||||||
|
assert exported == ANIMATION
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_name():
|
||||||
|
assert isinstance(animation.file_name, str)
|
||||||
|
assert animation.file_name == ANIMATION['file_name']
|
||||||
|
|
||||||
|
|
||||||
|
def test_mime_type():
|
||||||
|
assert isinstance(animation.mime_type, str)
|
||||||
|
assert animation.mime_type == ANIMATION['mime_type']
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_id():
|
||||||
|
assert isinstance(animation.file_id, str)
|
||||||
|
# assert hash(animation) == ANIMATION['file_id']
|
||||||
|
assert animation.file_id == ANIMATION['file_id']
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_size():
|
||||||
|
assert isinstance(animation.file_size, int)
|
||||||
|
assert animation.file_size == ANIMATION['file_size']
|
||||||
|
|
||||||
|
|
||||||
|
def test_thumb():
|
||||||
|
assert isinstance(animation.thumb, types.PhotoSize)
|
||||||
|
assert animation.thumb.file_id == ANIMATION['thumb']['file_id']
|
||||||
|
assert animation.thumb.width == ANIMATION['thumb']['width']
|
||||||
|
assert animation.thumb.height == ANIMATION['thumb']['height']
|
||||||
|
assert animation.thumb.file_size == ANIMATION['thumb']['file_size']
|
||||||
61
tests/types/test_chat.py
Normal file
61
tests/types/test_chat.py
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import CHAT
|
||||||
|
|
||||||
|
chat = types.Chat(**CHAT)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported = chat.to_python()
|
||||||
|
assert isinstance(exported, dict)
|
||||||
|
assert exported == CHAT
|
||||||
|
|
||||||
|
|
||||||
|
def test_id():
|
||||||
|
assert isinstance(chat.id, int)
|
||||||
|
assert chat.id == CHAT['id']
|
||||||
|
assert hash(chat) == CHAT['id']
|
||||||
|
|
||||||
|
|
||||||
|
def test_name():
|
||||||
|
assert isinstance(chat.first_name, str)
|
||||||
|
assert chat.first_name == CHAT['first_name']
|
||||||
|
|
||||||
|
assert isinstance(chat.last_name, str)
|
||||||
|
assert chat.last_name == CHAT['last_name']
|
||||||
|
|
||||||
|
assert isinstance(chat.username, str)
|
||||||
|
assert chat.username == CHAT['username']
|
||||||
|
|
||||||
|
|
||||||
|
def test_type():
|
||||||
|
assert isinstance(chat.type, str)
|
||||||
|
assert chat.type == CHAT['type']
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_types():
|
||||||
|
assert types.ChatType.PRIVATE == 'private'
|
||||||
|
assert types.ChatType.GROUP == 'group'
|
||||||
|
assert types.ChatType.SUPER_GROUP == 'supergroup'
|
||||||
|
assert types.ChatType.CHANNEL == 'channel'
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_type_filters():
|
||||||
|
from . import test_message
|
||||||
|
assert types.ChatType.is_private(test_message.message)
|
||||||
|
assert not types.ChatType.is_group(test_message.message)
|
||||||
|
assert not types.ChatType.is_super_group(test_message.message)
|
||||||
|
assert not types.ChatType.is_group_or_super_group(test_message.message)
|
||||||
|
assert not types.ChatType.is_channel(test_message.message)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_actions():
|
||||||
|
assert types.ChatActions.TYPING == 'typing'
|
||||||
|
assert types.ChatActions.UPLOAD_PHOTO == 'upload_photo'
|
||||||
|
assert types.ChatActions.RECORD_VIDEO == 'record_video'
|
||||||
|
assert types.ChatActions.UPLOAD_VIDEO == 'upload_video'
|
||||||
|
assert types.ChatActions.RECORD_AUDIO == 'record_audio'
|
||||||
|
assert types.ChatActions.UPLOAD_AUDIO == 'upload_audio'
|
||||||
|
assert types.ChatActions.UPLOAD_DOCUMENT == 'upload_document'
|
||||||
|
assert types.ChatActions.FIND_LOCATION == 'find_location'
|
||||||
|
assert types.ChatActions.RECORD_VIDEO_NOTE == 'record_video_note'
|
||||||
|
assert types.ChatActions.UPLOAD_VIDEO_NOTE == 'upload_video_note'
|
||||||
35
tests/types/test_document.py
Normal file
35
tests/types/test_document.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import DOCUMENT
|
||||||
|
|
||||||
|
document = types.Document(**DOCUMENT)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported = document.to_python()
|
||||||
|
assert isinstance(exported, dict)
|
||||||
|
assert exported == DOCUMENT
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_name():
|
||||||
|
assert isinstance(document.file_name, str)
|
||||||
|
assert document.file_name == DOCUMENT['file_name']
|
||||||
|
|
||||||
|
|
||||||
|
def test_mime_type():
|
||||||
|
assert isinstance(document.mime_type, str)
|
||||||
|
assert document.mime_type == DOCUMENT['mime_type']
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_id():
|
||||||
|
assert isinstance(document.file_id, str)
|
||||||
|
# assert hash(document) == DOCUMENT['file_id']
|
||||||
|
assert document.file_id == DOCUMENT['file_id']
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_size():
|
||||||
|
assert isinstance(document.file_size, int)
|
||||||
|
assert document.file_size == DOCUMENT['file_size']
|
||||||
|
|
||||||
|
|
||||||
|
def test_thumb():
|
||||||
|
assert document.thumb is None
|
||||||
29
tests/types/test_game.py
Normal file
29
tests/types/test_game.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import GAME
|
||||||
|
|
||||||
|
game = types.Game(**GAME)
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported = game.to_python()
|
||||||
|
assert isinstance(exported, dict)
|
||||||
|
assert exported == GAME
|
||||||
|
|
||||||
|
|
||||||
|
def test_title():
|
||||||
|
assert isinstance(game.title, str)
|
||||||
|
assert game.title == GAME['title']
|
||||||
|
|
||||||
|
|
||||||
|
def test_description():
|
||||||
|
assert isinstance(game.description, str)
|
||||||
|
assert game.description == GAME['description']
|
||||||
|
|
||||||
|
|
||||||
|
def test_photo():
|
||||||
|
assert isinstance(game.photo, list)
|
||||||
|
assert len(game.photo) == len(GAME['photo'])
|
||||||
|
assert all(map(lambda t: isinstance(t, types.PhotoSize), game.photo))
|
||||||
|
|
||||||
|
|
||||||
|
def test_animation():
|
||||||
|
assert isinstance(game.animation, types.Animation)
|
||||||
39
tests/types/test_message.py
Normal file
39
tests/types/test_message.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import MESSAGE
|
||||||
|
|
||||||
|
message = types.Message(**MESSAGE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported_chat = message.to_python()
|
||||||
|
assert isinstance(exported_chat, dict)
|
||||||
|
assert exported_chat == MESSAGE
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_id():
|
||||||
|
assert hash(message) == MESSAGE['message_id']
|
||||||
|
assert message.message_id == MESSAGE['message_id']
|
||||||
|
assert message['message_id'] == MESSAGE['message_id']
|
||||||
|
|
||||||
|
|
||||||
|
def test_from():
|
||||||
|
assert isinstance(message.from_user, types.User)
|
||||||
|
assert message.from_user == message['from']
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat():
|
||||||
|
assert isinstance(message.chat, types.Chat)
|
||||||
|
assert message.chat == message['chat']
|
||||||
|
|
||||||
|
|
||||||
|
def test_date():
|
||||||
|
assert isinstance(message.date, datetime.datetime)
|
||||||
|
assert int(message.date.timestamp()) == MESSAGE['date']
|
||||||
|
assert message.date == message['date']
|
||||||
|
|
||||||
|
|
||||||
|
def test_text():
|
||||||
|
assert message.text == MESSAGE['text']
|
||||||
|
assert message['text'] == MESSAGE['text']
|
||||||
27
tests/types/test_photo.py
Normal file
27
tests/types/test_photo.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import PHOTO
|
||||||
|
|
||||||
|
photo = types.PhotoSize(**PHOTO)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported = photo.to_python()
|
||||||
|
assert isinstance(exported, dict)
|
||||||
|
assert exported == PHOTO
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_id():
|
||||||
|
assert isinstance(photo.file_id, str)
|
||||||
|
assert photo.file_id == PHOTO['file_id']
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_size():
|
||||||
|
assert isinstance(photo.file_size, int)
|
||||||
|
assert photo.file_size == PHOTO['file_size']
|
||||||
|
|
||||||
|
|
||||||
|
def test_size():
|
||||||
|
assert isinstance(photo.width, int)
|
||||||
|
assert isinstance(photo.height, int)
|
||||||
|
assert photo.width == PHOTO['width']
|
||||||
|
assert photo.height == PHOTO['height']
|
||||||
20
tests/types/test_update.py
Normal file
20
tests/types/test_update.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import UPDATE
|
||||||
|
|
||||||
|
update = types.Update(**UPDATE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported = update.to_python()
|
||||||
|
assert isinstance(exported, dict)
|
||||||
|
assert exported == UPDATE
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_id():
|
||||||
|
assert isinstance(update.update_id, int)
|
||||||
|
assert hash(update) == UPDATE['update_id']
|
||||||
|
assert update.update_id == UPDATE['update_id']
|
||||||
|
|
||||||
|
|
||||||
|
def test_message():
|
||||||
|
assert isinstance(update.message, types.Message)
|
||||||
48
tests/types/test_user.py
Normal file
48
tests/types/test_user.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
from babel import Locale
|
||||||
|
|
||||||
|
from aiogram import types
|
||||||
|
from .dataset import USER
|
||||||
|
|
||||||
|
user = types.User(**USER)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
exported = user.to_python()
|
||||||
|
assert isinstance(exported, dict)
|
||||||
|
assert exported == USER
|
||||||
|
|
||||||
|
|
||||||
|
def test_id():
|
||||||
|
assert isinstance(user.id, int)
|
||||||
|
assert user.id == USER['id']
|
||||||
|
assert hash(user) == USER['id']
|
||||||
|
|
||||||
|
|
||||||
|
def test_bot():
|
||||||
|
assert isinstance(user.is_bot, bool)
|
||||||
|
assert user.is_bot == USER['is_bot']
|
||||||
|
|
||||||
|
|
||||||
|
def test_name():
|
||||||
|
assert user.first_name == USER['first_name']
|
||||||
|
assert user.last_name == USER['last_name']
|
||||||
|
assert user.username == USER['username']
|
||||||
|
|
||||||
|
|
||||||
|
def test_language_code():
|
||||||
|
assert user.language_code == USER['language_code']
|
||||||
|
assert user.locale == Locale.parse(USER['language_code'], sep='-')
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_name():
|
||||||
|
assert user.full_name == f"{USER['first_name']} {USER['last_name']}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mention():
|
||||||
|
assert user.mention == f"@{USER['username']}"
|
||||||
|
assert user.get_mention('foo') == f"[foo](tg://user?id={USER['id']})"
|
||||||
|
assert user.get_mention('foo', as_html=True) == f"<a href=\"tg://user?id={USER['id']}\">foo</a>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_url():
|
||||||
|
assert user.url == f"tg://user?id={USER['id']}"
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
def out(*message, sep=' '):
|
|
||||||
print('Test', sep.join(message))
|
|
||||||
7
tox.ini
Normal file
7
tox.ini
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
[tox]
|
||||||
|
envlist = py36
|
||||||
|
|
||||||
|
[testenv]
|
||||||
|
deps = -rdev_requirements.txt
|
||||||
|
commands = pytest
|
||||||
|
skip_install = true
|
||||||
Loading…
Add table
Add a link
Reference in a new issue