Implemented handler flags feature (#728)

* Implemented handler flag feature

* Cover tests
This commit is contained in:
Alex Root Junior 2021-10-25 23:37:14 +03:00 committed by GitHub
parent 5f07cb3d06
commit 3ad16be507
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 4 deletions

View file

@ -64,6 +64,7 @@ class FilterObject(CallableMixin):
class HandlerObject(CallableMixin): class HandlerObject(CallableMixin):
callback: HandlerType callback: HandlerType
filters: Optional[List[FilterObject]] = None filters: Optional[List[FilterObject]] = None
flags: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
super(HandlerObject, self).__post_init__() super(HandlerObject, self).__post_init__()

View file

@ -175,16 +175,25 @@ class TelegramEventObserver:
return bound_filters return bound_filters
def register( def register(
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any self,
callback: HandlerType,
*filters: FilterType,
flags: Optional[Dict[str, Any]] = None,
**bound_filters: Any,
) -> HandlerType: ) -> HandlerType:
""" """
Register event handler Register event handler
""" """
if flags is None:
flags = {}
resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False) resolved_filters = self.resolve_filters(filters, bound_filters, ignore_default=False)
for resolved_filter in resolved_filters:
resolved_filter.update_handler_flags(flags=flags)
self.handlers.append( self.handlers.append(
HandlerObject( HandlerObject(
callback=callback, callback=callback,
filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)], filters=[FilterObject(filter_) for filter_ in chain(resolved_filters, filters)],
flags=flags,
) )
) )
return callback return callback
@ -222,7 +231,7 @@ class TelegramEventObserver:
for handler in self.handlers: for handler in self.handlers:
result, data = await handler.check(event, **kwargs) result, data = await handler.check(event, **kwargs)
if result: if result:
kwargs.update(data) kwargs.update(data, handler=handler)
try: try:
wrapped_inner = self._wrap_middleware( wrapped_inner = self._wrap_middleware(
self._resolve_middlewares(), handler.call self._resolve_middlewares(), handler.call
@ -234,14 +243,14 @@ class TelegramEventObserver:
return UNHANDLED return UNHANDLED
def __call__( def __call__(
self, *args: FilterType, **bound_filters: Any self, *args: FilterType, flags: Optional[Dict[str, Any]] = None, **bound_filters: Any
) -> Callable[[CallbackType], CallbackType]: ) -> Callable[[CallbackType], CallbackType]:
""" """
Decorator for registering event handlers Decorator for registering event handlers
""" """
def wrapper(callback: CallbackType) -> CallbackType: def wrapper(callback: CallbackType) -> CallbackType:
self.register(callback, *args, **bound_filters) self.register(callback, *args, flags=flags, **bound_filters)
return callback return callback
return wrapper return wrapper

View file

@ -32,6 +32,9 @@ class BaseFilter(ABC, BaseModel):
""" """
pass pass
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
pass
def __await__(self): # type: ignore # pragma: no cover def __await__(self): # type: ignore # pragma: no cover
# Is needed only for inspection and this method is never be called # Is needed only for inspection and this method is never be called
return self.__call__ return self.__call__

View file

@ -38,6 +38,10 @@ class Command(BaseFilter):
command_magic: Optional[MagicFilter] = None command_magic: Optional[MagicFilter] = None
"""Validate command object via Magic filter after all checks done""" """Validate command object via Magic filter after all checks done"""
def update_handler_flags(self, flags: Dict[str, Any]) -> None:
commands = flags.setdefault("commands", [])
commands.append(self)
@validator("commands", always=True) @validator("commands", always=True)
def _validate_commands( def _validate_commands(
cls, value: Union[Sequence[CommandPatterType], CommandPatterType] cls, value: Union[Sequence[CommandPatterType], CommandPatterType]

View file

@ -287,6 +287,8 @@ class TestTelegramEventObserver:
observer.register(pipe_handler, mix_data) observer.register(pipe_handler, mix_data)
results = await observer.trigger(42) results = await observer.trigger(42)
assert len(results) == 2
assert results[1].pop("handler")
assert results == ((42,), {"b": 2}) assert results == ((42,), {"b": 2})
@pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware")) @pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware"))

View file

@ -126,3 +126,16 @@ class TestCommandObject:
) )
def test_text(self, obj: CommandObject, result: str): def test_text(self, obj: CommandObject, result: str):
assert obj.text == result assert obj.text == result
def test_update_handler_flags(self):
cmd = Command(commands=["start"])
flags = {}
cmd.update_handler_flags(flags)
assert "commands" in flags
assert isinstance(flags["commands"], list)
assert len(flags["commands"]) == 1
assert flags["commands"][0] is cmd
cmd.update_handler_flags(flags)
assert len(flags["commands"]) == 2