Changed the order of filters and optimize creation of default filters.

This commit is contained in:
Alex Root Junior 2018-06-27 01:46:35 +03:00
parent b4d8ac2c0a
commit fe6ae4863a
4 changed files with 19 additions and 24 deletions

View file

@ -78,6 +78,13 @@ class Dispatcher:
self._closed = True
self._close_waiter = loop.create_future()
filters_factory.bind(StateFilter, exclude_event_handlers=[
self.errors_handlers
])
filters_factory.bind(ContentTypeFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
])
filters_factory.bind(CommandsFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers
])
@ -89,13 +96,6 @@ class Dispatcher:
filters_factory.bind(RegexpCommandsFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers
])
filters_factory.bind(ContentTypeFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
])
filters_factory.bind(StateFilter, exclude_event_handlers=[
self.errors_handlers
])
filters_factory.bind(ExceptionsFilter, event_handlers=[
self.errors_handlers
])

View file

@ -95,17 +95,13 @@ class ContentTypeFilter(BaseFilter):
"""
key = 'content_types'
required = True
default = types.ContentType.TEXT
def __init__(self, dispatcher, content_types):
super().__init__(dispatcher)
self.content_types = content_types
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]):
result = super(ContentTypeFilter, cls).validate(full_config)
if not result:
return {cls.key: types.ContentType.TEXT}
async def check(self, message):
return ContentType.ANY[0] in self.content_types or \
message.content_type in self.content_types
@ -116,6 +112,7 @@ class StateFilter(BaseFilter):
Check user state
"""
key = 'state'
required = True
ctx_state = ContextVar('user_state')
@ -128,13 +125,6 @@ class StateFilter(BaseFilter):
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]):
result = super(StateFilter, cls).validate(full_config)
if not result:
return {cls.key: None}
return result
async def check(self, obj):
if '*' in self.state:
return {'state': self.dispatcher.current_state()}

View file

@ -53,10 +53,10 @@ class FiltersFactory:
:return:
"""
filters_set = []
if custom_filters:
filters_set.extend(custom_filters)
filters_set.extend(self._resolve_registered(event_handler,
{k: v for k, v in full_config.items() if v is not None}))
if custom_filters:
filters_set.extend(custom_filters)
return filters_set

View file

@ -133,8 +133,13 @@ class BaseFilter(AbstractFilter):
Base class for filters with default validator
"""
key = None
required = False
default = None
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None and cls.key in full_config:
return {cls.key: full_config[cls.key]}
if cls.key is not None:
if cls.key in full_config:
return {cls.key: full_config[cls.key]}
elif cls.required:
return {cls.key: cls.default}