From 24184b1c8f0698ff42395a583bfa663042e06be2 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Wed, 11 Jul 2018 23:16:12 +0300 Subject: [PATCH] Improved filters. --- aiogram/dispatcher/filters/__init__.py | 13 ++-- aiogram/dispatcher/filters/builtin.py | 102 ++++++++++++++++++++++--- aiogram/dispatcher/filters/filters.py | 60 ++++++++++++++- 3 files changed, 155 insertions(+), 20 deletions(-) diff --git a/aiogram/dispatcher/filters/__init__.py b/aiogram/dispatcher/filters/__init__.py index 5a4bbaa8..c4058abd 100644 --- a/aiogram/dispatcher/filters/__init__.py +++ b/aiogram/dispatcher/filters/__init__.py @@ -1,20 +1,23 @@ -from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \ - RegexpFilter, StateFilter +from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \ + StateFilter, Text from .factory import FiltersFactory -from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters +from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters __all__ = [ 'AbstractFilter', 'BaseFilter', + 'Command', 'CommandsFilter', 'ContentTypeFilter', 'ExceptionsFilter', + 'Filter', + 'FilterNotPassed', 'FilterRecord', 'FiltersFactory', 'RegexpCommandsFilter', 'RegexpFilter', 'StateFilter', + 'Text', 'check_filter', - 'check_filters', - 'FilterNotPassed' + 'check_filters' ] diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index afac9feb..c8392294 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -1,12 +1,62 @@ import inspect import re -from _contextvars import ContextVar +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Optional from aiogram import types -from aiogram.dispatcher.filters.filters import BaseFilter +from aiogram.dispatcher.filters.filters import BaseFilter, Filter from aiogram.types import CallbackQuery, ContentType, Message +class Command(Filter): + def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False): + if isinstance(commands, str): + commands = (commands,) + + self.commands = list(map(str.lower, commands)) if ignore_case else commands + self.prefixes = prefixes + self.ignore_case = ignore_case + self.ignore_mention = ignore_mention + + @staticmethod + async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False): + full_command = message.text.split()[0] + prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@') + + if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower(): + return False + elif prefix not in prefixes: + return False + elif (command.lower() if ignore_case else command) not in commands: + return False + + return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)} + + async def check(self, message: types.Message): + return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention) + + @dataclass + class CommandObj: + prefix: str = '/' + command: str = '' + mention: str = None + args: str = None + + @property + def mentioned(self) -> bool: + return bool(self.mention) + + @property + def text(self) -> str: + line = self.prefix + self.command + if self.mentioned: + line += '@' + self.mention + if self.args: + line += ' ' + self.args + return line + + class CommandsFilter(BaseFilter): """ Check commands in message @@ -15,22 +65,52 @@ class CommandsFilter(BaseFilter): def __init__(self, dispatcher, commands): super().__init__(dispatcher) + if isinstance(commands, str): + commands = (commands,) self.commands = commands async def check(self, message): - if not message.is_command(): - return False + return await Command.check_command(message, self.commands, '/') - command = message.text.split()[0][1:] - command, _, mention = command.partition('@') - if mention and mention != (await message.bot.me).username: - return False +class Text(Filter): + def __init__(self, + equals: Optional[str] = None, + contains: Optional[str] = None, + startswith: Optional[str] = None, + endswith: Optional[str] = None, + ignore_case=False): + # Only one mode can be used. check it. + check = sum(map(bool, (equals, contains, startswith, endswith))) + if check > 1: + args = "' and '".join([arg[0] for arg in [('equals', equals), + ('contains', contains), + ('startswith', startswith), + ('endswith', endswith) + ] if arg[1]]) + raise ValueError(f"Arguments '{args}' cannot be used together.") + elif check == 0: + raise ValueError(f"No one mode is specified!") - if command not in self.commands: - return False + self.equals = equals + self.contains = contains + self.endswith = endswith + self.startswith = startswith + self.ignore_case = ignore_case - return True + async def check(self, message: types.Message): + text = message.text.lower() if self.ignore_case else message.text + + if self.equals: + return text == self.equals + elif self.contains: + return self.contains in text + elif self.startswith: + return text.startswith(self.startswith) + elif self.endswith: + return text.endswith(self.endswith) + + return False class RegexpFilter(BaseFilter): diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py index 09f433ea..01b8722a 100644 --- a/aiogram/dispatcher/filters/filters.py +++ b/aiogram/dispatcher/filters/filters.py @@ -10,6 +10,17 @@ class FilterNotPassed(Exception): pass +def wrap_async(func): + async def async_wrapper(*args, **kwargs): + return func(*args, **kwargs) + + if inspect.isawaitable(func) \ + or inspect.iscoroutinefunction(func) \ + or isinstance(func, AbstractFilter): + return func + return async_wrapper + + async def check_filter(filter_, args): """ Helper for executing filter @@ -99,10 +110,6 @@ class AbstractFilter(abc.ABC): key = None - def __init__(self, dispatcher, **config): - self.dispatcher = dispatcher - self.config = config - @classmethod @abc.abstractmethod def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: @@ -127,6 +134,15 @@ class AbstractFilter(abc.ABC): async def __call__(self, obj: TelegramObject) -> bool: return await self.check(obj) + def __invert__(self): + return NotFilter(self) + + def __and__(self, other): + return AndFilter(self, other) + + def __or__(self, other): + return OrFilter(self, other) + class BaseFilter(AbstractFilter): """ @@ -136,6 +152,10 @@ class BaseFilter(AbstractFilter): required = False default = None + def __init__(self, dispatcher, **config): + self.dispatcher = dispatcher + self.config = config + @classmethod def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: if cls.key is not None: @@ -143,3 +163,35 @@ class BaseFilter(AbstractFilter): return {cls.key: full_config[cls.key]} elif cls.required: return {cls.key: cls.default} + + +class Filter(AbstractFilter): + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: + raise RuntimeError('This filter can\'t be passed as kwargs') + + +class NotFilter(Filter): + def __init__(self, target): + self.target = wrap_async(target) + + async def check(self, *args): + return await self.target(*args) + + +class AndFilter(Filter): + def __init__(self, target, target2): + self.target = wrap_async(target) + self.target2 = wrap_async(target2) + + async def check(self, *args): + return (await self.target(*args)) and (await self.target2(*args)) + + +class OrFilter(Filter): + def __init__(self, target, target2): + self.target = wrap_async(target) + self.target2 = wrap_async(target2) + + async def check(self, *args): + return (await self.target(*args)) or (await self.target2(*args))