mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 18:19:34 +00:00
Implemented handler flags feature (#728)
* Implemented handler flag feature * Cover tests
This commit is contained in:
parent
5f07cb3d06
commit
3ad16be507
6 changed files with 36 additions and 4 deletions
|
|
@ -64,6 +64,7 @@ class FilterObject(CallableMixin):
|
||||||
class HandlerObject(CallableMixin):
|
class HandlerObject(CallableMixin):
|
||||||
callback: HandlerType
|
callback: HandlerType
|
||||||
filters: Optional[List[FilterObject]] = None
|
filters: Optional[List[FilterObject]] = None
|
||||||
|
flags: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super(HandlerObject, self).__post_init__()
|
super(HandlerObject, self).__post_init__()
|
||||||
|
|
|
||||||
|
|
@ -175,16 +175,25 @@ class TelegramEventObserver:
|
||||||
return bound_filters
|
return bound_filters
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
self,
|
||||||
|
callback: HandlerType,
|
||||||
|
*filters: FilterType,
|
||||||
|
flags: Optional[Dict[str, Any]] = None,
|
||||||
|
**bound_filters: Any,
|
||||||
) -> HandlerType:
|
) -> HandlerType:
|
||||||
"""
|
"""
|
||||||
Register event handler
|
Register event handler
|
||||||
"""
|
"""
|
||||||
|
if flags is None:
|
||||||
|
flags = {}
|
||||||
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
|
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
|
||||||
|
for resolved_filter in resolved_filters:
|
||||||
|
resolved_filter.update_handler_flags(flags=flags)
|
||||||
self.handlers.append(
|
self.handlers.append(
|
||||||
HandlerObject(
|
HandlerObject(
|
||||||
callback=callback,
|
callback=callback,
|
||||||
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
|
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
|
||||||
|
flags=flags,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return callback
|
return callback
|
||||||
|
|
@ -222,7 +231,7 @@ class TelegramEventObserver:
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
result, data = await handler.check(event, **kwargs)
|
result, data = await handler.check(event, **kwargs)
|
||||||
if result:
|
if result:
|
||||||
kwargs.update(data)
|
kwargs.update(data, handler=handler)
|
||||||
try:
|
try:
|
||||||
wrapped_inner = self._wrap_middleware(
|
wrapped_inner = self._wrap_middleware(
|
||||||
self._resolve_middlewares(), handler.call
|
self._resolve_middlewares(), handler.call
|
||||||
|
|
@ -234,14 +243,14 @@ class TelegramEventObserver:
|
||||||
return UNHANDLED
|
return UNHANDLED
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, *args: FilterType, **bound_filters: Any
|
self, *args: FilterType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any
|
||||||
) -> Callable[[CallbackType], CallbackType]:
|
) -> Callable[[CallbackType], CallbackType]:
|
||||||
"""
|
"""
|
||||||
Decorator for registering event handlers
|
Decorator for registering event handlers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(callback: CallbackType) -> CallbackType:
|
def wrapper(callback: CallbackType) -> CallbackType:
|
||||||
self.register(callback, *args, **bound_filters)
|
self.register(callback, *args, flags=flags, **bound_filters)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,9 @@ class BaseFilter(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def __await__(self): # type: ignore # pragma: no cover
|
def __await__(self): # type: ignore # pragma: no cover
|
||||||
# Is needed only for inspection and this method is never be called
|
# Is needed only for inspection and this method is never be called
|
||||||
return self.__call__
|
return self.__call__
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,10 @@ class Command(BaseFilter):
|
||||||
command_magic: Optional[MagicFilter] = None
|
command_magic: Optional[MagicFilter] = None
|
||||||
"""Validate command object via Magic filter after all checks done"""
|
"""Validate command object via Magic filter after all checks done"""
|
||||||
|
|
||||||
|
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
|
||||||
|
commands = flags.setdefault("commands", [])
|
||||||
|
commands.append(self)
|
||||||
|
|
||||||
@validator("commands", always=True)
|
@validator("commands", always=True)
|
||||||
def _validate_commands(
|
def _validate_commands(
|
||||||
cls, value: Union[Sequence[CommandPatterType], CommandPatterType]
|
cls, value: Union[Sequence[CommandPatterType], CommandPatterType]
|
||||||
|
|
|
||||||
|
|
@ -287,6 +287,8 @@ class TestTelegramEventObserver:
|
||||||
observer.register(pipe_handler, mix_data)
|
observer.register(pipe_handler, mix_data)
|
||||||
|
|
||||||
results = await observer.trigger(42)
|
results = await observer.trigger(42)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[1].pop("handler")
|
||||||
assert results == ((42,), {"b": 2})
|
assert results == ((42,), {"b": 2})
|
||||||
|
|
||||||
@pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware"))
|
@pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware"))
|
||||||
|
|
|
||||||
|
|
@ -126,3 +126,16 @@ class TestCommandObject:
|
||||||
)
|
)
|
||||||
def test_text(self, obj: CommandObject, result: str):
|
def test_text(self, obj: CommandObject, result: str):
|
||||||
assert obj.text == result
|
assert obj.text == result
|
||||||
|
|
||||||
|
def test_update_handler_flags(self):
|
||||||
|
cmd = Command(commands=["start"])
|
||||||
|
flags = {}
|
||||||
|
cmd.update_handler_flags(flags)
|
||||||
|
|
||||||
|
assert "commands" in flags
|
||||||
|
assert isinstance(flags["commands"], list)
|
||||||
|
assert len(flags["commands"]) == 1
|
||||||
|
assert flags["commands"][0] is cmd
|
||||||
|
|
||||||
|
cmd.update_handler_flags(flags)
|
||||||
|
assert len(flags["commands"]) == 2
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue