mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 18:01:04 +00:00
Pydoc observer
This commit is contained in:
parent
d39bbc3952
commit
ad17143d3f
2 changed files with 72 additions and 30 deletions
|
|
@ -1,7 +1,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
|
@ -21,16 +31,12 @@ class EventObserver:
|
|||
Base events observer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.handlers: List[HandlerObject] = []
|
||||
|
||||
def register(self, callback: HandlerType, *filters: FilterType):
|
||||
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
|
||||
"""
|
||||
Register callback with filters
|
||||
|
||||
:param callback:
|
||||
:param filters:
|
||||
:return:
|
||||
"""
|
||||
self.handlers.append(
|
||||
HandlerObject(
|
||||
|
|
@ -39,7 +45,11 @@ class EventObserver:
|
|||
)
|
||||
return callback
|
||||
|
||||
async def trigger(self, *args, **kwargs):
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Propagate event to handlers.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
kwargs_copy = copy.copy(kwargs)
|
||||
result, data = await handler.check(*args, **kwargs)
|
||||
|
|
@ -50,8 +60,12 @@ class EventObserver:
|
|||
except SkipHandler:
|
||||
continue
|
||||
|
||||
def __call__(self, *args: FilterType):
|
||||
def wrapper(callback: CallbackType):
|
||||
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
"""
|
||||
|
||||
def wrapper(callback: CallbackType) -> CallbackType:
|
||||
self.register(callback, *args)
|
||||
return callback
|
||||
|
||||
|
|
@ -63,7 +77,7 @@ class TelegramEventObserver(EventObserver):
|
|||
Event observer for Telegram events
|
||||
"""
|
||||
|
||||
def __init__(self, router: Router, event_name: str):
|
||||
def __init__(self, router: Router, event_name: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.router: Router = router
|
||||
|
|
@ -71,16 +85,25 @@ class TelegramEventObserver(EventObserver):
|
|||
self.filters: List[Type[BaseFilter]] = []
|
||||
|
||||
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
||||
"""
|
||||
Register filter class in factory
|
||||
|
||||
:param bound_filter:
|
||||
"""
|
||||
if not issubclass(bound_filter, BaseFilter):
|
||||
raise TypeError(
|
||||
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
|
||||
)
|
||||
self.filters.append(bound_filter)
|
||||
|
||||
def _resolve_filters_chain(self):
|
||||
registry: List[FilterType] = []
|
||||
def _resolve_filters_chain(self) -> Generator[Type[BaseFilter], None, None]:
|
||||
"""
|
||||
Get all bounded filters from current observer and from the parents
|
||||
with the same event type without duplicates
|
||||
"""
|
||||
registry: List[Type[BaseFilter]] = []
|
||||
|
||||
router = self.router
|
||||
router: Optional[Router] = self.router
|
||||
while router:
|
||||
observer = router.observers[self.event_name]
|
||||
router = router.parent_router
|
||||
|
|
@ -91,16 +114,10 @@ class TelegramEventObserver(EventObserver):
|
|||
yield filter_
|
||||
registry.append(filter_)
|
||||
|
||||
def register(self, callback: HandlerType, *filters: FilterType, **bound_filters: Any):
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
return super().register(callback, *filters, *resolved_filters)
|
||||
|
||||
async def trigger(self, *args, **kwargs):
|
||||
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
||||
yield result
|
||||
break
|
||||
|
||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||
"""
|
||||
Resolve keyword filters via filters factory
|
||||
"""
|
||||
filters: List[BaseFilter] = []
|
||||
if not full_config:
|
||||
return filters
|
||||
|
|
@ -112,19 +129,44 @@ class TelegramEventObserver(EventObserver):
|
|||
except ValidationError:
|
||||
continue
|
||||
|
||||
# Clean full config to prevent to re-initialize another filter with the same configuration
|
||||
# Clean full config to prevent to re-initialize another filter
|
||||
# with the same configuration
|
||||
for key in f.__fields__:
|
||||
full_config.pop(key, None)
|
||||
|
||||
filters.append(f)
|
||||
|
||||
if full_config:
|
||||
raise ValueError(f"Unknown filters: {set(full_config.keys())}")
|
||||
raise ValueError(f"Unknown keyword filters: {set(full_config.keys())}")
|
||||
|
||||
return filters
|
||||
|
||||
def __call__(self, *args: FilterType, **bound_filters):
|
||||
def wrapper(callback: CallbackType):
|
||||
def register(
|
||||
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
||||
) -> HandlerType:
|
||||
"""
|
||||
Register event handler
|
||||
"""
|
||||
resolved_filters = self.resolve_filters(bound_filters)
|
||||
return super().register(callback, *filters, *resolved_filters)
|
||||
|
||||
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Propagate event to handlers and stops propagation on first match.
|
||||
Handler will be called when all its filters is pass.
|
||||
"""
|
||||
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
||||
yield result
|
||||
break
|
||||
|
||||
def __call__(
|
||||
self, *args: FilterType, **bound_filters: BaseFilter
|
||||
) -> Callable[[CallbackType], CallbackType]:
|
||||
"""
|
||||
Decorator for registering event handlers
|
||||
"""
|
||||
|
||||
def wrapper(callback: CallbackType) -> CallbackType:
|
||||
self.register(callback, *args, **bound_filters)
|
||||
return callback
|
||||
|
||||
|
|
|
|||
|
|
@ -177,15 +177,15 @@ class TestTelegramEventObserver:
|
|||
assert any(isinstance(item, MyFilter1) for item in resolved)
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"):
|
||||
with pytest.raises(ValueError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"@bad": "very"})
|
||||
|
||||
# Unknown filter
|
||||
with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"):
|
||||
with pytest.raises(ValueError, match="Unknown keyword filters: {'@bad'}"):
|
||||
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
|
||||
|
||||
# Bad argument type
|
||||
with pytest.raises(ValueError, match="Unknown filters: {'test'}"):
|
||||
with pytest.raises(ValueError, match="Unknown keyword filters: {'test'}"):
|
||||
assert observer.resolve_filters({"test": ...})
|
||||
|
||||
def test_register(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue