Improved filters.

This commit is contained in:
Alex Root Junior 2018-07-11 23:16:12 +03:00
parent cd4fee5eaa
commit 24184b1c8f
3 changed files with 155 additions and 20 deletions

View file

@ -1,20 +1,23 @@
from .builtin import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, \ from .builtin import Command, CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpCommandsFilter, RegexpFilter, \
RegexpFilter, StateFilter StateFilter, Text
from .factory import FiltersFactory from .factory import FiltersFactory
from .filters import AbstractFilter, BaseFilter, FilterNotPassed, FilterRecord, check_filter, check_filters from .filters import AbstractFilter, BaseFilter, Filter, FilterNotPassed, FilterRecord, check_filter, check_filters
__all__ = [ __all__ = [
'AbstractFilter', 'AbstractFilter',
'BaseFilter', 'BaseFilter',
'Command',
'CommandsFilter', 'CommandsFilter',
'ContentTypeFilter', 'ContentTypeFilter',
'ExceptionsFilter', 'ExceptionsFilter',
'Filter',
'FilterNotPassed',
'FilterRecord', 'FilterRecord',
'FiltersFactory', 'FiltersFactory',
'RegexpCommandsFilter', 'RegexpCommandsFilter',
'RegexpFilter', 'RegexpFilter',
'StateFilter', 'StateFilter',
'Text',
'check_filter', 'check_filter',
'check_filters', 'check_filters'
'FilterNotPassed'
] ]

View file

@ -1,12 +1,62 @@
import inspect import inspect
import re import re
from _contextvars import ContextVar from contextvars import ContextVar
from dataclasses import dataclass
from typing import Optional
from aiogram import types from aiogram import types
from aiogram.dispatcher.filters.filters import BaseFilter from aiogram.dispatcher.filters.filters import BaseFilter, Filter
from aiogram.types import CallbackQuery, ContentType, Message from aiogram.types import CallbackQuery, ContentType, Message
class Command(Filter):
def __init__(self, commands, prefixes='/', ignore_case=True, ignore_mention=False):
if isinstance(commands, str):
commands = (commands,)
self.commands = list(map(str.lower, commands)) if ignore_case else commands
self.prefixes = prefixes
self.ignore_case = ignore_case
self.ignore_mention = ignore_mention
@staticmethod
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
full_command = message.text.split()[0]
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
return False
elif prefix not in prefixes:
return False
elif (command.lower() if ignore_case else command) not in commands:
return False
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
@dataclass
class CommandObj:
prefix: str = '/'
command: str = ''
mention: str = None
args: str = None
@property
def mentioned(self) -> bool:
return bool(self.mention)
@property
def text(self) -> str:
line = self.prefix + self.command
if self.mentioned:
line += '@' + self.mention
if self.args:
line += ' ' + self.args
return line
class CommandsFilter(BaseFilter): class CommandsFilter(BaseFilter):
""" """
Check commands in message Check commands in message
@ -15,22 +65,52 @@ class CommandsFilter(BaseFilter):
def __init__(self, dispatcher, commands): def __init__(self, dispatcher, commands):
super().__init__(dispatcher) super().__init__(dispatcher)
if isinstance(commands, str):
commands = (commands,)
self.commands = commands self.commands = commands
async def check(self, message): async def check(self, message):
if not message.is_command(): return await Command.check_command(message, self.commands, '/')
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
if mention and mention != (await message.bot.me).username: class Text(Filter):
return False def __init__(self,
equals: Optional[str] = None,
contains: Optional[str] = None,
startswith: Optional[str] = None,
endswith: Optional[str] = None,
ignore_case=False):
# Only one mode can be used. check it.
check = sum(map(bool, (equals, contains, startswith, endswith)))
if check > 1:
args = "' and '".join([arg[0] for arg in [('equals', equals),
('contains', contains),
('startswith', startswith),
('endswith', endswith)
] if arg[1]])
raise ValueError(f"Arguments '{args}' cannot be used together.")
elif check == 0:
raise ValueError(f"No one mode is specified!")
if command not in self.commands: self.equals = equals
return False self.contains = contains
self.endswith = endswith
self.startswith = startswith
self.ignore_case = ignore_case
return True async def check(self, message: types.Message):
text = message.text.lower() if self.ignore_case else message.text
if self.equals:
return text == self.equals
elif self.contains:
return self.contains in text
elif self.startswith:
return text.startswith(self.startswith)
elif self.endswith:
return text.endswith(self.endswith)
return False
class RegexpFilter(BaseFilter): class RegexpFilter(BaseFilter):

View file

@ -10,6 +10,17 @@ class FilterNotPassed(Exception):
pass pass
def wrap_async(func):
async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)
if inspect.isawaitable(func) \
or inspect.iscoroutinefunction(func) \
or isinstance(func, AbstractFilter):
return func
return async_wrapper
async def check_filter(filter_, args): async def check_filter(filter_, args):
""" """
Helper for executing filter Helper for executing filter
@ -99,10 +110,6 @@ class AbstractFilter(abc.ABC):
key = None key = None
def __init__(self, dispatcher, **config):
self.dispatcher = dispatcher
self.config = config
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]: def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
@ -127,6 +134,15 @@ class AbstractFilter(abc.ABC):
async def __call__(self, obj: TelegramObject) -> bool: async def __call__(self, obj: TelegramObject) -> bool:
return await self.check(obj) return await self.check(obj)
def __invert__(self):
return NotFilter(self)
def __and__(self, other):
return AndFilter(self, other)
def __or__(self, other):
return OrFilter(self, other)
class BaseFilter(AbstractFilter): class BaseFilter(AbstractFilter):
""" """
@ -136,6 +152,10 @@ class BaseFilter(AbstractFilter):
required = False required = False
default = None default = None
def __init__(self, dispatcher, **config):
self.dispatcher = dispatcher
self.config = config
@classmethod @classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if cls.key is not None: if cls.key is not None:
@ -143,3 +163,35 @@ class BaseFilter(AbstractFilter):
return {cls.key: full_config[cls.key]} return {cls.key: full_config[cls.key]}
elif cls.required: elif cls.required:
return {cls.key: cls.default} return {cls.key: cls.default}
class Filter(AbstractFilter):
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
raise RuntimeError('This filter can\'t be passed as kwargs')
class NotFilter(Filter):
def __init__(self, target):
self.target = wrap_async(target)
async def check(self, *args):
return await self.target(*args)
class AndFilter(Filter):
def __init__(self, target, target2):
self.target = wrap_async(target)
self.target2 = wrap_async(target2)
async def check(self, *args):
return (await self.target(*args)) and (await self.target2(*args))
class OrFilter(Filter):
def __init__(self, target, target2):
self.target = wrap_async(target)
self.target2 = wrap_async(target2)
async def check(self, *args):
return (await self.target(*args)) or (await self.target2(*args))