Added ability to specify which update bot need to receive and process while using polling mode (#617)

* provide allowed_updates in polling mode
This commit is contained in:
Forevka 2021-07-05 00:41:27 +02:00 committed by GitHub
parent eee6589a2c
commit 125fc22ff9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 4 deletions

3
.gitignore vendored
View file

@ -18,4 +18,5 @@ aiogram/_meta.py
.coverage .coverage
reports reports
dev/ dev/
.venv/

View file

@ -4,7 +4,7 @@ import asyncio
import contextvars import contextvars
import warnings import warnings
from asyncio import CancelledError, Future, Lock from asyncio import CancelledError, Future, Lock
from typing import Any, AsyncGenerator, Dict, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from .. import loggers from .. import loggers
from ..client.bot import Bot from ..client.bot import Bot
@ -130,6 +130,7 @@ class Dispatcher(Router):
bot: Bot, bot: Bot,
polling_timeout: int = 30, polling_timeout: int = 30,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
) -> AsyncGenerator[Update, None]: ) -> AsyncGenerator[Update, None]:
""" """
Endless updates reader with correctly handling any server-side or connection errors. Endless updates reader with correctly handling any server-side or connection errors.
@ -137,7 +138,7 @@ class Dispatcher(Router):
So you may not worry that the polling will stop working. So you may not worry that the polling will stop working.
""" """
backoff = Backoff(config=backoff_config) backoff = Backoff(config=backoff_config)
get_updates = GetUpdates(timeout=polling_timeout) get_updates = GetUpdates(timeout=polling_timeout, allowed_updates=allowed_updates)
kwargs = {} kwargs = {}
if bot.session.timeout: if bot.session.timeout:
# Request timeout can be lower than session timeout ant that's OK. # Request timeout can be lower than session timeout ant that's OK.
@ -297,6 +298,7 @@ class Dispatcher(Router):
polling_timeout: int = 30, polling_timeout: int = 30,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -307,7 +309,10 @@ class Dispatcher(Router):
:return: :return:
""" """
async for update in self._listen_updates( async for update in self._listen_updates(
bot, polling_timeout=polling_timeout, backoff_config=backoff_config bot,
polling_timeout=polling_timeout,
backoff_config=backoff_config,
allowed_updates=allowed_updates,
): ):
handle_update = self._process_update(bot=bot, update=update, **kwargs) handle_update = self._process_update(bot=bot, update=update, **kwargs)
if handle_as_tasks: if handle_as_tasks:
@ -397,6 +402,7 @@ class Dispatcher(Router):
polling_timeout: int = 10, polling_timeout: int = 10,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -427,6 +433,7 @@ class Dispatcher(Router):
handle_as_tasks=handle_as_tasks, handle_as_tasks=handle_as_tasks,
polling_timeout=polling_timeout, polling_timeout=polling_timeout,
backoff_config=backoff_config, backoff_config=backoff_config,
allowed_updates=allowed_updates,
**kwargs, **kwargs,
) )
) )
@ -443,6 +450,7 @@ class Dispatcher(Router):
polling_timeout: int = 30, polling_timeout: int = 30,
handle_as_tasks: bool = True, handle_as_tasks: bool = True,
backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG,
allowed_updates: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
@ -452,6 +460,7 @@ class Dispatcher(Router):
:param polling_timeout: Poling timeout :param polling_timeout: Poling timeout
:param backoff_config: :param backoff_config:
:param handle_as_tasks: Run task for each event and no wait result :param handle_as_tasks: Run task for each event and no wait result
:param allowed_updates: List of the update types you want your bot to receive
:param kwargs: contextual data :param kwargs: contextual data
:return: :return:
""" """
@ -463,6 +472,7 @@ class Dispatcher(Router):
polling_timeout=polling_timeout, polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks, handle_as_tasks=handle_as_tasks,
backoff_config=backoff_config, backoff_config=backoff_config,
allowed_updates=allowed_updates,
) )
) )
except (KeyboardInterrupt, SystemExit): # pragma: no cover except (KeyboardInterrupt, SystemExit): # pragma: no cover

View file

@ -0,0 +1,28 @@
from itertools import chain
from typing import List, cast
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.router import Router
INTERNAL_HANDLERS = [
"update",
"error",
]
def get_handlers_in_use(
dispatcher: Dispatcher, handlers_to_skip: List[str] = INTERNAL_HANDLERS
) -> List[str]:
handlers_in_use: List[str] = []
for router in [dispatcher.sub_routers, dispatcher]:
if isinstance(router, list):
if router:
handlers_in_use.extend(chain(*list(map(get_handlers_in_use, router))))
else:
router = cast(Router, router)
for update_name, observer in router.observers.items():
if observer.handlers and update_name not in [*handlers_to_skip, *handlers_in_use]:
handlers_in_use.append(update_name)
return handlers_in_use

View file

@ -0,0 +1,87 @@
from aiogram.types.inline_keyboard_button import InlineKeyboardButton
from aiogram.types.inline_keyboard_markup import InlineKeyboardMarkup
from aiogram.dispatcher.router import Router
from aiogram.utils.handlers_in_use import get_handlers_in_use
import logging
from aiogram import Bot, Dispatcher
from aiogram.types import Message, ChatMemberUpdated, CallbackQuery
TOKEN = "6wo"
dp = Dispatcher()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@dp.message(commands={"start"})
async def command_start_handler(message: Message) -> None:
"""
This handler receive messages with `/start` command
"""
await message.answer(
f"Hello, <b>{message.from_user.full_name}!</b>",
reply_markup=InlineKeyboardMarkup(
inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]]
),
)
@dp.chat_member()
async def chat_member_update(chat_member: ChatMemberUpdated, bot: Bot) -> None:
await bot.send_message(
chat_member.chat.id,
"Member {chat_member.from_user.id} was changed "
+ f"from {chat_member.old_chat_member.is_chat_member} to {chat_member.new_chat_member.is_chat_member}",
)
# this router will use only callback_query updates
sub_router = Router()
@sub_router.callback_query()
async def callback_tap_me(callback_query: CallbackQuery) -> None:
await callback_query.answer("Yeah good, now i'm fine")
# this router will use only edited_message updates
sub_sub_router = Router()
@sub_sub_router.edited_message()
async def callback_tap_me(edited_message: Message) -> None:
await edited_message.reply("Message was edited, big brother watch you")
# this router will use only my_chat_member updates
deep_dark_router = Router()
@deep_dark_router.my_chat_member()
async def my_chat_member_change(chat_member: ChatMemberUpdated, bot: Bot) -> None:
await bot.send_message(
chat_member.chat.id,
"Member was changed from "
+ f"{chat_member.old_chat_member.is_chat_member} to {chat_member.new_chat_member.is_chat_member}",
)
def main() -> None:
# Initialize Bot instance with an default parse mode which will be passed to all API calls
bot = Bot(TOKEN, parse_mode="HTML")
sub_router.include_router(deep_dark_router)
dp.include_router(sub_router)
dp.include_router(sub_sub_router)
useful_updates = get_handlers_in_use(dp)
# And the run events dispatching
dp.run_polling(bot, allowed_updates=useful_updates)
if __name__ == "__main__":
main()

View file

@ -28,6 +28,7 @@ from aiogram.types import (
Update, Update,
User, User,
) )
from aiogram.utils.handlers_in_use import get_handlers_in_use
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot
try: try:
@ -659,3 +660,56 @@ class TestDispatcher:
log_records = [rec.message for rec in caplog.records] log_records = [rec.message for rec in caplog.records]
assert "Cause exception while process update" in log_records[0] assert "Cause exception while process update" in log_records[0]
def test_specify_updates_calculation(self):
def simple_msg_handler() -> None:
...
def simple_callback_query_handler() -> None:
...
def simple_poll_handler() -> None:
...
def simple_edited_msg_handler() -> None:
...
dispatcher = Dispatcher()
dispatcher.message.register(simple_msg_handler)
router1 = Router()
router1.callback_query.register(simple_callback_query_handler)
router2 = Router()
router2.poll.register(simple_poll_handler)
router21 = Router()
router21.edited_message.register(simple_edited_msg_handler)
useful_updates1 = get_handlers_in_use(dispatcher)
assert sorted(useful_updates1) == sorted(["message"])
dispatcher.include_router(router1)
useful_updates2 = get_handlers_in_use(dispatcher)
assert sorted(useful_updates2) == sorted(["message", "callback_query"])
dispatcher.include_router(router2)
useful_updates3 = get_handlers_in_use(dispatcher)
assert sorted(useful_updates3) == sorted(["message", "callback_query", "poll"])
router2.include_router(router21)
useful_updates4 = get_handlers_in_use(dispatcher)
assert sorted(useful_updates4) == sorted(
["message", "callback_query", "poll", "edited_message"]
)
useful_updates5 = get_handlers_in_use(router2)
assert sorted(useful_updates5) == sorted(["poll", "edited_message"])