Refactor filters.

This commit is contained in:
Alex Root Junior 2018-07-13 22:58:47 +03:00
parent 539c76a062
commit f957883082
5 changed files with 191 additions and 80 deletions

View file

@ -6,8 +6,9 @@ import time
import typing import typing
from contextvars import ContextVar from contextvars import ContextVar
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \ from aiogram.dispatcher.filters import Command
RegexpFilter, StateFilter from .filters import ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \
Regexp, StateFilter
from .handler import Handler from .handler import Handler
from .middlewares import MiddlewareManager from .middlewares import MiddlewareManager
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
@ -84,10 +85,10 @@ class Dispatcher:
self.message_handlers, self.edited_message_handlers, self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers, self.channel_post_handlers, self.edited_channel_post_handlers,
]) ])
filters_factory.bind(CommandsFilter, event_handlers=[ filters_factory.bind(Command, event_handlers=[
self.message_handlers, self.edited_message_handlers self.message_handlers, self.edited_message_handlers
]) ])
filters_factory.bind(RegexpFilter, event_handlers=[ filters_factory.bind(Regexp, event_handlers=[
self.message_handlers, self.edited_message_handlers, self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers, self.channel_post_handlers, self.edited_channel_post_handlers,
self.callback_query_handlers self.callback_query_handlers

View file

@ -1,13 +1,14 @@
from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \ from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
StateFilter, Text RegexpCommandsFilter, StateFilter, Text
from .factory import FiltersFactory from .factory import FiltersFactory
from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
__all__ = [ __all__ = [
'AbstractFilter', 'AbstractFilter',
'BaseFilter', 'BoundFilter',
'Command', 'Command',
'CommandsFilter', 'CommandStart',
'CommandHelp',
'ContentTypeFilter', 'ContentTypeFilter',
'ExceptionsFilter', 'ExceptionsFilter',
'Filter', 'Filter',
@ -15,7 +16,7 @@ __all__ = [
'FilterRecord', 'FilterRecord',
'FiltersFactory', 'FiltersFactory',
'RegexpCommandsFilter', 'RegexpCommandsFilter',
'RegexpFilter', 'Regexp',
'StateFilter', 'StateFilter',
'Text', 'Text',
'check_filter', 'check_filter',

View file

@ -2,15 +2,30 @@ import inspect
import re import re
from contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Any, Dict, Iterable, Optional, Union
from aiogram import types from aiogram import types
from aiogram.dispatcher.filters.filters import BaseFilter, Filter from aiogram.dispatcher.filters.filters import BoundFilter, Filter
from aiogram.types import CallbackQuery, ContentType, Message from aiogram.types import CallbackQuery, ContentType, Message
class Command(Filter): class Command(Filter):
def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False): """
You can handle commands by using this filter
"""
def __init__(self, commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = '/',
ignore_case: bool = True,
ignore_mention: bool = False):
"""
Filter can be initialized from filters factory or by simply creating instance of this class
:param commands: command or list of commands
:param prefixes:
:param ignore_case:
:param ignore_mention:
"""
if isinstance(commands, str): if isinstance(commands, str):
commands = (commands,) commands = (commands,)
@ -19,6 +34,26 @@ class Command(Filter):
self.ignore_case = ignore_case self.ignore_case = ignore_case
self.ignore_mention = ignore_mention self.ignore_mention = ignore_mention
@classmethod
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Validator for filters factory
:param full_config:
:return: config or empty dict
"""
config = {}
if 'commands' in full_config:
config['commands'] = full_config.pop('commands')
if 'commands_prefix' in full_config:
config['prefixes'] = full_config.pop('commands_prefix')
if 'commands_ignore_mention' in full_config:
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
return config
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
@staticmethod @staticmethod
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False): async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
full_command = message.text.split()[0] full_command = message.text.split()[0]
@ -33,9 +68,6 @@ class Command(Filter):
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)} return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
@dataclass @dataclass
class CommandObj: class CommandObj:
prefix: str = '/' prefix: str = '/'
@ -57,29 +89,36 @@ class Command(Filter):
return line return line
class CommandsFilter(BaseFilter): class CommandStart(Command):
""" def __init__(self):
Check commands in message super(CommandStart, self).__init__(['start'])
"""
key = 'commands'
def __init__(self, dispatcher, commands):
super().__init__(dispatcher)
if isinstance(commands, str):
commands = (commands,)
self.commands = commands
async def check(self, message): class CommandHelp(Command):
return await Command.check_command(message, self.commands, '/') def __init__(self):
super(CommandHelp, self).__init__(['help'])
class Text(Filter): class Text(Filter):
"""
Simple text filter
"""
def __init__(self, def __init__(self,
equals: Optional[str] = None, equals: Optional[str] = None,
contains: Optional[str] = None, contains: Optional[str] = None,
startswith: Optional[str] = None, startswith: Optional[str] = None,
endswith: Optional[str] = None, endswith: Optional[str] = None,
ignore_case=False): ignore_case=False):
"""
Check text for one of pattern. Only one mode can be used in one filter.
:param equals:
:param contains:
:param startswith:
:param endswith:
:param ignore_case: case insensitive
"""
# Only one mode can be used. check it. # Only one mode can be used. check it.
check = sum(map(bool, (equals, contains, startswith, endswith))) check = sum(map(bool, (equals, contains, startswith, endswith)))
if check > 1: if check > 1:
@ -98,8 +137,27 @@ class Text(Filter):
self.startswith = startswith self.startswith = startswith
self.ignore_case = ignore_case self.ignore_case = ignore_case
async def check(self, message: types.Message): @classmethod
text = message.text.lower() if self.ignore_case else message.text def validate(cls, full_config: Dict[str, Any]):
if 'text' in full_config:
return {'equals': full_config.pop('text')}
elif 'text_contains' in full_config:
return {'contains': full_config.pop('text_contains')}
elif 'text_startswith' in full_config:
return {'startswith': full_config.pop('text_startswith')}
elif 'text_endswith' in full_config:
return {'endswith': full_config.pop('text_endswith')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
text = obj.text or obj.caption or ''
elif isinstance(obj, CallbackQuery):
text = obj.data
else:
return False
if self.ignore_case:
text = text.lower()
if self.equals: if self.equals:
return text == self.equals return text == self.equals
@ -113,24 +171,24 @@ class Text(Filter):
return False return False
class RegexpFilter(BaseFilter): class Regexp(Filter):
""" """
Regexp filter for messages and callback query Regexp filter for messages and callback query
""" """
key = 'regexp'
def __init__(self, dispatcher, regexp): def __init__(self, regexp):
super().__init__(dispatcher) if not isinstance(regexp, re.Pattern):
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
self.regexp = regexp
async def check(self, obj): @classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'regexp' in full_config:
return {'regexp': full_config.pop('regexp')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message): if isinstance(obj, Message):
if obj.text: match = self.regexp.search(obj.text or obj.caption or '')
match = self.regexp.search(obj.text)
elif obj.caption:
match = self.regexp.search(obj.caption)
else:
return False
elif isinstance(obj, CallbackQuery) and obj.data: elif isinstance(obj, CallbackQuery) and obj.data:
match = self.regexp.search(obj.data) match = self.regexp.search(obj.data)
else: else:
@ -141,15 +199,14 @@ class RegexpFilter(BaseFilter):
return False return False
class RegexpCommandsFilter(BaseFilter): class RegexpCommandsFilter(BoundFilter):
""" """
Check commands by regexp in message Check commands by regexp in message
""" """
key = 'regexp_commands' key = 'regexp_commands'
def __init__(self, dispatcher, regexp_commands): def __init__(self, regexp_commands):
super().__init__(dispatcher)
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message): async def check(self, message):
@ -169,7 +226,7 @@ class RegexpCommandsFilter(BaseFilter):
return False return False
class ContentTypeFilter(BaseFilter): class ContentTypeFilter(BoundFilter):
""" """
Check message content type Check message content type
""" """
@ -178,8 +235,7 @@ class ContentTypeFilter(BaseFilter):
required = True required = True
default = types.ContentType.TEXT default = types.ContentType.TEXT
def __init__(self, dispatcher, content_types): def __init__(self, content_types):
super().__init__(dispatcher)
self.content_types = content_types self.content_types = content_types
async def check(self, message): async def check(self, message):
@ -187,7 +243,7 @@ class ContentTypeFilter(BaseFilter):
message.content_type in self.content_types message.content_type in self.content_types
class StateFilter(BaseFilter): class StateFilter(BoundFilter):
""" """
Check user state Check user state
""" """
@ -199,7 +255,7 @@ class StateFilter(BaseFilter):
def __init__(self, dispatcher, state): def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State, StatesGroup from aiogram.dispatcher.filters.state import State, StatesGroup
super().__init__(dispatcher) self.dispatcher = dispatcher
states = [] states = []
if not isinstance(state, (list, set, tuple, frozenset)) or state is None: if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
state = [state, ] state = [state, ]
@ -237,7 +293,7 @@ class StateFilter(BaseFilter):
return False return False
class ExceptionsFilter(BaseFilter): class ExceptionsFilter(BoundFilter):
""" """
Filter for exceptions Filter for exceptions
""" """

View file

@ -21,29 +21,35 @@ def wrap_async(func):
return async_wrapper return async_wrapper
async def check_filter(filter_, args): async def check_filter(dispatcher, filter_, args):
""" """
Helper for executing filter Helper for executing filter
:param dispatcher:
:param filter_: :param filter_:
:param args: :param args:
:return: :return:
""" """
kwargs = {}
if not callable(filter_): if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!') raise TypeError('Filter must be callable and/or awaitable!')
spec = inspect.getfullargspec(filter_)
if 'dispatcher' in spec:
kwargs['dispatcher'] = dispatcher
if inspect.isawaitable(filter_) \ if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \ or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, AbstractFilter): or isinstance(filter_, AbstractFilter):
return await filter_(*args) return await filter_(*args, **kwargs)
else: else:
return filter_(*args) return filter_(*args, **kwargs)
async def check_filters(filters, args): async def check_filters(dispatcher, filters, args):
""" """
Check list of filters Check list of filters
:param dispatcher:
:param filters: :param filters:
:param args: :param args:
:return: :return:
@ -51,7 +57,7 @@ async def check_filters(filters, args):
data = {} data = {}
if filters is not None: if filters is not None:
for filter_ in filters: for filter_ in filters:
f = await check_filter(filter_, args) f = await check_filter(dispatcher, filter_, args)
if not f: if not f:
raise FilterNotPassed() raise FilterNotPassed()
elif isinstance(f, dict): elif isinstance(f, dict):
@ -89,11 +95,16 @@ class FilterRecord:
return return
config = self.resolver(full_config) config = self.resolver(full_config)
if config: if config:
if 'dispatcher' not in config:
spec = inspect.getfullargspec(self.callback)
if 'dispatcher' in spec.args:
config['dispatcher'] = dispatcher
for key in config: for key in config:
if key in full_config: if key in full_config:
full_config.pop(key) full_config.pop(key)
return self.callback(dispatcher, **config) return self.callback(**config)
def _check_event_handler(self, event_handler) -> bool: def _check_event_handler(self, event_handler) -> bool:
if self.event_handlers: if self.event_handlers:
@ -108,8 +119,6 @@ class AbstractFilter(abc.ABC):
Abstract class for custom filters Abstract class for custom filters
""" """
key = None
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
@ -138,13 +147,29 @@ class AbstractFilter(abc.ABC):
return NotFilter(self) return NotFilter(self)
def __and__(self, other): def __and__(self, other):
if isinstance(self, AndFilter):
self.append(other)
return self
return AndFilter(self, other) return AndFilter(self, other)
def __or__(self, other): def __or__(self, other):
if isinstance(self, OrFilter):
self.append(other)
return self
return OrFilter(self, other) return OrFilter(self, other)
class BaseFilter(AbstractFilter): class Filter(AbstractFilter):
"""
You can make subclasses of that class for custom filters
"""
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
pass
class BoundFilter(Filter):
""" """
Base class for filters with default validator Base class for filters with default validator
""" """
@ -152,10 +177,6 @@ class BaseFilter(AbstractFilter):
required = False required = False
default = None default = None
def __init__(self, dispatcher, **config):
self.dispatcher = dispatcher
self.config = config
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None: if cls.key is not None:
@ -165,33 +186,65 @@ class BaseFilter(AbstractFilter):
return {cls.key: cls.default} return {cls.key: cls.default}
class Filter(AbstractFilter): class _LogicFilter(Filter):
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise RuntimeError('This filter can\'t be passed as kwargs') raise ValueError('That filter can\'t be used in filters factory!')
class NotFilter(Filter): class NotFilter(_LogicFilter):
def __init__(self, target): def __init__(self, target):
self.target = wrap_async(target) self.target = wrap_async(target)
async def check(self, *args): async def check(self, *args):
return await self.target(*args) return not bool(await self.target(*args))
class AndFilter(Filter): class AndFilter(_LogicFilter):
def __init__(self, target, target2):
self.target = wrap_async(target) def __init__(self, *targets):
self.target2 = wrap_async(target2) self.targets = list(wrap_async(target) for target in targets)
async def check(self, *args): async def check(self, *args):
return (await self.target(*args)) and (await self.target2(*args)) """
All filters must return a positive result
:param args:
:return:
"""
data = {}
for target in self.targets:
result = await target(*args)
if not result:
return False
if isinstance(result, dict):
data.update(result)
if not data:
return True
return data
def append(self, target):
self.targets.append(wrap_async(target))
class OrFilter(Filter): class OrFilter(_LogicFilter):
def __init__(self, target, target2): def __init__(self, *targets):
self.target = wrap_async(target) self.targets = list(wrap_async(target) for target in targets)
self.target2 = wrap_async(target2)
async def check(self, *args): async def check(self, *args):
return (await self.target(*args)) or (await self.target2(*args)) """
One of filters must return a positive result
:param args:
:return:
"""
for target in self.targets:
result = await target(*args)
if result:
if isinstance(result, dict):
return result
return True
return False
def append(self, target):
self.targets.append(wrap_async(target))

View file

@ -83,7 +83,7 @@ class Handler:
try: try:
for filters, handler in self.handlers: for filters, handler in self.handlers:
try: try:
data.update(await check_filters(filters, args)) data.update(await check_filters(self.dispatcher, filters, args))
except FilterNotPassed: except FilterNotPassed:
continue continue
else: else: