diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 37a9ecb7..7937d209 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -64,6 +64,7 @@ class FilterObject(CallableMixin): class HandlerObject(CallableMixin): callback: HandlerType filters: Optional[List[FilterObject]] = None + flags: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: super(HandlerObject, self).__post_init__() diff --git a/aiogram/dispatcher/event/telegram.py b/aiogram/dispatcher/event/telegram.py index fd9ba63a..5535bb3e 100644 --- a/aiogram/dispatcher/event/telegram.py +++ b/aiogram/dispatcher/event/telegram.py @@ -175,16 +175,25 @@ class TelegramEventObserver: return bound_filters def register( - self, callback: HandlerType, *filters: FilterType, **bound_filters: Any + self, + callback: HandlerType, + *filters: FilterType, + flags: Optional[Dict[str, Any]] = None, + **bound_filters: Any, ) -> HandlerType: """ Register event handler """ + if flags is None: + flags = {} resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False) + for resolved_filter in resolved_filters: + resolved_filter.update_handler_flags(flags=flags) self.handlers.append( HandlerObject( callback=callback, filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)], + flags=flags, ) ) return callback @@ -222,7 +231,7 @@ class TelegramEventObserver: for handler in self.handlers: result, data = await handler.check(event, **kwargs) if result: - kwargs.update(data) + kwargs.update(data, handler=handler) try: wrapped_inner = self._wrap_middleware( self._resolve_middlewares(), handler.call @@ -234,14 +243,14 @@ class TelegramEventObserver: return UNHANDLED def __call__( - self, *args: FilterType, **bound_filters: Any + self, *args: FilterType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any ) -> Callable[[CallbackType], CallbackType]: """ Decorator for registering event handlers """ def wrapper(callback: CallbackType) -> CallbackType: - self.register(callback, *args, **bound_filters) + self.register(callback, *args, flags=flags, **bound_filters) return callback return wrapper diff --git a/aiogram/dispatcher/filters/base.py b/aiogram/dispatcher/filters/base.py index 769e887c..d2bb99cf 100644 --- a/aiogram/dispatcher/filters/base.py +++ b/aiogram/dispatcher/filters/base.py @@ -32,6 +32,9 @@ class BaseFilter(ABC, BaseModel): """ pass + def update_handler_flags(self, flags: Dict[str, Any]) -> None: + pass + def __await__(self): # type: ignore # pragma: no cover # Is needed only for inspection and this method is never be called return self.__call__ diff --git a/aiogram/dispatcher/filters/command.py b/aiogram/dispatcher/filters/command.py index 0e46c1ec..03cb41c7 100644 --- a/aiogram/dispatcher/filters/command.py +++ b/aiogram/dispatcher/filters/command.py @@ -38,6 +38,10 @@ class Command(BaseFilter): command_magic: Optional[MagicFilter] = None """Validate command object via Magic filter after all checks done""" + def update_handler_flags(self, flags: Dict[str, Any]) -> None: + commands = flags.setdefault("commands", []) + commands.append(self) + @validator("commands", always=True) def _validate_commands( cls, value: Union[Sequence[CommandPatterType], CommandPatterType] diff --git a/tests/test_dispatcher/test_event/test_telegram.py b/tests/test_dispatcher/test_event/test_telegram.py index 4d0e11e1..df9e128d 100644 --- a/tests/test_dispatcher/test_event/test_telegram.py +++ b/tests/test_dispatcher/test_event/test_telegram.py @@ -287,6 +287,8 @@ class TestTelegramEventObserver: observer.register(pipe_handler, mix_data) results = await observer.trigger(42) + assert len(results) == 2 + assert results[1].pop("handler") assert results == ((42,), {"b": 2}) @pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware")) diff --git a/tests/test_dispatcher/test_filters/test_command.py b/tests/test_dispatcher/test_filters/test_command.py index 18888fe1..d7a0ef55 100644 --- a/tests/test_dispatcher/test_filters/test_command.py +++ b/tests/test_dispatcher/test_filters/test_command.py @@ -126,3 +126,16 @@ class TestCommandObject: ) def test_text(self, obj: CommandObject, result: str): assert obj.text == result + + def test_update_handler_flags(self): + cmd = Command(commands=["start"]) + flags = {} + cmd.update_handler_flags(flags) + + assert "commands" in flags + assert isinstance(flags["commands"], list) + assert len(flags["commands"]) == 1 + assert flags["commands"][0] is cmd + + cmd.update_handler_flags(flags) + assert len(flags["commands"]) == 2