Implement filters factory.

This commit is contained in:
Alex Root Junior 2018-02-23 16:56:31 +02:00
parent a3856e33bd
commit 8d7a00204d
2 changed files with 121 additions and 19 deletions

View file

@ -5,8 +5,8 @@ import logging
import time
import typing
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, \
USER_STATE, generate_default_filters
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter, USER_STATE, \
generate_default_filters
from .handler import CancelHandler, Handler, SkipHandler
from .middlewares import MiddlewareManager
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
@ -37,12 +37,15 @@ class Dispatcher:
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
run_tasks_by_default: bool = False,
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False,
filters_factory=None):
if loop is None:
loop = bot.loop
if storage is None:
storage = DisabledStorage()
if filters_factory is None:
filters_factory = FiltersFactory(self)
self.bot: Bot = bot
self.loop = loop
@ -54,6 +57,7 @@ class Dispatcher:
self.last_update_id = 0
self.filters_factory: FiltersFactory = filters_factory
self.updates_handler = Handler(self, middleware_key='update')
self.message_handlers = Handler(self, middleware_key='message')
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
@ -74,6 +78,12 @@ class Dispatcher:
self._closed = True
self._close_waiter = loop.create_future()
filters_factory.bind(filters.CommandsFilter, 'commands')
filters_factory.bind(filters.RegexpFilter, 'regexp')
filters_factory.bind(filters.RegexpCommandsFilter, 'regexp_commands')
filters_factory.bind(filters.ContentTypeFilter, 'content_types')
filters_factory.bind(filters.StateFilter, 'state', with_dispatcher=True, default=True)
def __del__(self):
self.stop_polling()
@ -310,8 +320,8 @@ class Dispatcher:
"""
return self._polling
def register_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None,
state=None, custom_filters=None, run_task=None, **kwargs):
def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None,
func=None, state=None, run_task=None, **kwargs):
"""
Register handler for message
@ -334,23 +344,23 @@ class Dispatcher:
:param content_types: List of content types.
:param func: custom any callable object
:param custom_filters: list of custom filters
:param run_task:
:param kwargs:
:param state:
:return: decorated function
"""
if content_types is None:
content_types = ContentType.TEXT
if custom_filters is None:
custom_filters = []
if func is not None:
custom_filters = list(custom_filters)
custom_filters.append(func)
filters_set = generate_default_filters(self,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
func=func,
state=state,
**kwargs)
filters_set = self.filters_factory.parse(*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs)
self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, func=None, state=None,
@ -426,9 +436,9 @@ class Dispatcher:
"""
def decorator(callback):
self.register_message_handler(callback,
self.register_message_handler(callback, *custom_filters,
commands=commands, regexp=regexp, content_types=content_types,
func=func, state=state, custom_filters=custom_filters, run_task=run_task,
func=func, state=state, run_task=run_task,
**kwargs)
return callback

View file

@ -9,6 +9,84 @@ from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE'
class FiltersFactory:
def __init__(self, dispatcher):
self._dispatcher = dispatcher
self._filters = []
@property
def _default_filters(self):
return tuple(filter(lambda item: item[-1], self._filters))
def bind(self, filter_, *args, default=False, with_dispatcher=False):
self._filters.append((filter_, args, with_dispatcher, default))
def unbind(self, filter_):
for item in self._filters:
if filter_ is item[0]:
self._filters.remove(item)
return True
raise ValueError(f'{filter_} is not binded.')
def replace(self, original, new):
for item in self._filters:
if original is item[0]:
item[0] = new
return True
raise ValueError(f'{original} is not binded.')
def parse(self, *args, **kwargs):
"""
Generate filters list
:param args:
:param kwargs:
:return:
"""
used = []
filters = []
filters.extend(args)
# Registered filters filters
for filter_, args_list, with_dispatcher, default in self._filters:
config = {}
accept = True
for item in args_list:
value = kwargs.pop(item, None)
if value is None:
accept = False
break
config[item] = value
if accept:
if with_dispatcher:
config['dispatcher'] = self._dispatcher
filters.append(filter_(**config))
used.append(filter_)
elif default:
if filter_ not in used:
used.append(filter_)
if isinstance(filter_, Filter):
if with_dispatcher:
filters.append(filter_(dispatcher=self._dispatcher))
else:
filters.append(filter_())
# Not registered filters
for key, filter_ in kwargs.items():
if isinstance(filter_, Filter):
filters.append(filter_)
used.append(filter_.__class__)
else:
raise ValueError(f"Unknown filter with key '{key}'")
return filters
async def check_filter(filter_, args):
"""
Helper for executing filter
@ -48,6 +126,9 @@ class Filter:
Base class for filters
"""
def __init__(self, *args, **kwargs):
pass
def __call__(self, *args, **kwargs):
return self.check(*args, **kwargs)
@ -77,6 +158,7 @@ class AnyFilter(AsyncFilter):
def __init__(self, *filters: callable):
self.filters = filters
super().__init__()
async def check(self, *args):
f = (check_filter(filter_, args) for filter_ in self.filters)
@ -90,6 +172,7 @@ class NotFilter(AsyncFilter):
def __init__(self, filter_: callable):
self.filter = filter_
super().__init__()
async def check(self, *args):
return not await check_filter(self.filter, args)
@ -102,6 +185,7 @@ class CommandsFilter(AsyncFilter):
def __init__(self, commands):
self.commands = commands
super().__init__()
async def check(self, message):
if not message.is_command():
@ -126,6 +210,7 @@ class RegexpFilter(Filter):
def __init__(self, regexp):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
super().__init__()
def check(self, message):
if message.text:
@ -139,6 +224,7 @@ class RegexpCommandsFilter(AsyncFilter):
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
super().__init__()
async def check(self, message):
if not message.is_command():
@ -165,10 +251,11 @@ class ContentTypeFilter(Filter):
def __init__(self, content_types):
self.content_types = content_types
super().__init__()
def check(self, message):
return ContentType.ANY[0] in self.content_types or \
message.content_type in self.content_types
message.content_type in self.content_types
class CancelFilter(Filter):
@ -180,6 +267,7 @@ class CancelFilter(Filter):
if cancel_set is None:
cancel_set = ['/cancel', 'cancel', 'cancel.']
self.cancel_set = cancel_set
super().__init__()
def check(self, message):
if message.text:
@ -193,7 +281,10 @@ class StateFilter(AsyncFilter):
def __init__(self, dispatcher, state):
self.dispatcher = dispatcher
if isinstance(state, str):
state = (state,)
self.state = state
super().__init__()
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
@ -209,7 +300,7 @@ class StateFilter(AsyncFilter):
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
return False
@ -233,6 +324,7 @@ class ExceptionsFilter(Filter):
def __init__(self, exception):
self.exception = exception
super().__init__()
def check(self, dispatcher, update, exception):
try: