diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 7f140352..0b0eecff 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -6,8 +6,9 @@ import time import typing from contextvars import ContextVar -from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \ - RegexpFilter, StateFilter +from aiogram.dispatcher.filters import Command +from .filters import ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpCommandsFilter, \ + Regexp, StateFilter from .handler import Handler from .middlewares import MiddlewareManager from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \ @@ -84,10 +85,10 @@ class Dispatcher: self.message_handlers, self.edited_message_handlers, self.channel_post_handlers, self.edited_channel_post_handlers, ]) - filters_factory.bind(CommandsFilter, event_handlers=[ + filters_factory.bind(Command, event_handlers=[ self.message_handlers, self.edited_message_handlers ]) - filters_factory.bind(RegexpFilter, event_handlers=[ + filters_factory.bind(Regexp, event_handlers=[ self.message_handlers, self.edited_message_handlers, self.channel_post_handlers, self.edited_channel_post_handlers, self.callback_query_handlers diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index c4058abd..aa3a3ecf 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -1,13 +1,14 @@ -from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \ - StateFilter, Text +from .builtin import Command, CommandHelp, CommandStart, ContentTypeFilter, ExceptionsFilter, Regexp, \ + RegexpCommandsFilter, StateFilter, Text from .factory import FiltersFactory -from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters +from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters __all__ = [ 'AbstractFilter', - 'BaseFilter', + 'BoundFilter', 'Command', - 'CommandsFilter', + 'CommandStart', + 'CommandHelp', 'ContentTypeFilter', 'ExceptionsFilter', 'Filter', @@ -15,7 +16,7 @@ __all__ = [ 'FilterRecord', 'FiltersFactory', 'RegexpCommandsFilter', - 'RegexpFilter', + 'Regexp', 'StateFilter', 'Text', 'check_filter', diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index c8392294..420672a9 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -2,15 +2,30 @@ import inspect import re from contextvars import ContextVar from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Iterable, Optional, Union from aiogram import types -from aiogram.dispatcher.filters.filters import BaseFilter, Filter +from aiogram.dispatcher.filters.filters import BoundFilter, Filter from aiogram.types import CallbackQuery, ContentType, Message class Command(Filter): - def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False): + """ + You can handle commands by using this filter + """ + + def __init__(self, commands: Union[Iterable, str], + prefixes: Union[Iterable, str] = '/', + ignore_case: bool = True, + ignore_mention: bool = False): + """ + Filter can be initialized from filters factory or by simply creating instance of this class + + :param commands: command or list of commands + :param prefixes: + :param ignore_case: + :param ignore_mention: + """ if isinstance(commands, str): commands = (commands,) @@ -19,6 +34,26 @@ class Command(Filter): self.ignore_case = ignore_case self.ignore_mention = ignore_mention + @classmethod + def validate(cls, full_config: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Validator for filters factory + + :param full_config: + :return: config or empty dict + """ + config = {} + if 'commands' in full_config: + config['commands'] = full_config.pop('commands') + if 'commands_prefix' in full_config: + config['prefixes'] = full_config.pop('commands_prefix') + if 'commands_ignore_mention' in full_config: + config['ignore_mention'] = full_config.pop('commands_ignore_mention') + return config + + async def check(self, message: types.Message): + return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention) + @staticmethod async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False): full_command = message.text.split()[0] @@ -33,9 +68,6 @@ class Command(Filter): return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)} - async def check(self, message: types.Message): - return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention) - @dataclass class CommandObj: prefix: str = '/' @@ -57,29 +89,36 @@ class Command(Filter): return line -class CommandsFilter(BaseFilter): - """ - Check commands in message - """ - key = 'commands' +class CommandStart(Command): + def __init__(self): + super(CommandStart, self).__init__(['start']) - def __init__(self, dispatcher, commands): - super().__init__(dispatcher) - if isinstance(commands, str): - commands = (commands,) - self.commands = commands - async def check(self, message): - return await Command.check_command(message, self.commands, '/') +class CommandHelp(Command): + def __init__(self): + super(CommandHelp, self).__init__(['help']) class Text(Filter): + """ + Simple text filter + """ + def __init__(self, equals: Optional[str] = None, contains: Optional[str] = None, startswith: Optional[str] = None, endswith: Optional[str] = None, ignore_case=False): + """ + Check text for one of pattern. Only one mode can be used in one filter. + + :param equals: + :param contains: + :param startswith: + :param endswith: + :param ignore_case: case insensitive + """ # Only one mode can be used. check it. check = sum(map(bool, (equals, contains, startswith, endswith))) if check > 1: @@ -98,8 +137,27 @@ class Text(Filter): self.startswith = startswith self.ignore_case = ignore_case - async def check(self, message: types.Message): - text = message.text.lower() if self.ignore_case else message.text + @classmethod + def validate(cls, full_config: Dict[str, Any]): + if 'text' in full_config: + return {'equals': full_config.pop('text')} + elif 'text_contains' in full_config: + return {'contains': full_config.pop('text_contains')} + elif 'text_startswith' in full_config: + return {'startswith': full_config.pop('text_startswith')} + elif 'text_endswith' in full_config: + return {'endswith': full_config.pop('text_endswith')} + + async def check(self, obj: Union[Message, CallbackQuery]): + if isinstance(obj, Message): + text = obj.text or obj.caption or '' + elif isinstance(obj, CallbackQuery): + text = obj.data + else: + return False + + if self.ignore_case: + text = text.lower() if self.equals: return text == self.equals @@ -113,24 +171,24 @@ class Text(Filter): return False -class RegexpFilter(BaseFilter): +class Regexp(Filter): """ Regexp filter for messages and callback query """ - key = 'regexp' - def __init__(self, dispatcher, regexp): - super().__init__(dispatcher) - self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) + def __init__(self, regexp): + if not isinstance(regexp, re.Pattern): + regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE) + self.regexp = regexp - async def check(self, obj): + @classmethod + def validate(cls, full_config: Dict[str, Any]): + if 'regexp' in full_config: + return {'regexp': full_config.pop('regexp')} + + async def check(self, obj: Union[Message, CallbackQuery]): if isinstance(obj, Message): - if obj.text: - match = self.regexp.search(obj.text) - elif obj.caption: - match = self.regexp.search(obj.caption) - else: - return False + match = self.regexp.search(obj.text or obj.caption or '') elif isinstance(obj, CallbackQuery) and obj.data: match = self.regexp.search(obj.data) else: @@ -141,15 +199,14 @@ class RegexpFilter(BaseFilter): return False -class RegexpCommandsFilter(BaseFilter): +class RegexpCommandsFilter(BoundFilter): """ Check commands by regexp in message """ key = 'regexp_commands' - def __init__(self, dispatcher, regexp_commands): - super().__init__(dispatcher) + def __init__(self, regexp_commands): self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands] async def check(self, message): @@ -169,7 +226,7 @@ class RegexpCommandsFilter(BaseFilter): return False -class ContentTypeFilter(BaseFilter): +class ContentTypeFilter(BoundFilter): """ Check message content type """ @@ -178,8 +235,7 @@ class ContentTypeFilter(BaseFilter): required = True default = types.ContentType.TEXT - def __init__(self, dispatcher, content_types): - super().__init__(dispatcher) + def __init__(self, content_types): self.content_types = content_types async def check(self, message): @@ -187,7 +243,7 @@ class ContentTypeFilter(BaseFilter): message.content_type in self.content_types -class StateFilter(BaseFilter): +class StateFilter(BoundFilter): """ Check user state """ @@ -199,7 +255,7 @@ class StateFilter(BaseFilter): def __init__(self, dispatcher, state): from aiogram.dispatcher.filters.state import State, StatesGroup - super().__init__(dispatcher) + self.dispatcher = dispatcher states = [] if not isinstance(state, (list, set, tuple, frozenset)) or state is None: state = [state, ] @@ -237,7 +293,7 @@ class StateFilter(BaseFilter): return False -class ExceptionsFilter(BaseFilter): +class ExceptionsFilter(BoundFilter): """ Filter for exceptions """ diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py index 01b8722a..816f4722 100644 --- a/aiogram/dispatcher/filters/filters.py +++ b/aiogram/dispatcher/filters/filters.py @@ -21,29 +21,35 @@ def wrap_async(func): return async_wrapper -async def check_filter(filter_, args): +async def check_filter(dispatcher, filter_, args): """ Helper for executing filter + :param dispatcher: :param filter_: :param args: :return: """ + kwargs = {} if not callable(filter_): raise TypeError('Filter must be callable and/or awaitable!') + spec = inspect.getfullargspec(filter_) + if 'dispatcher' in spec: + kwargs['dispatcher'] = dispatcher if inspect.isawaitable(filter_) \ or inspect.iscoroutinefunction(filter_) \ or isinstance(filter_, AbstractFilter): - return await filter_(*args) + return await filter_(*args, **kwargs) else: - return filter_(*args) + return filter_(*args, **kwargs) -async def check_filters(filters, args): +async def check_filters(dispatcher, filters, args): """ Check list of filters + :param dispatcher: :param filters: :param args: :return: @@ -51,7 +57,7 @@ async def check_filters(filters, args): data = {} if filters is not None: for filter_ in filters: - f = await check_filter(filter_, args) + f = await check_filter(dispatcher, filter_, args) if not f: raise FilterNotPassed() elif isinstance(f, dict): @@ -89,11 +95,16 @@ class FilterRecord: return config = self.resolver(full_config) if config: + if 'dispatcher' not in config: + spec = inspect.getfullargspec(self.callback) + if 'dispatcher' in spec.args: + config['dispatcher'] = dispatcher + for key in config: if key in full_config: full_config.pop(key) - return self.callback(dispatcher, **config) + return self.callback(**config) def _check_event_handler(self, event_handler) -> bool: if self.event_handlers: @@ -108,8 +119,6 @@ class AbstractFilter(abc.ABC): Abstract class for custom filters """ - key = None - @classmethod @abc.abstractmethod def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: @@ -138,13 +147,29 @@ class AbstractFilter(abc.ABC): return NotFilter(self) def __and__(self, other): + if isinstance(self, AndFilter): + self.append(other) + return self return AndFilter(self, other) def __or__(self, other): + if isinstance(self, OrFilter): + self.append(other) + return self return OrFilter(self, other) -class BaseFilter(AbstractFilter): +class Filter(AbstractFilter): + """ + You can make subclasses of that class for custom filters + """ + + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + pass + + +class BoundFilter(Filter): """ Base class for filters with default validator """ @@ -152,10 +177,6 @@ class BaseFilter(AbstractFilter): required = False default = None - def __init__(self, dispatcher, **config): - self.dispatcher = dispatcher - self.config = config - @classmethod def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: if cls.key is not None: @@ -165,33 +186,65 @@ class BaseFilter(AbstractFilter): return {cls.key: cls.default} -class Filter(AbstractFilter): +class _LogicFilter(Filter): @classmethod - def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: - raise RuntimeError('This filter can\'t be passed as kwargs') + def validate(cls, full_config: typing.Dict[str, typing.Any]): + raise ValueError('That filter can\'t be used in filters factory!') -class NotFilter(Filter): +class NotFilter(_LogicFilter): def __init__(self, target): self.target = wrap_async(target) async def check(self, *args): - return await self.target(*args) + return not bool(await self.target(*args)) -class AndFilter(Filter): - def __init__(self, target, target2): - self.target = wrap_async(target) - self.target2 = wrap_async(target2) +class AndFilter(_LogicFilter): + + def __init__(self, *targets): + self.targets = list(wrap_async(target) for target in targets) async def check(self, *args): - return (await self.target(*args)) and (await self.target2(*args)) + """ + All filters must return a positive result + + :param args: + :return: + """ + data = {} + for target in self.targets: + result = await target(*args) + if not result: + return False + if isinstance(result, dict): + data.update(result) + if not data: + return True + return data + + def append(self, target): + self.targets.append(wrap_async(target)) -class OrFilter(Filter): - def __init__(self, target, target2): - self.target = wrap_async(target) - self.target2 = wrap_async(target2) +class OrFilter(_LogicFilter): + def __init__(self, *targets): + self.targets = list(wrap_async(target) for target in targets) async def check(self, *args): - return (await self.target(*args)) or (await self.target2(*args)) + """ + One of filters must return a positive result + + :param args: + :return: + """ + for target in self.targets: + result = await target(*args) + if result: + if isinstance(result, dict): + return result + return True + return False + + def append(self, target): + self.targets.append(wrap_async(target)) diff --git a/aiogram/dispatcher/handler.py b/aiogram/dispatcher/handler.py index fc98da2a..4ded9316 100644 --- a/aiogram/dispatcher/handler.py +++ b/aiogram/dispatcher/handler.py @@ -83,7 +83,7 @@ class Handler: try: for filters, handler in self.handlers: try: - data.update(await check_filters(filters, args)) + data.update(await check_filters(self.dispatcher, filters, args)) except FilterNotPassed: continue else: