mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-08 17:13:56 +00:00
Optimize filter checking
This commit is contained in:
parent
78aee861bb
commit
9ee1a0cbed
3 changed files with 47 additions and 20 deletions
|
|
@ -1,7 +1,8 @@
|
|||
from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \
|
||||
ExceptionsFilter, FuncFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text
|
||||
from .factory import FiltersFactory
|
||||
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
|
||||
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, FilterObj, execute_filter, \
|
||||
check_filters, get_filter_spec, get_filters_spec
|
||||
|
||||
__all__ = [
|
||||
'AbstractFilter',
|
||||
|
|
@ -23,6 +24,9 @@ __all__ = [
|
|||
'Regexp',
|
||||
'StateFilter',
|
||||
'Text',
|
||||
'check_filter',
|
||||
'FilterObj',
|
||||
'get_filters_spec',
|
||||
'get_filters_spec',
|
||||
'execute_filter',
|
||||
'check_filters'
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import abc
|
|||
import inspect
|
||||
import typing
|
||||
|
||||
from ..handler import Handler
|
||||
from ..handler import Handler, FilterObj
|
||||
|
||||
|
||||
class FilterNotPassed(Exception):
|
||||
|
|
@ -20,15 +20,7 @@ def wrap_async(func):
|
|||
return async_wrapper
|
||||
|
||||
|
||||
async def check_filter(dispatcher, filter_, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param dispatcher:
|
||||
:param filter_:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
def get_filter_spec(dispatcher, filter_: callable):
|
||||
kwargs = {}
|
||||
if not callable(filter_):
|
||||
raise TypeError('Filter must be callable and/or awaitable!')
|
||||
|
|
@ -39,16 +31,37 @@ async def check_filter(dispatcher, filter_, args):
|
|||
if inspect.isawaitable(filter_) \
|
||||
or inspect.iscoroutinefunction(filter_) \
|
||||
or isinstance(filter_, AbstractFilter):
|
||||
return await filter_(*args, **kwargs)
|
||||
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
|
||||
else:
|
||||
return filter_(*args, **kwargs)
|
||||
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
|
||||
|
||||
|
||||
async def check_filters(dispatcher, filters, args):
|
||||
def get_filters_spec(dispatcher, filters: typing.Iterable[callable]):
|
||||
data = []
|
||||
if filters is not None:
|
||||
for i in filters:
|
||||
data.append(get_filter_spec(dispatcher, i))
|
||||
return data
|
||||
|
||||
|
||||
async def execute_filter(filter_: FilterObj, args):
|
||||
"""
|
||||
Helper for executing filter
|
||||
|
||||
:param filter_:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
if filter_.is_async:
|
||||
return await filter_.filter(*args, **filter_.kwargs)
|
||||
else:
|
||||
return filter_.filter(*args, **filter_.kwargs)
|
||||
|
||||
|
||||
async def check_filters(filters: typing.Iterable[FilterObj], args):
|
||||
"""
|
||||
Check list of filters
|
||||
|
||||
:param dispatcher:
|
||||
:param filters:
|
||||
:param args:
|
||||
:return:
|
||||
|
|
@ -56,7 +69,7 @@ async def check_filters(dispatcher, filters, args):
|
|||
data = {}
|
||||
if filters is not None:
|
||||
for filter_ in filters:
|
||||
f = await check_filter(dispatcher, filter_, args)
|
||||
f = await execute_filter(filter_, args)
|
||||
if not f:
|
||||
raise FilterNotPassed()
|
||||
elif isinstance(f, dict):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,19 @@
|
|||
import inspect
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Optional, Iterable
|
||||
|
||||
ctx_data = ContextVar('ctx_handler_data')
|
||||
current_handler = ContextVar('current_handler')
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterObj:
|
||||
filter: callable
|
||||
kwargs: dict
|
||||
is_async: bool
|
||||
|
||||
|
||||
class SkipHandler(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -39,6 +46,7 @@ class Handler:
|
|||
self.middleware_key = middleware_key
|
||||
|
||||
def register(self, handler, filters=None, index=None):
|
||||
from .filters import get_filters_spec
|
||||
"""
|
||||
Register callback
|
||||
|
||||
|
|
@ -52,6 +60,8 @@ class Handler:
|
|||
|
||||
if filters and not isinstance(filters, (list, tuple, set)):
|
||||
filters = [filters]
|
||||
filters = get_filters_spec(self.dispatcher, filters)
|
||||
|
||||
record = Handler.HandlerObj(handler=handler, spec=spec, filters=filters)
|
||||
if index is None:
|
||||
self.handlers.append(record)
|
||||
|
|
@ -95,7 +105,7 @@ class Handler:
|
|||
try:
|
||||
for handler_obj in self.handlers:
|
||||
try:
|
||||
data.update(await check_filters(self.dispatcher, handler_obj.filters, args))
|
||||
data.update(await check_filters(handler_obj.filters, args))
|
||||
except FilterNotPassed:
|
||||
continue
|
||||
else:
|
||||
|
|
@ -126,4 +136,4 @@ class Handler:
|
|||
class HandlerObj:
|
||||
handler: callable
|
||||
spec: inspect.FullArgSpec
|
||||
filters: Union[list, tuple, set, None] = None
|
||||
filters: Optional[Iterable[FilterObj]] = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue