Optimize filter checking

This commit is contained in:
Nikita 2019-04-06 22:47:33 +05:00
parent 78aee861bb
commit 9ee1a0cbed
3 changed files with 47 additions and 20 deletions

View file

@ -1,7 +1,8 @@
from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \
ExceptionsFilter, FuncFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text
from .factory import FiltersFactory
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, FilterObj, execute_filter, \
check_filters, get_filter_spec, get_filters_spec
__all__ = [
'AbstractFilter',
@ -23,6 +24,9 @@ __all__ = [
'Regexp',
'StateFilter',
'Text',
'check_filter',
'FilterObj',
'get_filters_spec',
'get_filters_spec',
'execute_filter',
'check_filters'
]

View file

@ -2,7 +2,7 @@ import abc
import inspect
import typing
from ..handler import Handler
from ..handler import Handler, FilterObj
class FilterNotPassed(Exception):
@ -20,15 +20,7 @@ def wrap_async(func):
return async_wrapper
async def check_filter(dispatcher, filter_, args):
"""
Helper for executing filter
:param dispatcher:
:param filter_:
:param args:
:return:
"""
def get_filter_spec(dispatcher, filter_: callable):
kwargs = {}
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
@ -39,16 +31,37 @@ async def check_filter(dispatcher, filter_, args):
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, AbstractFilter):
return await filter_(*args, **kwargs)
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
else:
return filter_(*args, **kwargs)
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
async def check_filters(dispatcher, filters, args):
def get_filters_spec(dispatcher, filters: typing.Iterable[callable]):
data = []
if filters is not None:
for i in filters:
data.append(get_filter_spec(dispatcher, i))
return data
async def execute_filter(filter_: FilterObj, args):
"""
Helper for executing filter
:param filter_:
:param args:
:return:
"""
if filter_.is_async:
return await filter_.filter(*args, **filter_.kwargs)
else:
return filter_.filter(*args, **filter_.kwargs)
async def check_filters(filters: typing.Iterable[FilterObj], args):
"""
Check list of filters
:param dispatcher:
:param filters:
:param args:
:return:
@ -56,7 +69,7 @@ async def check_filters(dispatcher, filters, args):
data = {}
if filters is not None:
for filter_ in filters:
f = await check_filter(dispatcher, filter_, args)
f = await execute_filter(filter_, args)
if not f:
raise FilterNotPassed()
elif isinstance(f, dict):

View file

@ -1,12 +1,19 @@
import inspect
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Union
from typing import Optional, Iterable
ctx_data = ContextVar('ctx_handler_data')
current_handler = ContextVar('current_handler')
@dataclass
class FilterObj:
filter: callable
kwargs: dict
is_async: bool
class SkipHandler(Exception):
pass
@ -39,6 +46,7 @@ class Handler:
self.middleware_key = middleware_key
def register(self, handler, filters=None, index=None):
from .filters import get_filters_spec
"""
Register callback
@ -52,6 +60,8 @@ class Handler:
if filters and not isinstance(filters, (list, tuple, set)):
filters = [filters]
filters = get_filters_spec(self.dispatcher, filters)
record = Handler.HandlerObj(handler=handler, spec=spec, filters=filters)
if index is None:
self.handlers.append(record)
@ -95,7 +105,7 @@ class Handler:
try:
for handler_obj in self.handlers:
try:
data.update(await check_filters(self.dispatcher, handler_obj.filters, args))
data.update(await check_filters(handler_obj.filters, args))
except FilterNotPassed:
continue
else:
@ -126,4 +136,4 @@ class Handler:
class HandlerObj:
handler: callable
spec: inspect.FullArgSpec
filters: Union[list, tuple, set, None] = None
filters: Optional[Iterable[FilterObj]] = None