mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 18:01:04 +00:00
Pydoc observer
This commit is contained in:
parent
d39bbc3952
commit
ad17143d3f
2 changed files with 72 additions and 30 deletions
|
|
@ -1,7 +1,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Type
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
|
@ -21,16 +31,12 @@ class EventObserver:
|
||||||
Base events observer
|
Base events observer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.handlers: List[HandlerObject] = []
|
self.handlers: List[HandlerObject] = []
|
||||||
|
|
||||||
def register(self, callback: HandlerType, *filters: FilterType):
|
def register(self, callback: HandlerType, *filters: FilterType) -> HandlerType:
|
||||||
"""
|
"""
|
||||||
Register callback with filters
|
Register callback with filters
|
||||||
|
|
||||||
:param callback:
|
|
||||||
:param filters:
|
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
self.handlers.append(
|
self.handlers.append(
|
||||||
HandlerObject(
|
HandlerObject(
|
||||||
|
|
@ -39,7 +45,11 @@ class EventObserver:
|
||||||
)
|
)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
async def trigger(self, *args, **kwargs):
|
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||||
|
"""
|
||||||
|
Propagate event to handlers.
|
||||||
|
Handler will be called when all its filters is pass.
|
||||||
|
"""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
kwargs_copy = copy.copy(kwargs)
|
kwargs_copy = copy.copy(kwargs)
|
||||||
result, data = await handler.check(*args, **kwargs)
|
result, data = await handler.check(*args, **kwargs)
|
||||||
|
|
@ -50,8 +60,12 @@ class EventObserver:
|
||||||
except SkipHandler:
|
except SkipHandler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def __call__(self, *args: FilterType):
|
def __call__(self, *args: FilterType) -> Callable[[CallbackType], CallbackType]:
|
||||||
def wrapper(callback: CallbackType):
|
"""
|
||||||
|
Decorator for registering event handlers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(callback: CallbackType) -> CallbackType:
|
||||||
self.register(callback, *args)
|
self.register(callback, *args)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
@ -63,7 +77,7 @@ class TelegramEventObserver(EventObserver):
|
||||||
Event observer for Telegram events
|
Event observer for Telegram events
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, router: Router, event_name: str):
|
def __init__(self, router: Router, event_name: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.router: Router = router
|
self.router: Router = router
|
||||||
|
|
@ -71,16 +85,25 @@ class TelegramEventObserver(EventObserver):
|
||||||
self.filters: List[Type[BaseFilter]] = []
|
self.filters: List[Type[BaseFilter]] = []
|
||||||
|
|
||||||
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
def bind_filter(self, bound_filter: Type[BaseFilter]) -> None:
|
||||||
|
"""
|
||||||
|
Register filter class in factory
|
||||||
|
|
||||||
|
:param bound_filter:
|
||||||
|
"""
|
||||||
if not issubclass(bound_filter, BaseFilter):
|
if not issubclass(bound_filter, BaseFilter):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
|
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
|
||||||
)
|
)
|
||||||
self.filters.append(bound_filter)
|
self.filters.append(bound_filter)
|
||||||
|
|
||||||
def _resolve_filters_chain(self):
|
def _resolve_filters_chain(self) -> Generator[Type[BaseFilter], None, None]:
|
||||||
registry: List[FilterType] = []
|
"""
|
||||||
|
Get all bounded filters from current observer and from the parents
|
||||||
|
with the same event type without duplicates
|
||||||
|
"""
|
||||||
|
registry: List[Type[BaseFilter]] = []
|
||||||
|
|
||||||
router = self.router
|
router: Optional[Router] = self.router
|
||||||
while router:
|
while router:
|
||||||
observer = router.observers[self.event_name]
|
observer = router.observers[self.event_name]
|
||||||
router = router.parent_router
|
router = router.parent_router
|
||||||
|
|
@ -91,16 +114,10 @@ class TelegramEventObserver(EventObserver):
|
||||||
yield filter_
|
yield filter_
|
||||||
registry.append(filter_)
|
registry.append(filter_)
|
||||||
|
|
||||||
def register(self, callback: HandlerType, *filters: FilterType, **bound_filters: Any):
|
|
||||||
resolved_filters = self.resolve_filters(bound_filters)
|
|
||||||
return super().register(callback, *filters, *resolved_filters)
|
|
||||||
|
|
||||||
async def trigger(self, *args, **kwargs):
|
|
||||||
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
|
||||||
yield result
|
|
||||||
break
|
|
||||||
|
|
||||||
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
def resolve_filters(self, full_config: Dict[str, Any]) -> List[BaseFilter]:
|
||||||
|
"""
|
||||||
|
Resolve keyword filters via filters factory
|
||||||
|
"""
|
||||||
filters: List[BaseFilter] = []
|
filters: List[BaseFilter] = []
|
||||||
if not full_config:
|
if not full_config:
|
||||||
return filters
|
return filters
|
||||||
|
|
@ -112,19 +129,44 @@ class TelegramEventObserver(EventObserver):
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Clean full config to prevent to re-initialize another filter with the same configuration
|
# Clean full config to prevent to re-initialize another filter
|
||||||
|
# with the same configuration
|
||||||
for key in f.__fields__:
|
for key in f.__fields__:
|
||||||
full_config.pop(key, None)
|
full_config.pop(key, None)
|
||||||
|
|
||||||
filters.append(f)
|
filters.append(f)
|
||||||
|
|
||||||
if full_config:
|
if full_config:
|
||||||
raise ValueError(f"Unknown filters: {set(full_config.keys())}")
|
raise ValueError(f"Unknown keyword filters: {set(full_config.keys())}")
|
||||||
|
|
||||||
return filters
|
return filters
|
||||||
|
|
||||||
def __call__(self, *args: FilterType, **bound_filters):
|
def register(
|
||||||
def wrapper(callback: CallbackType):
|
self, callback: HandlerType, *filters: FilterType, **bound_filters: Any
|
||||||
|
) -> HandlerType:
|
||||||
|
"""
|
||||||
|
Register event handler
|
||||||
|
"""
|
||||||
|
resolved_filters = self.resolve_filters(bound_filters)
|
||||||
|
return super().register(callback, *filters, *resolved_filters)
|
||||||
|
|
||||||
|
async def trigger(self, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||||
|
"""
|
||||||
|
Propagate event to handlers and stops propagation on first match.
|
||||||
|
Handler will be called when all its filters is pass.
|
||||||
|
"""
|
||||||
|
async for result in super(TelegramEventObserver, self).trigger(*args, **kwargs):
|
||||||
|
yield result
|
||||||
|
break
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, *args: FilterType, **bound_filters: BaseFilter
|
||||||
|
) -> Callable[[CallbackType], CallbackType]:
|
||||||
|
"""
|
||||||
|
Decorator for registering event handlers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(callback: CallbackType) -> CallbackType:
|
||||||
self.register(callback, *args, **bound_filters)
|
self.register(callback, *args, **bound_filters)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -177,15 +177,15 @@ class TestTelegramEventObserver:
|
||||||
assert any(isinstance(item, MyFilter1) for item in resolved)
|
assert any(isinstance(item, MyFilter1) for item in resolved)
|
||||||
|
|
||||||
# Unknown filter
|
# Unknown filter
|
||||||
with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"):
|
with pytest.raises(ValueError, match="Unknown keyword filters: {'@bad'}"):
|
||||||
assert observer.resolve_filters({"@bad": "very"})
|
assert observer.resolve_filters({"@bad": "very"})
|
||||||
|
|
||||||
# Unknown filter
|
# Unknown filter
|
||||||
with pytest.raises(ValueError, match="Unknown filters: {'@bad'}"):
|
with pytest.raises(ValueError, match="Unknown keyword filters: {'@bad'}"):
|
||||||
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
|
assert observer.resolve_filters({"test": "ok", "@bad": "very"})
|
||||||
|
|
||||||
# Bad argument type
|
# Bad argument type
|
||||||
with pytest.raises(ValueError, match="Unknown filters: {'test'}"):
|
with pytest.raises(ValueError, match="Unknown keyword filters: {'test'}"):
|
||||||
assert observer.resolve_filters({"test": ...})
|
assert observer.resolve_filters({"test": ...})
|
||||||
|
|
||||||
def test_register(self):
|
def test_register(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue