diff --git a/.gitignore b/.gitignore index 1f3e1971..4ffb8359 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ aiogram/_meta.py .coverage reports -dev/ \ No newline at end of file +dev/ +.venv/ diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 47c67096..109bb920 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -4,7 +4,7 @@ import asyncio import contextvars import warnings from asyncio import CancelledError, Future, Lock -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Union from .. import loggers from ..client.bot import Bot @@ -130,6 +130,7 @@ class Dispatcher(Router): bot: Bot, polling_timeout: int = 30, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, + allowed_updates: Optional[List[str]] = None, ) -> AsyncGenerator[Update, None]: """ Endless updates reader with correctly handling any server-side or connection errors. @@ -137,7 +138,7 @@ class Dispatcher(Router): So you may not worry that the polling will stop working. """ backoff = Backoff(config=backoff_config) - get_updates = GetUpdates(timeout=polling_timeout) + get_updates = GetUpdates(timeout=polling_timeout, allowed_updates=allowed_updates) kwargs = {} if bot.session.timeout: # Request timeout can be lower than session timeout ant that's OK. @@ -297,6 +298,7 @@ class Dispatcher(Router): polling_timeout: int = 30, handle_as_tasks: bool = True, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, + allowed_updates: Optional[List[str]] = None, **kwargs: Any, ) -> None: """ @@ -307,7 +309,10 @@ class Dispatcher(Router): :return: """ async for update in self._listen_updates( - bot, polling_timeout=polling_timeout, backoff_config=backoff_config + bot, + polling_timeout=polling_timeout, + backoff_config=backoff_config, + allowed_updates=allowed_updates, ): handle_update = self._process_update(bot=bot, update=update, **kwargs) if handle_as_tasks: @@ -397,6 +402,7 @@ class Dispatcher(Router): polling_timeout: int = 10, handle_as_tasks: bool = True, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, + allowed_updates: Optional[List[str]] = None, **kwargs: Any, ) -> None: """ @@ -427,6 +433,7 @@ class Dispatcher(Router): handle_as_tasks=handle_as_tasks, polling_timeout=polling_timeout, backoff_config=backoff_config, + allowed_updates=allowed_updates, **kwargs, ) ) @@ -443,6 +450,7 @@ class Dispatcher(Router): polling_timeout: int = 30, handle_as_tasks: bool = True, backoff_config: BackoffConfig = DEFAULT_BACKOFF_CONFIG, + allowed_updates: Optional[List[str]] = None, **kwargs: Any, ) -> None: """ @@ -452,6 +460,7 @@ class Dispatcher(Router): :param polling_timeout: Poling timeout :param backoff_config: :param handle_as_tasks: Run task for each event and no wait result + :param allowed_updates: List of the update types you want your bot to receive :param kwargs: contextual data :return: """ @@ -463,6 +472,7 @@ class Dispatcher(Router): polling_timeout=polling_timeout, handle_as_tasks=handle_as_tasks, backoff_config=backoff_config, + allowed_updates=allowed_updates, ) ) except (KeyboardInterrupt, SystemExit): # pragma: no cover diff --git a/aiogram/utils/handlers_in_use.py b/aiogram/utils/handlers_in_use.py new file mode 100644 index 00000000..c1816476 --- /dev/null +++ b/aiogram/utils/handlers_in_use.py @@ -0,0 +1,28 @@ +from itertools import chain +from typing import List, cast + +from aiogram.dispatcher.dispatcher import Dispatcher +from aiogram.dispatcher.router import Router + +INTERNAL_HANDLERS = [ + "update", + "error", +] + + +def get_handlers_in_use( + dispatcher: Dispatcher, handlers_to_skip: List[str] = INTERNAL_HANDLERS +) -> List[str]: + handlers_in_use: List[str] = [] + + for router in [dispatcher.sub_routers, dispatcher]: + if isinstance(router, list): + if router: + handlers_in_use.extend(chain(*list(map(get_handlers_in_use, router)))) + else: + router = cast(Router, router) + for update_name, observer in router.observers.items(): + if observer.handlers and update_name not in [*handlers_to_skip, *handlers_in_use]: + handlers_in_use.append(update_name) + + return handlers_in_use diff --git a/examples/specify_updates.py b/examples/specify_updates.py new file mode 100644 index 00000000..33fdd093 --- /dev/null +++ b/examples/specify_updates.py @@ -0,0 +1,87 @@ +from aiogram.types.inline_keyboard_button import InlineKeyboardButton +from aiogram.types.inline_keyboard_markup import InlineKeyboardMarkup +from aiogram.dispatcher.router import Router +from aiogram.utils.handlers_in_use import get_handlers_in_use +import logging + +from aiogram import Bot, Dispatcher +from aiogram.types import Message, ChatMemberUpdated, CallbackQuery + +TOKEN = "6wo" +dp = Dispatcher() + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +@dp.message(commands={"start"}) +async def command_start_handler(message: Message) -> None: + """ + This handler receive messages with `/start` command + """ + + await message.answer( + f"Hello, {message.from_user.full_name}!", + reply_markup=InlineKeyboardMarkup( + inline_keyboard=[[InlineKeyboardButton(text="Tap me, bro", callback_data="*")]] + ), + ) + + +@dp.chat_member() +async def chat_member_update(chat_member: ChatMemberUpdated, bot: Bot) -> None: + await bot.send_message( + chat_member.chat.id, + "Member {chat_member.from_user.id} was changed " + + f"from {chat_member.old_chat_member.is_chat_member} to {chat_member.new_chat_member.is_chat_member}", + ) + + +# this router will use only callback_query updates +sub_router = Router() + + +@sub_router.callback_query() +async def callback_tap_me(callback_query: CallbackQuery) -> None: + await callback_query.answer("Yeah good, now i'm fine") + + +# this router will use only edited_message updates +sub_sub_router = Router() + + +@sub_sub_router.edited_message() +async def callback_tap_me(edited_message: Message) -> None: + await edited_message.reply("Message was edited, big brother watch you") + + +# this router will use only my_chat_member updates +deep_dark_router = Router() + + +@deep_dark_router.my_chat_member() +async def my_chat_member_change(chat_member: ChatMemberUpdated, bot: Bot) -> None: + await bot.send_message( + chat_member.chat.id, + "Member was changed from " + + f"{chat_member.old_chat_member.is_chat_member} to {chat_member.new_chat_member.is_chat_member}", + ) + + +def main() -> None: + # Initialize Bot instance with an default parse mode which will be passed to all API calls + bot = Bot(TOKEN, parse_mode="HTML") + + sub_router.include_router(deep_dark_router) + + dp.include_router(sub_router) + dp.include_router(sub_sub_router) + + useful_updates = get_handlers_in_use(dp) + + # And the run events dispatching + dp.run_polling(bot, allowed_updates=useful_updates) + + +if __name__ == "__main__": + main() diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index 37bbf634..5f7a2f62 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -28,6 +28,7 @@ from aiogram.types import ( Update, User, ) +from aiogram.utils.handlers_in_use import get_handlers_in_use from tests.mocked_bot import MockedBot try: @@ -659,3 +660,56 @@ class TestDispatcher: log_records = [rec.message for rec in caplog.records] assert "Cause exception while process update" in log_records[0] + + def test_specify_updates_calculation(self): + def simple_msg_handler() -> None: + ... + + def simple_callback_query_handler() -> None: + ... + + def simple_poll_handler() -> None: + ... + + def simple_edited_msg_handler() -> None: + ... + + dispatcher = Dispatcher() + dispatcher.message.register(simple_msg_handler) + + router1 = Router() + router1.callback_query.register(simple_callback_query_handler) + + router2 = Router() + router2.poll.register(simple_poll_handler) + + router21 = Router() + router21.edited_message.register(simple_edited_msg_handler) + + useful_updates1 = get_handlers_in_use(dispatcher) + + assert sorted(useful_updates1) == sorted(["message"]) + + dispatcher.include_router(router1) + + useful_updates2 = get_handlers_in_use(dispatcher) + + assert sorted(useful_updates2) == sorted(["message", "callback_query"]) + + dispatcher.include_router(router2) + + useful_updates3 = get_handlers_in_use(dispatcher) + + assert sorted(useful_updates3) == sorted(["message", "callback_query", "poll"]) + + router2.include_router(router21) + + useful_updates4 = get_handlers_in_use(dispatcher) + + assert sorted(useful_updates4) == sorted( + ["message", "callback_query", "poll", "edited_message"] + ) + + useful_updates5 = get_handlers_in_use(router2) + + assert sorted(useful_updates5) == sorted(["poll", "edited_message"])