Merge branch 'dev-1.x'

# Conflicts:
#	aiogram/__init__.py
This commit is contained in:
Alex Root Junior 2018-01-07 18:22:56 +02:00
commit e881064f6c
43 changed files with 1527 additions and 218 deletions

View file

@ -2,17 +2,18 @@ VENV_NAME := venv
PYTHON := $(VENV_NAME)/bin/python
AIOGRAM_VERSION := $(shell $(PYTHON) -c "import aiogram;print(aiogram.__version__)")
RM := rm -rf
mkvenv:
virtualenv $(VENV_NAME)
$(PYTHON) -m pip install -r requirements.txt
clean:
find . -name '*.pyc' -exec rm --force {} +
find . -name '*.pyo' -exec rm --force {} +
find . -name '*~' -exec rm --force {} +
rm --force --recursive build/
rm --force --recursive dist/
rm --force --recursive *.egg-info
find . -name '*.pyc' -exec $(RM) {} +
find . -name '*.pyo' -exec $(RM) {} +
find . -name '*~' -exec $(RM) {} +
find . -name '__pycache__' -exec $(RM) {} +
$(RM) build/ dist/ docs/build/ .tox/ .cache/ *.egg-info
tag:
@echo "Add tag: '$(AIOGRAM_VERSION)'"
@ -26,14 +27,23 @@ upload:
release:
make clean
make tag
make test
make build
make tag
@echo "Released aiogram $(AIOGRAM_VERSION)"
full-release:
make release
make upload
make install:
install:
$(PYTHON) setup.py install
test:
tox
summary:
cloc aiogram/ tests/ examples/ setup.py
docs: docs/source/*
cd docs && $(MAKE) html

View file

@ -10,7 +10,16 @@ except ImportError as e:
from .utils.versions import Stage, Version
VERSION = Version(1, 0, 2, stage=Stage.FINAL, build=0)
try:
import uvloop
except ImportError:
pass
else:
import asyncio
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
VERSION = Version(1, 0, 3, stage=Stage.FINAL, build=0)
API_VERSION = Version(3, 5)
__version__ = VERSION.version

View file

@ -124,7 +124,7 @@ def _compose_data(params=None, files=None):
return data
async def request(session, token, method, data=None, files=None, continue_retry=False, **kwargs) -> bool or dict:
async def request(session, token, method, data=None, files=None, **kwargs) -> bool or dict:
"""
Make request to API
@ -144,8 +144,6 @@ async def request(session, token, method, data=None, files=None, continue_retry=
:type data: :obj:`dict`
:param files: files
:type files: :obj:`dict`
:param continue_retry:
:type continue_retry: :obj:`dict`
:return: result
:rtype :obj:`bool` or :obj:`dict`
"""
@ -158,18 +156,13 @@ async def request(session, token, method, data=None, files=None, continue_retry=
return await _check_result(method, response)
except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
except exceptions.RetryAfter as e:
if continue_retry:
await asyncio.sleep(e.timeout)
return await request(session, token, method, data, files, **kwargs)
raise
class Methods(Helper):
"""
Helper for Telegram API Methods listed on https://core.telegram.org/bots/api
List is updated to Bot API 3.4
List is updated to Bot API 3.5
"""
mode = HelperMode.lowerCamelCase

View file

@ -18,7 +18,6 @@ class BaseBot:
loop: Optional[Union[asyncio.BaseEventLoop, asyncio.AbstractEventLoop]] = None,
connections_limit: Optional[base.Integer] = 10,
proxy: str = None, proxy_auth: Optional[aiohttp.BasicAuth] = None,
continue_retry: Optional[bool] = False,
validate_token: Optional[bool] = True):
"""
Instructions how to get Bot token is found here: https://core.telegram.org/bots#3-how-do-i-create-a-bot
@ -33,8 +32,6 @@ class BaseBot:
:type proxy: :obj:`str`
:param proxy_auth: Authentication information
:type proxy_auth: Optional :obj:`aiohttp.BasicAuth`
:param continue_retry: automatic retry sent request when flood control exceeded
:type continue_retry: :obj:`bool`
:param validate_token: Validate token.
:type validate_token: :obj:`bool`
:raise: when token is invalid throw an :obj:`aiogram.utils.exceptions.ValidationError`
@ -48,9 +45,6 @@ class BaseBot:
self.proxy = proxy
self.proxy_auth = proxy_auth
# Action on error
self.continue_retry = continue_retry
# Asyncio loop instance
if loop is None:
loop = asyncio.get_event_loop()
@ -68,8 +62,11 @@ class BaseBot:
self._data = {}
def __del__(self):
self.close()
def close(self):
"""
When bot object is deleting - need close all sessions
Close all client sessions
"""
for session in self._temp_sessions:
if not session.closed:
@ -77,17 +74,19 @@ class BaseBot:
if self.session and not self.session.closed:
self.session.close()
def create_temp_session(self, limit: int = 1) -> aiohttp.ClientSession:
def create_temp_session(self, limit: int = 1, force_close: bool = False) -> aiohttp.ClientSession:
"""
Create temporary session
:param limit: Limit of connections
:type limit: :obj:`int`
:param force_close: Set to True to force close and do reconnect after each request (and between redirects).
:type force_close: :obj:`bool`
:return: New session
:rtype: :obj:`aiohttp.TCPConnector`
"""
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=limit, force_close=True),
connector=aiohttp.TCPConnector(limit=limit, force_close=force_close),
loop=self.loop, json_serialize=json.dumps)
self._temp_sessions.append(session)
return session
@ -123,8 +122,7 @@ class BaseBot:
:raise: :obj:`aiogram.exceptions.TelegramApiError`
"""
return await api.request(self.session, self.__token, method, data, files,
proxy=self.proxy, proxy_auth=self.proxy_auth,
continue_retry=self.continue_retry)
proxy=self.proxy, proxy_auth=self.proxy_auth)
async def download_file(self, file_path: base.String,
destination: Optional[base.InputFile] = None,

View file

@ -31,7 +31,8 @@ class Bot(BaseBot):
delattr(self, '_me')
async def download_file_by_id(self, file_id: base.String, destination=None,
timeout: base.Integer=30, chunk_size: base.Integer=65536, seek: base.Boolean=True):
timeout: base.Integer = 30, chunk_size: base.Integer = 65536,
seek: base.Boolean = True):
"""
Download file by file_id to destination

View file

@ -30,7 +30,7 @@ class MemoryStorage(BaseStorage):
chat_id = str(chat_id)
user_id = str(user_id)
if user_id not in self.data[chat_id]:
self.data[chat_id][user_id] = {'state': None, 'data': {}}
self.data[chat_id][user_id] = {'state': None, 'data': {}, 'bucket': {}}
return self.data[chat_id][user_id]
async def get_state(self, *,
@ -82,3 +82,27 @@ class MemoryStorage(BaseStorage):
await self.set_state(chat=chat, user=user, state=None)
if with_data:
await self.set_data(chat=chat, user=user, data={})
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
user = self._get_user(chat, user)
return user['bucket']
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
user = self._get_user(chat, user)
user['bucket'] = bucket
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
chat, user = self.check_address(chat=chat, user=user)
user = self._get_user(chat, user)
if bucket is None:
bucket = []
user['bucket'].update(bucket, **kwargs)

View file

@ -3,6 +3,7 @@ This module has redis storage for finite-state machine based on `aioredis <https
"""
import asyncio
import logging
import typing
import aioredis
@ -10,6 +11,10 @@ import aioredis
from ...dispatcher.storage import BaseStorage
from ...utils import json
STATE_KEY = 'state'
STATE_DATA_KEY = 'data'
STATE_BUCKET_KEY = 'bucket'
class RedisStorage(BaseStorage):
"""
@ -90,10 +95,11 @@ class RedisStorage(BaseStorage):
return json.loads(data)
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state=None, data=None):
state=None, data=None, bucket=None):
"""
Write record to storage
:param bucket:
:param chat:
:param user:
:param state:
@ -102,11 +108,13 @@ class RedisStorage(BaseStorage):
"""
if data is None:
data = {}
if bucket is None:
bucket = {}
chat, user = self.check_address(chat=chat, user=user)
addr = f"fsm:{chat}:{user}"
record = {'state': state, 'data': data}
record = {'state': state, 'data': data, 'bucket': bucket}
conn = await self.redis
await conn.execute('SET', addr, json.dumps(record))
@ -168,3 +176,220 @@ class RedisStorage(BaseStorage):
else:
keys = await conn.execute('KEYS', 'fsm:*')
conn.execute('DEL', *keys)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
record = await self.get_record(chat=chat, user=user)
return record.get('bucket', {})
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
record = await self.get_record(chat=chat, user=user)
await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
record = await self.get_record(chat=chat, user=user)
record_bucket = record.get('bucket', {})
record_bucket.update(bucket, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
class RedisStorage2(BaseStorage):
"""
Busted Redis-base storage for FSM.
Works with Redis connection pool and customizable keys prefix.
Usage:
.. code-block:: python3
storage = RedisStorage('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key')
dp = Dispatcher(bot, storage=storage)
And need to close Redis connection when shutdown
.. code-block:: python3
await dp.storage.close()
await dp.storage.wait_closed()
"""
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None,
pool_size=10, loop=None, prefix='fsm', **kwargs):
self._host = host
self._port = port
self._db = db
self._password = password
self._ssl = ssl
self._pool_size = pool_size
self._loop = loop or asyncio.get_event_loop()
self._kwargs = kwargs
self._prefix = (prefix,)
self._redis: aioredis.RedisConnection = None
self._connection_lock = asyncio.Lock(loop=self._loop)
@property
async def redis(self) -> aioredis.Redis:
"""
Get Redis connection
This property is awaitable.
"""
# Use thread-safe asyncio Lock because this method without that is not safe
async with self._connection_lock:
if self._redis is None:
self._redis = await aioredis.create_redis_pool((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
minsize=1, maxsize=self._pool_size,
loop=self._loop, **self._kwargs)
return self._redis
def generate_key(self, *parts):
return ':'.join(self._prefix + tuple(map(str, parts)))
async def close(self):
async with self._connection_lock:
if self._redis and not self._redis.closed:
self._redis.close()
del self._redis
self._redis = None
async def wait_closed(self):
async with self._connection_lock:
if self._redis:
return await self._redis.wait_closed()
return True
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis
return await redis.get(key, encoding='utf8') or None
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self.redis
raw_result = await redis.get(key, encoding='utf8')
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis
if state is None:
await redis.delete(key)
else:
await redis.set(key, state)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self.redis
await redis.set(key, json.dumps(data))
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
temp_data = await self.get_data(chat=chat, user=user, default={})
temp_data.update(data, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_data)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self.redis
raw_result = await redis.get(key, encoding='utf8')
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self.redis
await redis.set(key, json.dumps(bucket))
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
temp_bucket = await self.get_data(chat=chat, user=user)
temp_bucket.update(bucket, **kwargs)
await self.set_data(chat=chat, user=user, data=temp_bucket)
async def reset_all(self, full=True):
"""
Reset states in DB
:param full: clean DB or clean only states
:return:
"""
conn = await self.redis
if full:
conn.flushdb()
else:
keys = await conn.keys(self.generate_key('*'))
conn.delete(*keys)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
"""
Get list of all stored chat's and user's
:return: list of tuples where first element is chat id and second is user id
"""
conn = await self.redis
result = []
keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8')
for item in keys:
*_, chat, user, _ = item.split(':')
result.append((chat, user))
return result
async def import_redis1(self, redis1):
await migrate_redis1_to_redis2(redis1, self)
async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2):
"""
Helper for migrating from RedisStorage to RedisStorage2
:param storage1: instance of RedisStorage
:param storage2: instance of RedisStorage2
:return:
"""
assert isinstance(storage1, RedisStorage)
assert isinstance(storage2, RedisStorage2)
log = logging.getLogger('aiogram.RedisStorage')
for chat, user in await storage1.get_states_list():
state = await storage1.get_state(chat=chat, user=user)
await storage2.set_state(chat=chat, user=user, state=state)
data = await storage1.get_data(chat=chat, user=user)
await storage2.set_data(chat=chat, user=user, data=data)
bucket = await storage1.get_bucket(chat=chat, user=user)
await storage2.set_bucket(chat=chat, user=user, bucket=bucket)
log.info(f"Migrated user {user} in chat {chat}")

View file

View file

@ -0,0 +1,132 @@
import logging
import time
from aiogram import types
from aiogram.dispatcher.middlewares import BaseMiddleware
HANDLED_STR = ['Unhandled', 'Handled']
class LoggingMiddleware(BaseMiddleware):
def __init__(self, logger=__name__):
if not isinstance(logger, logging.Logger):
logger = logging.getLogger(logger)
self.logger = logger
super(LoggingMiddleware, self).__init__()
def check_timeout(self, obj):
start = obj.conf.get('_start', None)
if start:
del obj.conf['_start']
return round((time.time() - start) * 1000)
return -1
async def on_pre_process_update(self, update: types.Update):
update.conf['_start'] = time.time()
self.logger.debug(f"Received update [ID:{update.update_id}]")
async def on_post_process_update(self, update: types.Update, result):
timeout = self.check_timeout(update)
if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
async def on_pre_process_message(self, message: types.Message):
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_post_process_message(self, message: types.Message, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
async def on_pre_process_edited_message(self, edited_message):
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_post_process_edited_message(self, edited_message, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
async def on_pre_process_channel_post(self, channel_post: types.Message):
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
f"in channel [ID:{channel_post.chat.id}]")
async def on_post_process_channel_post(self, channel_post: types.Message, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"channel post [ID:{channel_post.message_id}] "
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message):
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery):
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery):
if callback_query.message:
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
f"from user [ID:{callback_query.message.from_user.id}]")
else:
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
async def on_post_process_callback_query(self, callback_query, results):
if callback_query.message:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
f"from user [ID:{callback_query.message.from_user.id}]")
else:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery):
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
async def on_post_process_shipping_query(self, shipping_query, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_pre_process_error(self, dispatcher, update, error):
timeout = self.check_timeout(update)
if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [failed] (in {timeout} ms)")

View file

@ -1,20 +1,21 @@
import asyncio
import functools
import itertools
import logging
import typing
import time
import typing
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, USER_STATE, \
generate_default_filters
from .handler import Handler
from .storage import BaseStorage, DisabledStorage, FSMContext
from .handler import CancelHandler, Handler, SkipHandler
from .middlewares import MiddlewareManager
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, LAST_CALL, RATE_LIMIT, RESULT
from .webhook import BaseResponse
from ..bot import Bot
from ..types.message import ContentType
from ..utils import context
from ..utils.deprecated import deprecated
from ..utils.exceptions import NetworkError, TelegramAPIError
from ..utils.exceptions import NetworkError, TelegramAPIError, Throttled
log = logging.getLogger(__name__)
@ -22,6 +23,8 @@ MODE = 'MODE'
LONG_POLLING = 'long-polling'
UPDATE_OBJECT = 'update_object'
DEFAULT_RATE_LIMIT = .1
class Dispatcher:
"""
@ -33,7 +36,9 @@ class Dispatcher:
"""
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
run_tasks_by_default: bool = False):
run_tasks_by_default: bool = False,
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
if loop is None:
loop = bot.loop
if storage is None:
@ -44,27 +49,33 @@ class Dispatcher:
self.storage = storage
self.run_tasks_by_default = run_tasks_by_default
self.throttling_rate_limit = throttling_rate_limit
self.no_throttle_error = no_throttle_error
self.last_update_id = 0
self.updates_handler = Handler(self)
self.message_handlers = Handler(self)
self.edited_message_handlers = Handler(self)
self.channel_post_handlers = Handler(self)
self.edited_channel_post_handlers = Handler(self)
self.inline_query_handlers = Handler(self)
self.chosen_inline_result_handlers = Handler(self)
self.callback_query_handlers = Handler(self)
self.shipping_query_handlers = Handler(self)
self.pre_checkout_query_handlers = Handler(self)
self.updates_handler = Handler(self, middleware_key='update')
self.message_handlers = Handler(self, middleware_key='message')
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
self.channel_post_handlers = Handler(self, middleware_key='channel_post')
self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
self.inline_query_handlers = Handler(self, middleware_key='inline_query')
self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
self.callback_query_handlers = Handler(self, middleware_key='callback_query')
self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
self.errors_handlers = Handler(self, once=False, middleware_key='error')
self.middleware = MiddlewareManager(self)
self.updates_handler.register(self.process_update)
self.errors_handlers = Handler(self, once=False)
self._polling = False
self._closed = True
self._close_waiter = loop.create_future()
def __del__(self):
self._polling = False
self.stop_polling()
@property
def data(self):
@ -105,7 +116,7 @@ class Dispatcher:
"""
tasks = []
for update in updates:
tasks.append(self.process_update(update))
tasks.append(self.updates_handler.notify(update))
return await asyncio.gather(*tasks)
async def process_update(self, update):
@ -115,72 +126,59 @@ class Dispatcher:
:param update:
:return:
"""
start = time.time()
success = True
self.last_update_id = update.update_id
context.set_value(UPDATE_OBJECT, update)
try:
self.last_update_id = update.update_id
has_context = context.check_configured()
if has_context:
context.set_value(UPDATE_OBJECT, update)
if update.message:
if has_context:
state = await self.storage.get_state(chat=update.message.chat.id,
user=update.message.from_user.id)
context.update_state(chat=update.message.chat.id,
user=update.message.from_user.id,
state=state)
state = await self.storage.get_state(chat=update.message.chat.id,
user=update.message.from_user.id)
context.update_state(chat=update.message.chat.id,
user=update.message.from_user.id,
state=state)
return await self.message_handlers.notify(update.message)
if update.edited_message:
if has_context:
state = await self.storage.get_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id)
context.update_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id,
state=state)
state = await self.storage.get_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id)
context.update_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id,
state=state)
return await self.edited_message_handlers.notify(update.edited_message)
if update.channel_post:
if has_context:
state = await self.storage.get_state(chat=update.channel_post.chat.id)
context.update_state(chat=update.channel_post.chat.id,
state=state)
state = await self.storage.get_state(chat=update.channel_post.chat.id)
context.update_state(chat=update.channel_post.chat.id,
state=state)
return await self.channel_post_handlers.notify(update.channel_post)
if update.edited_channel_post:
if has_context:
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
context.update_state(chat=update.edited_channel_post.chat.id,
state=state)
state = await self.storage.get_state(chat=update.edited_channel_post.chat.id)
context.update_state(chat=update.edited_channel_post.chat.id,
state=state)
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
if update.inline_query:
if has_context:
state = await self.storage.get_state(user=update.inline_query.from_user.id)
context.update_state(user=update.inline_query.from_user.id,
state=state)
state = await self.storage.get_state(user=update.inline_query.from_user.id)
context.update_state(user=update.inline_query.from_user.id,
state=state)
return await self.inline_query_handlers.notify(update.inline_query)
if update.chosen_inline_result:
if has_context:
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
context.update_state(user=update.chosen_inline_result.from_user.id,
state=state)
state = await self.storage.get_state(user=update.chosen_inline_result.from_user.id)
context.update_state(user=update.chosen_inline_result.from_user.id,
state=state)
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
if update.callback_query:
if has_context:
state = await self.storage.get_state(chat=update.callback_query.message.chat.id,
user=update.callback_query.from_user.id)
context.update_state(user=update.callback_query.from_user.id,
state=state)
state = await self.storage.get_state(
chat=update.callback_query.message.chat.id if update.callback_query.message else None,
user=update.callback_query.from_user.id)
context.update_state(user=update.callback_query.from_user.id,
state=state)
return await self.callback_query_handlers.notify(update.callback_query)
if update.shipping_query:
if has_context:
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
context.update_state(user=update.shipping_query.from_user.id,
state=state)
state = await self.storage.get_state(user=update.shipping_query.from_user.id)
context.update_state(user=update.shipping_query.from_user.id,
state=state)
return await self.shipping_query_handlers.notify(update.shipping_query)
if update.pre_checkout_query:
if has_context:
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
context.update_state(user=update.pre_checkout_query.from_user.id,
state=state)
state = await self.storage.get_state(user=update.pre_checkout_query.from_user.id)
context.update_state(user=update.pre_checkout_query.from_user.id,
state=state)
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
except Exception as e:
success = False
@ -188,10 +186,6 @@ class Dispatcher:
if err:
return err
raise
finally:
log.info(f"Process update [ID:{update.update_id}]: "
f"{['failed', 'success'][success]} "
f"(in {round((time.time() - start) * 1000)} ms)")
async def reset_webhook(self, check=True) -> bool:
"""
@ -244,24 +238,26 @@ class Dispatcher:
self._polling = True
offset = None
while self._polling:
try:
updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
except NetworkError:
log.exception('Cause exception while getting updates.')
await asyncio.sleep(15)
continue
try:
while self._polling:
try:
updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
except NetworkError:
log.exception('Cause exception while getting updates.')
await asyncio.sleep(15)
continue
if updates:
log.debug(f"Received {len(updates)} updates.")
offset = updates[-1].update_id + 1
if updates:
log.debug(f"Received {len(updates)} updates.")
offset = updates[-1].update_id + 1
self.loop.create_task(self._process_polling_updates(updates))
self.loop.create_task(self._process_polling_updates(updates))
if relax:
await asyncio.sleep(relax)
log.warning('Polling is stopped.')
if relax:
await asyncio.sleep(relax)
finally:
self._close_waiter.set_result(None)
log.warning('Polling is stopped.')
async def _process_polling_updates(self, updates):
"""
@ -270,8 +266,8 @@ class Dispatcher:
:param updates: list of updates.
"""
need_to_call = []
for response in await self.process_updates(updates):
for response in response:
for responses in itertools.chain.from_iterable(await self.process_updates(updates)):
for response in responses:
if not isinstance(response, BaseResponse):
continue
need_to_call.append(response.execute_response(self.bot))
@ -288,12 +284,21 @@ class Dispatcher:
def stop_polling(self):
"""
Break long-polling process.
:return:
"""
if self._polling:
log.info('Stop polling.')
log.info('Stop polling...')
self._polling = False
async def wait_closed(self):
"""
Wait closing the long polling
:return:
"""
await asyncio.shield(self._close_waiter, loop=self.loop)
@deprecated('The old method was renamed to `is_polling`')
def is_pooling(self):
return self.is_polling()
@ -897,7 +902,8 @@ class Dispatcher:
"""
def decorator(callback):
self.register_errors_handler(callback, func=func, exception=exception)
self.register_errors_handler(self._wrap_async_task(callback, run_task),
func=func, exception=exception)
return callback
return decorator
@ -929,6 +935,109 @@ class Dispatcher:
return FSMContext(storage=self.storage, chat=chat, user=user)
async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool:
"""
Execute throttling manager.
Return True limit is not exceeded otherwise raise ThrottleError or return False
:param key: key in storage
:param rate: limit (by default is equals with default rate limit)
:param user: user id
:param chat: chat id
:param no_error: return boolean value instead of raising error
:return: bool
"""
if not self.storage.has_bucket():
raise RuntimeError('This storage does not provide Leaky Bucket')
if no_error is None:
no_error = self.no_throttle_error
if rate is None:
rate = self.throttling_rate_limit
if user is None and chat is None:
from . import ctx
user = ctx.get_user()
chat = ctx.get_chat()
# Detect current time
now = time.time()
bucket = await self.storage.get_bucket(chat=chat, user=user)
# Fix bucket
if bucket is None:
bucket = {key: {}}
if key not in bucket:
bucket[key] = {}
data = bucket[key]
# Calculate
called = data.get(LAST_CALL, now)
delta = now - called
result = delta >= rate or delta <= 0
# Save results
data[RESULT] = result
data[RATE_LIMIT] = rate
data[LAST_CALL] = now
data[DELTA] = delta
if not result:
data[EXCEEDED_COUNT] += 1
else:
data[EXCEEDED_COUNT] = 1
bucket[key].update(data)
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
if not result and not no_error:
# Raise if that is allowed
raise Throttled(key=key, chat=chat, user=user, **data)
return result
async def check_key(self, key, chat=None, user=None):
"""
Get information about key in bucket
:param key:
:param chat:
:param user:
:return:
"""
if not self.storage.has_bucket():
raise RuntimeError('This storage does not provide Leaky Bucket')
if user is None and chat is None:
from . import ctx
user = ctx.get_user()
chat = ctx.get_chat()
bucket = await self.storage.get_bucket(chat=chat, user=user)
data = bucket.get(key, {})
return Throttled(key=key, chat=chat, user=user, **data)
async def release_key(self, key, chat=None, user=None):
"""
Release blocked key
:param key:
:param chat:
:param user:
:return:
"""
if not self.storage.has_bucket():
raise RuntimeError('This storage does not provide Leaky Bucket')
if user is None and chat is None:
from . import ctx
user = ctx.get_user()
chat = ctx.get_chat()
bucket = await self.storage.get_bucket(chat=chat, user=user)
if bucket and key in bucket:
del bucket['key']
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
return True
return False
def async_task(self, func):
"""
Execute handler as task and return None.
@ -947,10 +1056,14 @@ class Dispatcher:
"""
def process_response(task):
response = task.result()
if isinstance(response, BaseResponse):
self.loop.create_task(response.execute_response(self.bot))
try:
response = task.result()
except Exception as e:
self.loop.create_task(
self.errors_handlers.notify(self, task.context.get(UPDATE_OBJECT, None), e))
else:
if isinstance(response, BaseResponse):
self.loop.create_task(response.execute_response(self.bot))
@functools.wraps(func)
async def wrapper(*args, **kwargs):

View file

@ -7,7 +7,10 @@ from ..utils import context
def _get(key, default=None, no_error=False):
result = context.get_value(key, default)
if not no_error and result is None:
raise RuntimeError(f"Context is not configured for '{key}'")
raise RuntimeError(f"Key '{key}' does not exist in the current execution context!\n"
f"Maybe asyncio task factory is not configured!\n"
f"\t>>> from aiogram.utils import context\n"
f"\t>>> loop.set_task_factory(context.task_factory)")
return result

View file

@ -9,26 +9,45 @@ from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE'
async def check_filter(filter_, args, kwargs):
async def check_filter(filter_, args):
"""
Helper for executing filter
:param filter_:
:param args:
:param kwargs:
:return:
"""
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
if inspect.isawaitable(filter_) or inspect.iscoroutinefunction(filter_):
return await filter_(*args, **kwargs)
return await filter_(*args)
else:
return filter_(*args, **kwargs)
return filter_(*args)
async def check_filters(filters, args, kwargs):
async def check_filters(filters, args):
"""
Check list of filters
:param filters:
:param args:
:return:
"""
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args, kwargs)
f = await check_filter(filter_, args)
if not f:
return False
return True
class Filter:
"""
Base class for filters
"""
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
@ -37,6 +56,10 @@ class Filter:
class AsyncFilter(Filter):
"""
Base class for asynchronous filters
"""
def __aiter__(self):
return None
@ -48,23 +71,35 @@ class AsyncFilter(Filter):
class AnyFilter(AsyncFilter):
"""
One filter from many
"""
def __init__(self, *filters: callable):
self.filters = filters
async def check(self, *args, **kwargs):
f = (check_filter(filter_, args, kwargs) for filter_ in self.filters)
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
return any(await asyncio.gather(*f))
class NotFilter(AsyncFilter):
"""
Reverse filter
"""
def __init__(self, filter_: callable):
self.filter = filter_
async def check(self, *args, **kwargs):
return not await check_filter(self.filter, args, kwargs)
async def check(self, *args):
return not await check_filter(self.filter, args)
class CommandsFilter(AsyncFilter):
"""
Check commands in message
"""
def __init__(self, commands):
self.commands = commands
@ -85,6 +120,10 @@ class CommandsFilter(AsyncFilter):
class RegexpFilter(Filter):
"""
Regexp filter for messages
"""
def __init__(self, regexp):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
@ -94,6 +133,10 @@ class RegexpFilter(Filter):
class ContentTypeFilter(Filter):
"""
Check message content type
"""
def __init__(self, content_types):
self.content_types = content_types
@ -103,6 +146,10 @@ class ContentTypeFilter(Filter):
class CancelFilter(Filter):
"""
Find cancel in message text
"""
def __init__(self, cancel_set=None):
if cancel_set is None:
cancel_set = ['/cancel', 'cancel', 'cancel.']
@ -114,6 +161,10 @@ class CancelFilter(Filter):
class StateFilter(AsyncFilter):
"""
Check user state
"""
def __init__(self, dispatcher, state):
self.dispatcher = dispatcher
self.state = state
@ -137,6 +188,10 @@ class StateFilter(AsyncFilter):
class StatesListFilter(StateFilter):
"""
List of states
"""
async def check(self, obj):
chat, user = self.get_target(obj)
@ -146,6 +201,10 @@ class StatesListFilter(StateFilter):
class ExceptionsFilter(Filter):
"""
Filter for exceptions
"""
def __init__(self, exception):
self.exception = exception
@ -159,6 +218,14 @@ class ExceptionsFilter(Filter):
def generate_default_filters(dispatcher, *args, **kwargs):
"""
Prepare filters
:param dispatcher:
:param args:
:param kwargs:
:return:
"""
filters_set = []
for name, filter_ in kwargs.items():

View file

@ -1,3 +1,4 @@
from aiogram.utils import context
from .filters import check_filters
@ -10,11 +11,12 @@ class CancelHandler(BaseException):
class Handler:
def __init__(self, dispatcher, once=True):
def __init__(self, dispatcher, once=True, middleware_key=None):
self.dispatcher = dispatcher
self.once = once
self.handlers = []
self.middleware_key = middleware_key
def register(self, handler, filters=None, index=None):
"""
@ -48,20 +50,24 @@ class Handler:
return True
raise ValueError('This handler is not registered!')
async def notify(self, *args, **kwargs):
async def notify(self, *args):
"""
Notify handlers
:param args:
:param kwargs:
:return:
"""
results = []
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args)
for filters, handler in self.handlers:
if await check_filters(filters, args, kwargs):
if await check_filters(filters, args):
try:
response = await handler(*args, **kwargs)
if self.middleware_key:
context.set_value('handler', handler)
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args)
response = await handler(*args)
if results is not None:
results.append(response)
if self.once:
@ -70,5 +76,8 @@ class Handler:
continue
except CancelHandler:
break
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results,))
return results

View file

@ -0,0 +1,101 @@
import logging
import typing
log = logging.getLogger('aiogram.Middleware')
class MiddlewareManager:
"""
Middlewares manager. Works only with dispatcher.
"""
def __init__(self, dispatcher):
"""
Init
:param dispatcher: instance of Dispatcher
"""
self.dispatcher = dispatcher
self.loop = dispatcher.loop
self.bot = dispatcher.bot
self.storage = dispatcher.storage
self.applications = []
def setup(self, middleware):
"""
Setup middleware
:param middleware:
:return:
"""
assert isinstance(middleware, BaseMiddleware)
if middleware.is_configured():
raise ValueError('That middleware is already used!')
self.applications.append(middleware)
middleware.setup(self)
log.debug(f"Loaded middleware '{middleware.__class__.__name__}'")
async def trigger(self, action: str, args: typing.Iterable):
"""
Call action to middlewares with args lilt.
:param action:
:param args:
:return:
"""
for app in self.applications:
await app.trigger(action, args)
class BaseMiddleware:
"""
Base class for middleware.
All methods on the middle always must be coroutines and name starts with "on_" like "on_process_message".
"""
def __init__(self):
self._configured = False
self._manager = None
@property
def manager(self) -> MiddlewareManager:
"""
Instance of MiddlewareManager
"""
if self._manager is None:
raise RuntimeError('Middleware is not configured!')
return self._manager
def setup(self, manager):
"""
Mark middleware as configured
:param manager:
:return:
"""
self._manager = manager
self._configured = True
def is_configured(self) -> bool:
"""
Check middleware is configured
:return:
"""
return self._configured
async def trigger(self, action, args):
"""
Trigger action.
:param action:
:param args:
:return:
"""
handler_name = f"on_{action}"
handler = getattr(self, handler_name, None)
if not handler:
return None
await handler(*args)

View file

@ -1,5 +1,14 @@
import typing
# Leak bucket
KEY = 'key'
LAST_CALL = 'called_at'
RATE_LIMIT = 'rate_limit'
RESULT = 'result'
EXCEEDED_COUNT = 'exceeded'
DELTA = 'delta'
THROTTLE_MANAGER = '$throttle_manager'
class BaseStorage:
"""
@ -184,6 +193,78 @@ class BaseStorage:
"""
await self.reset_state(chat=chat, user=user, with_data=True)
def has_bucket(self):
return False
async def get_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
"""
Get state-data for user in chat. Return `default` if data is not presented in storage.
Chat or user is always required. If one of this is not presented,
need set the missing value based on the presented
:param chat:
:param user:
:param default:
:return:
"""
raise NotImplementedError
async def set_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
"""
Set data for user in chat
Chat or user is always required. If one of this is not presented,
need set the missing value based on the presented
:param chat:
:param user:
:param bucket:
"""
raise NotImplementedError
async def update_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
**kwargs):
"""
Update data for user in chat
You can use data parameter or|and kwargs.
Chat or user is always required. If one of this is not presented,
need set the missing value based on the presented
:param bucket:
:param chat:
:param user:
:param kwargs:
:return:
"""
raise NotImplementedError
async def reset_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None):
"""
Reset data dor user in chat.
Chat or user is always required. If one of this is not presented,
need set the missing value based on the presented
:param chat:
:param user:
:return:
"""
await self.set_data(chat=chat, user=user, data={})
class FSMContext:
def __init__(self, storage, chat, user):

View file

@ -186,10 +186,15 @@ class WebhookRequestHandler(web.View):
dispatcher = self.get_dispatcher()
loop = dispatcher.loop
results = task.result()
response = self.get_response(results)
if response is not None:
asyncio.ensure_future(response.execute_response(self.get_dispatcher().bot), loop=loop)
try:
results = task.result()
except Exception as e:
loop.create_task(
dispatcher.errors_handlers.notify(dispatcher, context.get_value('update_object'), e))
else:
response = self.get_response(results)
if response is not None:
asyncio.ensure_future(response.execute_response(dispatcher.bot), loop=loop)
def get_response(self, results):
"""

View file

@ -138,10 +138,8 @@ class TelegramObject(metaclass=MetaTelegramObject):
@property
def bot(self):
bot = get_value('bot')
if bot is None:
raise RuntimeError('Can not found bot instance in current context!')
return bot
from ..dispatcher import ctx
return ctx.get_bot()
def to_python(self) -> typing.Dict:
"""

View file

@ -211,7 +211,7 @@ class ChatActions(helper.Helper):
@classmethod
async def _do(cls, action: str, sleep=None):
from aiogram.dispatcher.ctx import get_bot, get_chat
from ..dispatcher.ctx import get_bot, get_chat
await get_bot().send_chat_action(get_chat(), action)
if sleep:
await asyncio.sleep(sleep)

View file

@ -1,9 +1,9 @@
import datetime
from aiogram.utils import helper
from . import base
from . import fields
from .user import User
from ..utils import helper
class ChatMember(base.TelegramObject):

View file

@ -76,6 +76,7 @@ class MediaGroup(base.TelegramObject):
"""
Helper for sending media group
"""
def __init__(self, medias: typing.Optional[typing.List[typing.Union[InputMedia, typing.Dict]]] = None):
super(MediaGroup, self).__init__()
self.media = []

View file

@ -31,4 +31,4 @@ class PreCheckoutQuery(base.TelegramObject):
def __eq__(self, other):
if isinstance(other, type(self)):
return other.id == self.id
return self.id == other
return self.id == other

View file

@ -31,7 +31,7 @@ def task_factory(loop: asyncio.BaseEventLoop, coro: typing.Coroutine):
del task._source_traceback[-1]
try:
task.context = asyncio.Task.current_task().context
task.context = asyncio.Task.current_task().context.copy()
except AttributeError:
task.context = {CONFIGURED: True}
@ -114,3 +114,25 @@ def check_configured():
:return:
"""
return get_value(CONFIGURED)
class _Context:
"""
Other things for interactions with the execution context.
"""
def __getitem__(self, item):
return get_value(item)
def __setitem__(self, key, value):
set_value(key, value)
def __delitem__(self, key):
del_value(key)
@staticmethod
def get_context():
return get_current_state()
context = _Context()

View file

@ -1,3 +1,5 @@
import time
_PREFIXES = ['Error: ', '[Error]: ', 'Bad Request: ', 'Conflict: ']
@ -51,3 +53,21 @@ class MigrateToChat(TelegramAPIError):
def __init__(self, chat_id):
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
self.migrate_to_chat_id = chat_id
class Throttled(Exception):
def __init__(self, **kwargs):
from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT
self.key = kwargs.pop(KEY, '<None>')
self.called_at = kwargs.pop(LAST_CALL, time.time())
self.rate = kwargs.pop(RATE_LIMIT, None)
self.result = kwargs.pop(RESULT, False)
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
self.delta = kwargs.pop(DELTA, 0)
self.user = kwargs.pop('user', None)
self.chat = kwargs.pop('chat', None)
def __str__(self):
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
f"exceeded: {self.exceeded_count}, " \
f"time delta: {round(self.delta, 3)} s)"

View file

@ -34,6 +34,7 @@ async def _shutdown(dispatcher: Dispatcher, callback=None):
if dispatcher.is_polling():
dispatcher.stop_polling()
# await dispatcher.wait_closed()
await dispatcher.storage.close()
await dispatcher.storage.wait_closed()

7
dev_requirements.txt Normal file
View file

@ -0,0 +1,7 @@
-r requirements.txt
ujson
emoji
pytest
pytest-asyncio
uvloop
aioredis

View file

@ -0,0 +1,123 @@
import asyncio
from aiogram import Bot, types
from aiogram.contrib.fsm_storage.redis import RedisStorage2
from aiogram.dispatcher import CancelHandler, DEFAULT_RATE_LIMIT, Dispatcher, ctx
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils import context, executor
from aiogram.utils.exceptions import Throttled
TOKEN = 'BOT TOKEN HERE'
loop = asyncio.get_event_loop()
# In this example used Redis storage
storage = RedisStorage2(db=5)
bot = Bot(token=TOKEN, loop=loop)
dp = Dispatcher(bot, storage=storage)
def rate_limit(limit: int, key=None):
"""
Decorator for configuring rate limit and key in different functions.
:param limit:
:param key:
:return:
"""
def decorator(func):
setattr(func, 'throttling_rate_limit', limit)
if key:
setattr(func, 'throttling_key', key)
return func
return decorator
class ThrottlingMiddleware(BaseMiddleware):
"""
Simple middleware
"""
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
self.rate_limit = limit
self.prefix = key_prefix
super(ThrottlingMiddleware, self).__init__()
async def on_process_message(self, message: types.Message):
"""
That handler will be called when dispatcher receive message
:param message:
"""
# Get current handler
handler = context.get_value('handler')
# Get dispatcher from context
dispatcher = ctx.get_dispatcher()
# If handler was configured get rate limit and key from handler
if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
limit = self.rate_limit
key = f"{self.prefix}_message"
# Use Dispatcher.throttle method.
try:
await dispatcher.throttle(key, rate=limit)
except Throttled as t:
# Execute action
await self.message_throttled(message, t)
# Cancel current handler
raise CancelHandler()
async def message_throttled(self, message: types.Message, throttled: Throttled):
"""
Notify user only on first exceed and notify about unlocking only on last exceed
:param message:
:param throttled:
"""
handler = context.get_value('handler')
dispatcher = ctx.get_dispatcher()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
key = f"{self.prefix}_message"
# Calculate how many time left to the end of block.
delta = throttled.rate - throttled.delta
# Prevent flooding
if throttled.exceeded_count <= 2:
await message.reply('Too many requests! ')
# Sleep.
await asyncio.sleep(delta)
# Check lock status
thr = await dispatcher.check_key(key)
# If current message is not last with current key - do not send message
if thr.exceeded_count == throttled.exceeded_count:
await message.reply('Unlocked.')
@dp.message_handler(commands=['start'])
@rate_limit(5, 'start') # is not required but with that you can configure throttling manager for current handler
async def cmd_test(message: types.Message):
# You can use that command every 5 seconds
await message.reply('Test passed! You can use that command every 5 seconds.')
if __name__ == '__main__':
# Setup middleware
dp.middleware.setup(ThrottlingMiddleware())
# Start long-polling
executor.start_polling(dp, loop=loop)

View file

@ -0,0 +1,43 @@
"""
Example for throttling manager.
You can use that for flood controlling.
"""
import asyncio
import logging
from aiogram import Bot, types
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from aiogram.dispatcher import Dispatcher
from aiogram.utils.exceptions import Throttled
from aiogram.utils.executor import start_polling
API_TOKEN = 'BOT TOKEN HERE'
logging.basicConfig(level=logging.INFO)
loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop)
# Throttling manager does not working without Leaky Bucket.
# Then need to use storage's. For example use simple in-memory storage.
storage = MemoryStorage()
dp = Dispatcher(bot, storage=storage)
@dp.message_handler(commands=['start', 'help'])
async def send_welcome(message: types.Message):
try:
# Execute throttling manager with rate-limit equals to 2 seconds for key "start"
await dp.throttle('start', rate=2)
except Throttled:
# If request is throttled the `Throttled` exception will be raised.
await message.reply('Too many requests!')
else:
# Otherwise do something.
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
if __name__ == '__main__':
start_polling(dp, loop=loop, skip_updates=True)

1
tests/conftest.py Normal file
View file

@ -0,0 +1 @@
# pytest_plugins = "pytest_asyncio.plugin"

View file

@ -1,25 +0,0 @@
UPDATE = {
"update_id": 128526,
"message": {
"message_id": 11223,
"from": {
"id": 12345678,
"is_bot": False,
"first_name": "FirstName",
"last_name": "LastName",
"username": "username",
"language_code": "ru"
},
"chat": {
"id": 12345678,
"first_name": "FirstName",
"last_name": "LastName",
"username": "username",
"type": "private"
},
"date": 1508709711,
"text": "Hi, world!"
}
}
MESSAGE = UPDATE['message']

4
tests/test_bot.py Normal file
View file

@ -0,0 +1,4 @@
import aiogram
# bot = aiogram.Bot('123456789:AABBCCDDEEFFaabbccddeeff-1234567890')
# TODO: mock for aiogram.bot.api.request and then test all AI methods.

View file

@ -1,35 +0,0 @@
import datetime
import unittest
from aiogram import types
from dataset import MESSAGE
class TestMessage(unittest.TestCase):
def setUp(self):
self.message = types.Message(**MESSAGE)
def test_update_id(self):
self.assertEqual(self.message.message_id, MESSAGE['message_id'], 'test')
self.assertEqual(self.message['message_id'], MESSAGE['message_id'])
def test_from(self):
self.assertIsInstance(self.message.from_user, types.User)
self.assertEqual(self.message.from_user, self.message['from'])
def test_chat(self):
self.assertIsInstance(self.message.chat, types.Chat)
self.assertEqual(self.message.chat, self.message['chat'])
def test_date(self):
self.assertIsInstance(self.message.date, datetime.datetime)
self.assertEqual(int(self.message.date.timestamp()), MESSAGE['date'])
self.assertEqual(self.message.date, self.message['date'])
def test_text(self):
self.assertEqual(self.message.text, MESSAGE['text'])
self.assertEqual(self.message['text'], MESSAGE['text'])
if __name__ == '__main__':
unittest.main()

0
tests/types/__init__.py Normal file
View file

75
tests/types/dataset.py Normal file
View file

@ -0,0 +1,75 @@
USER = {
"id": 12345678,
"is_bot": False,
"first_name": "FirstName",
"last_name": "LastName",
"username": "username",
"language_code": "ru-RU"
}
CHAT = {
"id": 12345678,
"first_name": "FirstName",
"last_name": "LastName",
"username": "username",
"type": "private"
}
MESSAGE = {
"message_id": 11223,
"from": USER,
"chat": CHAT,
"date": 1508709711,
"text": "Hi, world!"
}
DOCUMENT = {
"file_name": "test.docx",
"mime_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"file_id": "BQADAgADpgADy_JxS66XQTBRHFleAg",
"file_size": 21331
}
MESSAGE_WITH_DOCUMENT = {
"message_id": 12345,
"from": USER,
"chat": CHAT,
"date": 1508768012,
"document": DOCUMENT,
"caption": "doc description"
}
UPDATE = {
"update_id": 128526,
"message": MESSAGE
}
PHOTO = {
"file_id": "AgADBAADFak0G88YZAf8OAug7bHyS9x2ZxkABHVfpJywcloRAAGAAQABAg",
"file_size": 1101,
"width": 90,
"height": 51
}
ANIMATION = {
"file_name": "a9b0e0ca537aa344338f80978f0896b7.gif.mp4",
"mime_type": "video/mp4",
"thumb": PHOTO,
"file_id": "CgADBAAD4DUAAoceZAe2WiE9y0crrAI",
"file_size": 65837
}
GAME = {
"title": "Karate Kido",
"description": "No trees were harmed in the making of this game :)",
"photo": [PHOTO, PHOTO, PHOTO],
"animation": ANIMATION
}
MESSAGE_WITH_GAME = {
"message_id": 12345,
"from": USER,
"chat": CHAT,
"date": 1508824810,
"game": GAME
}

View file

@ -0,0 +1,39 @@
from aiogram import types
from .dataset import ANIMATION
animation = types.Animation(**ANIMATION)
def test_export():
exported = animation.to_python()
assert isinstance(exported, dict)
assert exported == ANIMATION
def test_file_name():
assert isinstance(animation.file_name, str)
assert animation.file_name == ANIMATION['file_name']
def test_mime_type():
assert isinstance(animation.mime_type, str)
assert animation.mime_type == ANIMATION['mime_type']
def test_file_id():
assert isinstance(animation.file_id, str)
# assert hash(animation) == ANIMATION['file_id']
assert animation.file_id == ANIMATION['file_id']
def test_file_size():
assert isinstance(animation.file_size, int)
assert animation.file_size == ANIMATION['file_size']
def test_thumb():
assert isinstance(animation.thumb, types.PhotoSize)
assert animation.thumb.file_id == ANIMATION['thumb']['file_id']
assert animation.thumb.width == ANIMATION['thumb']['width']
assert animation.thumb.height == ANIMATION['thumb']['height']
assert animation.thumb.file_size == ANIMATION['thumb']['file_size']

61
tests/types/test_chat.py Normal file
View file

@ -0,0 +1,61 @@
from aiogram import types
from .dataset import CHAT
chat = types.Chat(**CHAT)
def test_export():
exported = chat.to_python()
assert isinstance(exported, dict)
assert exported == CHAT
def test_id():
assert isinstance(chat.id, int)
assert chat.id == CHAT['id']
assert hash(chat) == CHAT['id']
def test_name():
assert isinstance(chat.first_name, str)
assert chat.first_name == CHAT['first_name']
assert isinstance(chat.last_name, str)
assert chat.last_name == CHAT['last_name']
assert isinstance(chat.username, str)
assert chat.username == CHAT['username']
def test_type():
assert isinstance(chat.type, str)
assert chat.type == CHAT['type']
def test_chat_types():
assert types.ChatType.PRIVATE == 'private'
assert types.ChatType.GROUP == 'group'
assert types.ChatType.SUPER_GROUP == 'supergroup'
assert types.ChatType.CHANNEL == 'channel'
def test_chat_type_filters():
from . import test_message
assert types.ChatType.is_private(test_message.message)
assert not types.ChatType.is_group(test_message.message)
assert not types.ChatType.is_super_group(test_message.message)
assert not types.ChatType.is_group_or_super_group(test_message.message)
assert not types.ChatType.is_channel(test_message.message)
def test_chat_actions():
assert types.ChatActions.TYPING == 'typing'
assert types.ChatActions.UPLOAD_PHOTO == 'upload_photo'
assert types.ChatActions.RECORD_VIDEO == 'record_video'
assert types.ChatActions.UPLOAD_VIDEO == 'upload_video'
assert types.ChatActions.RECORD_AUDIO == 'record_audio'
assert types.ChatActions.UPLOAD_AUDIO == 'upload_audio'
assert types.ChatActions.UPLOAD_DOCUMENT == 'upload_document'
assert types.ChatActions.FIND_LOCATION == 'find_location'
assert types.ChatActions.RECORD_VIDEO_NOTE == 'record_video_note'
assert types.ChatActions.UPLOAD_VIDEO_NOTE == 'upload_video_note'

View file

@ -0,0 +1,35 @@
from aiogram import types
from .dataset import DOCUMENT
document = types.Document(**DOCUMENT)
def test_export():
exported = document.to_python()
assert isinstance(exported, dict)
assert exported == DOCUMENT
def test_file_name():
assert isinstance(document.file_name, str)
assert document.file_name == DOCUMENT['file_name']
def test_mime_type():
assert isinstance(document.mime_type, str)
assert document.mime_type == DOCUMENT['mime_type']
def test_file_id():
assert isinstance(document.file_id, str)
# assert hash(document) == DOCUMENT['file_id']
assert document.file_id == DOCUMENT['file_id']
def test_file_size():
assert isinstance(document.file_size, int)
assert document.file_size == DOCUMENT['file_size']
def test_thumb():
assert document.thumb is None

29
tests/types/test_game.py Normal file
View file

@ -0,0 +1,29 @@
from aiogram import types
from .dataset import GAME
game = types.Game(**GAME)
def test_export():
exported = game.to_python()
assert isinstance(exported, dict)
assert exported == GAME
def test_title():
assert isinstance(game.title, str)
assert game.title == GAME['title']
def test_description():
assert isinstance(game.description, str)
assert game.description == GAME['description']
def test_photo():
assert isinstance(game.photo, list)
assert len(game.photo) == len(GAME['photo'])
assert all(map(lambda t: isinstance(t, types.PhotoSize), game.photo))
def test_animation():
assert isinstance(game.animation, types.Animation)

View file

@ -0,0 +1,39 @@
import datetime
from aiogram import types
from .dataset import MESSAGE
message = types.Message(**MESSAGE)
def test_export():
exported_chat = message.to_python()
assert isinstance(exported_chat, dict)
assert exported_chat == MESSAGE
def test_message_id():
assert hash(message) == MESSAGE['message_id']
assert message.message_id == MESSAGE['message_id']
assert message['message_id'] == MESSAGE['message_id']
def test_from():
assert isinstance(message.from_user, types.User)
assert message.from_user == message['from']
def test_chat():
assert isinstance(message.chat, types.Chat)
assert message.chat == message['chat']
def test_date():
assert isinstance(message.date, datetime.datetime)
assert int(message.date.timestamp()) == MESSAGE['date']
assert message.date == message['date']
def test_text():
assert message.text == MESSAGE['text']
assert message['text'] == MESSAGE['text']

27
tests/types/test_photo.py Normal file
View file

@ -0,0 +1,27 @@
from aiogram import types
from .dataset import PHOTO
photo = types.PhotoSize(**PHOTO)
def test_export():
exported = photo.to_python()
assert isinstance(exported, dict)
assert exported == PHOTO
def test_file_id():
assert isinstance(photo.file_id, str)
assert photo.file_id == PHOTO['file_id']
def test_file_size():
assert isinstance(photo.file_size, int)
assert photo.file_size == PHOTO['file_size']
def test_size():
assert isinstance(photo.width, int)
assert isinstance(photo.height, int)
assert photo.width == PHOTO['width']
assert photo.height == PHOTO['height']

View file

@ -0,0 +1,20 @@
from aiogram import types
from .dataset import UPDATE
update = types.Update(**UPDATE)
def test_export():
exported = update.to_python()
assert isinstance(exported, dict)
assert exported == UPDATE
def test_update_id():
assert isinstance(update.update_id, int)
assert hash(update) == UPDATE['update_id']
assert update.update_id == UPDATE['update_id']
def test_message():
assert isinstance(update.message, types.Message)

48
tests/types/test_user.py Normal file
View file

@ -0,0 +1,48 @@
from babel import Locale
from aiogram import types
from .dataset import USER
user = types.User(**USER)
def test_export():
exported = user.to_python()
assert isinstance(exported, dict)
assert exported == USER
def test_id():
assert isinstance(user.id, int)
assert user.id == USER['id']
assert hash(user) == USER['id']
def test_bot():
assert isinstance(user.is_bot, bool)
assert user.is_bot == USER['is_bot']
def test_name():
assert user.first_name == USER['first_name']
assert user.last_name == USER['last_name']
assert user.username == USER['username']
def test_language_code():
assert user.language_code == USER['language_code']
assert user.locale == Locale.parse(USER['language_code'], sep='-')
def test_full_name():
assert user.full_name == f"{USER['first_name']} {USER['last_name']}"
def test_mention():
assert user.mention == f"@{USER['username']}"
assert user.get_mention('foo') == f"[foo](tg://user?id={USER['id']})"
assert user.get_mention('foo', as_html=True) == f"<a href=\"tg://user?id={USER['id']}\">foo</a>"
def test_url():
assert user.url == f"tg://user?id={USER['id']}"

View file

@ -1,2 +0,0 @@
def out(*message, sep=' '):
print('Test', sep.join(message))

7
tox.ini Normal file
View file

@ -0,0 +1,7 @@
[tox]
envlist = py36
[testenv]
deps = -rdev_requirements.txt
commands = pytest
skip_install = true