Refactor filters.

This commit is contained in:
Alex Root Junior 2018-07-13 22:58:47 +03:00
parent 539c76a062
commit f957883082
5 changed files with 191 additions and 80 deletions

View file

@ -6,8 +6,9 @@ import time
import typing
from contextvars import ContextVar
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \
RegexpFilter, StateFilter
from aiogram.dispatcher.filters import Command
from .filters import ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \
Regexp, StateFilter
from .handler import Handler
from .middlewares import MiddlewareManager
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
@ -84,10 +85,10 @@ class Dispatcher:
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
])
filters_factory.bind(CommandsFilter, event_handlers=[
filters_factory.bind(Command, event_handlers=[
self.message_handlers, self.edited_message_handlers
])
filters_factory.bind(RegexpFilter, event_handlers=[
filters_factory.bind(Regexp, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
self.callback_query_handlers

View file

@ -1,13 +1,14 @@
from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \
StateFilter, Text
from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \
RegexpCommandsFilter, StateFilter, Text
from .factory import FiltersFactory
from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
__all__ = [
'AbstractFilter',
'BaseFilter',
'BoundFilter',
'Command',
'CommandsFilter',
'CommandStart',
'CommandHelp',
'ContentTypeFilter',
'ExceptionsFilter',
'Filter',
@ -15,7 +16,7 @@ __all__ = [
'FilterRecord',
'FiltersFactory',
'RegexpCommandsFilter',
'RegexpFilter',
'Regexp',
'StateFilter',
'Text',
'check_filter',

View file

@ -2,15 +2,30 @@ import inspect
import re
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Iterable, Optional, Union
from aiogram import types
from aiogram.dispatcher.filters.filters import BaseFilter, Filter
from aiogram.dispatcher.filters.filters import BoundFilter, Filter
from aiogram.types import CallbackQuery, ContentType, Message
class Command(Filter):
def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False):
"""
You can handle commands by using this filter
"""
def __init__(self, commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = '/',
ignore_case: bool = True,
ignore_mention: bool = False):
"""
Filter can be initialized from filters factory or by simply creating instance of this class
:param commands: command or list of commands
:param prefixes:
:param ignore_case:
:param ignore_mention:
"""
if isinstance(commands, str):
commands = (commands,)
@ -19,6 +34,26 @@ class Command(Filter):
self.ignore_case = ignore_case
self.ignore_mention = ignore_mention
@classmethod
def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Validator for filters factory
:param full_config:
:return: config or empty dict
"""
config = {}
if 'commands' in full_config:
config['commands'] = full_config.pop('commands')
if 'commands_prefix' in full_config:
config['prefixes'] = full_config.pop('commands_prefix')
if 'commands_ignore_mention' in full_config:
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
return config
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
@staticmethod
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
full_command = message.text.split()[0]
@ -33,9 +68,6 @@ class Command(Filter):
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
@dataclass
class CommandObj:
prefix: str = '/'
@ -57,29 +89,36 @@ class Command(Filter):
return line
class CommandsFilter(BaseFilter):
"""
Check commands in message
"""
key = 'commands'
class CommandStart(Command):
def __init__(self):
super(CommandStart, self).__init__(['start'])
def __init__(self, dispatcher, commands):
super().__init__(dispatcher)
if isinstance(commands, str):
commands = (commands,)
self.commands = commands
async def check(self, message):
return await Command.check_command(message, self.commands, '/')
class CommandHelp(Command):
def __init__(self):
super(CommandHelp, self).__init__(['help'])
class Text(Filter):
"""
Simple text filter
"""
def __init__(self,
equals: Optional[str] = None,
contains: Optional[str] = None,
startswith: Optional[str] = None,
endswith: Optional[str] = None,
ignore_case=False):
"""
Check text for one of pattern. Only one mode can be used in one filter.
:param equals:
:param contains:
:param startswith:
:param endswith:
:param ignore_case: case insensitive
"""
# Only one mode can be used. check it.
check = sum(map(bool, (equals, contains, startswith, endswith)))
if check > 1:
@ -98,8 +137,27 @@ class Text(Filter):
self.startswith = startswith
self.ignore_case = ignore_case
async def check(self, message: types.Message):
text = message.text.lower() if self.ignore_case else message.text
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'text' in full_config:
return {'equals': full_config.pop('text')}
elif 'text_contains' in full_config:
return {'contains': full_config.pop('text_contains')}
elif 'text_startswith' in full_config:
return {'startswith': full_config.pop('text_startswith')}
elif 'text_endswith' in full_config:
return {'endswith': full_config.pop('text_endswith')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
text = obj.text or obj.caption or ''
elif isinstance(obj, CallbackQuery):
text = obj.data
else:
return False
if self.ignore_case:
text = text.lower()
if self.equals:
return text == self.equals
@ -113,24 +171,24 @@ class Text(Filter):
return False
class RegexpFilter(BaseFilter):
class Regexp(Filter):
"""
Regexp filter for messages and callback query
"""
key = 'regexp'
def __init__(self, dispatcher, regexp):
super().__init__(dispatcher)
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
def __init__(self, regexp):
if not isinstance(regexp, re.Pattern):
regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
self.regexp = regexp
async def check(self, obj):
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'regexp' in full_config:
return {'regexp': full_config.pop('regexp')}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
if obj.text:
match = self.regexp.search(obj.text)
elif obj.caption:
match = self.regexp.search(obj.caption)
else:
return False
match = self.regexp.search(obj.text or obj.caption or '')
elif isinstance(obj, CallbackQuery) and obj.data:
match = self.regexp.search(obj.data)
else:
@ -141,15 +199,14 @@ class RegexpFilter(BaseFilter):
return False
class RegexpCommandsFilter(BaseFilter):
class RegexpCommandsFilter(BoundFilter):
"""
Check commands by regexp in message
"""
key = 'regexp_commands'
def __init__(self, dispatcher, regexp_commands):
super().__init__(dispatcher)
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
async def check(self, message):
@ -169,7 +226,7 @@ class RegexpCommandsFilter(BaseFilter):
return False
class ContentTypeFilter(BaseFilter):
class ContentTypeFilter(BoundFilter):
"""
Check message content type
"""
@ -178,8 +235,7 @@ class ContentTypeFilter(BaseFilter):
required = True
default = types.ContentType.TEXT
def __init__(self, dispatcher, content_types):
super().__init__(dispatcher)
def __init__(self, content_types):
self.content_types = content_types
async def check(self, message):
@ -187,7 +243,7 @@ class ContentTypeFilter(BaseFilter):
message.content_type in self.content_types
class StateFilter(BaseFilter):
class StateFilter(BoundFilter):
"""
Check user state
"""
@ -199,7 +255,7 @@ class StateFilter(BaseFilter):
def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State, StatesGroup
super().__init__(dispatcher)
self.dispatcher = dispatcher
states = []
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
state = [state, ]
@ -237,7 +293,7 @@ class StateFilter(BaseFilter):
return False
class ExceptionsFilter(BaseFilter):
class ExceptionsFilter(BoundFilter):
"""
Filter for exceptions
"""

View file

@ -21,29 +21,35 @@ def wrap_async(func):
return async_wrapper
async def check_filter(filter_, args):
async def check_filter(dispatcher, filter_, args):
"""
Helper for executing filter
:param dispatcher:
:param filter_:
:param args:
:return:
"""
kwargs = {}
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
spec = inspect.getfullargspec(filter_)
if 'dispatcher' in spec:
kwargs['dispatcher'] = dispatcher
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, AbstractFilter):
return await filter_(*args)
return await filter_(*args, **kwargs)
else:
return filter_(*args)
return filter_(*args, **kwargs)
async def check_filters(filters, args):
async def check_filters(dispatcher, filters, args):
"""
Check list of filters
:param dispatcher:
:param filters:
:param args:
:return:
@ -51,7 +57,7 @@ async def check_filters(filters, args):
data = {}
if filters is not None:
for filter_ in filters:
f = await check_filter(filter_, args)
f = await check_filter(dispatcher, filter_, args)
if not f:
raise FilterNotPassed()
elif isinstance(f, dict):
@ -89,11 +95,16 @@ class FilterRecord:
return
config = self.resolver(full_config)
if config:
if 'dispatcher' not in config:
spec = inspect.getfullargspec(self.callback)
if 'dispatcher' in spec.args:
config['dispatcher'] = dispatcher
for key in config:
if key in full_config:
full_config.pop(key)
return self.callback(dispatcher, **config)
return self.callback(**config)
def _check_event_handler(self, event_handler) -> bool:
if self.event_handlers:
@ -108,8 +119,6 @@ class AbstractFilter(abc.ABC):
Abstract class for custom filters
"""
key = None
@classmethod
@abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
@ -138,13 +147,29 @@ class AbstractFilter(abc.ABC):
return NotFilter(self)
def __and__(self, other):
if isinstance(self, AndFilter):
self.append(other)
return self
return AndFilter(self, other)
def __or__(self, other):
if isinstance(self, OrFilter):
self.append(other)
return self
return OrFilter(self, other)
class BaseFilter(AbstractFilter):
class Filter(AbstractFilter):
"""
You can make subclasses of that class for custom filters
"""
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
pass
class BoundFilter(Filter):
"""
Base class for filters with default validator
"""
@ -152,10 +177,6 @@ class BaseFilter(AbstractFilter):
required = False
default = None
def __init__(self, dispatcher, **config):
self.dispatcher = dispatcher
self.config = config
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None:
@ -165,33 +186,65 @@ class BaseFilter(AbstractFilter):
return {cls.key: cls.default}
class Filter(AbstractFilter):
class _LogicFilter(Filter):
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
raise RuntimeError('This filter can\'t be passed as kwargs')
def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise ValueError('That filter can\'t be used in filters factory!')
class NotFilter(Filter):
class NotFilter(_LogicFilter):
def __init__(self, target):
self.target = wrap_async(target)
async def check(self, *args):
return await self.target(*args)
return not bool(await self.target(*args))
class AndFilter(Filter):
def __init__(self, target, target2):
self.target = wrap_async(target)
self.target2 = wrap_async(target2)
class AndFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = list(wrap_async(target) for target in targets)
async def check(self, *args):
return (await self.target(*args)) and (await self.target2(*args))
"""
All filters must return a positive result
:param args:
:return:
"""
data = {}
for target in self.targets:
result = await target(*args)
if not result:
return False
if isinstance(result, dict):
data.update(result)
if not data:
return True
return data
def append(self, target):
self.targets.append(wrap_async(target))
class OrFilter(Filter):
def __init__(self, target, target2):
self.target = wrap_async(target)
self.target2 = wrap_async(target2)
class OrFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = list(wrap_async(target) for target in targets)
async def check(self, *args):
return (await self.target(*args)) or (await self.target2(*args))
"""
One of filters must return a positive result
:param args:
:return:
"""
for target in self.targets:
result = await target(*args)
if result:
if isinstance(result, dict):
return result
return True
return False
def append(self, target):
self.targets.append(wrap_async(target))

View file

@ -83,7 +83,7 @@ class Handler:
try:
for filters, handler in self.handlers:
try:
data.update(await check_filters(filters, args))
data.update(await check_filters(self.dispatcher, filters, args))
except FilterNotPassed:
continue
else: