mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-16 20:23:32 +00:00
Implement filters factory.
This commit is contained in:
parent
a3856e33bd
commit
8d7a00204d
2 changed files with 121 additions and 19 deletions
|
|
@ -5,8 +5,8 @@ import logging
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, RegexpFilter, \
|
from .filters import CommandsFilter, ContentTypeFilter, ExceptionsFilter, FiltersFactory, RegexpFilter, USER_STATE, \
|
||||||
USER_STATE, generate_default_filters
|
generate_default_filters
|
||||||
from .handler import CancelHandler, Handler, SkipHandler
|
from .handler import CancelHandler, Handler, SkipHandler
|
||||||
from .middlewares import MiddlewareManager
|
from .middlewares import MiddlewareManager
|
||||||
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
|
||||||
|
|
@ -37,12 +37,15 @@ class Dispatcher:
|
||||||
|
|
||||||
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
|
||||||
run_tasks_by_default: bool = False,
|
run_tasks_by_default: bool = False,
|
||||||
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False):
|
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False,
|
||||||
|
filters_factory=None):
|
||||||
|
|
||||||
if loop is None:
|
if loop is None:
|
||||||
loop = bot.loop
|
loop = bot.loop
|
||||||
if storage is None:
|
if storage is None:
|
||||||
storage = DisabledStorage()
|
storage = DisabledStorage()
|
||||||
|
if filters_factory is None:
|
||||||
|
filters_factory = FiltersFactory(self)
|
||||||
|
|
||||||
self.bot: Bot = bot
|
self.bot: Bot = bot
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
|
@ -54,6 +57,7 @@ class Dispatcher:
|
||||||
|
|
||||||
self.last_update_id = 0
|
self.last_update_id = 0
|
||||||
|
|
||||||
|
self.filters_factory: FiltersFactory = filters_factory
|
||||||
self.updates_handler = Handler(self, middleware_key='update')
|
self.updates_handler = Handler(self, middleware_key='update')
|
||||||
self.message_handlers = Handler(self, middleware_key='message')
|
self.message_handlers = Handler(self, middleware_key='message')
|
||||||
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
|
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
|
||||||
|
|
@ -74,6 +78,12 @@ class Dispatcher:
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self._close_waiter = loop.create_future()
|
self._close_waiter = loop.create_future()
|
||||||
|
|
||||||
|
filters_factory.bind(filters.CommandsFilter, 'commands')
|
||||||
|
filters_factory.bind(filters.RegexpFilter, 'regexp')
|
||||||
|
filters_factory.bind(filters.RegexpCommandsFilter, 'regexp_commands')
|
||||||
|
filters_factory.bind(filters.ContentTypeFilter, 'content_types')
|
||||||
|
filters_factory.bind(filters.StateFilter, 'state', with_dispatcher=True, default=True)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.stop_polling()
|
self.stop_polling()
|
||||||
|
|
||||||
|
|
@ -310,8 +320,8 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
return self._polling
|
return self._polling
|
||||||
|
|
||||||
def register_message_handler(self, callback, *, commands=None, regexp=None, content_types=None, func=None,
|
def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None,
|
||||||
state=None, custom_filters=None, run_task=None, **kwargs):
|
func=None, state=None, run_task=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Register handler for message
|
Register handler for message
|
||||||
|
|
||||||
|
|
@ -334,21 +344,21 @@ class Dispatcher:
|
||||||
:param content_types: List of content types.
|
:param content_types: List of content types.
|
||||||
:param func: custom any callable object
|
:param func: custom any callable object
|
||||||
:param custom_filters: list of custom filters
|
:param custom_filters: list of custom filters
|
||||||
|
:param run_task:
|
||||||
:param kwargs:
|
:param kwargs:
|
||||||
:param state:
|
:param state:
|
||||||
:return: decorated function
|
:return: decorated function
|
||||||
"""
|
"""
|
||||||
if content_types is None:
|
if content_types is None:
|
||||||
content_types = ContentType.TEXT
|
content_types = ContentType.TEXT
|
||||||
if custom_filters is None:
|
if func is not None:
|
||||||
custom_filters = []
|
custom_filters = list(custom_filters)
|
||||||
|
custom_filters.append(func)
|
||||||
|
|
||||||
filters_set = generate_default_filters(self,
|
filters_set = self.filters_factory.parse(*custom_filters,
|
||||||
*custom_filters,
|
|
||||||
commands=commands,
|
commands=commands,
|
||||||
regexp=regexp,
|
regexp=regexp,
|
||||||
content_types=content_types,
|
content_types=content_types,
|
||||||
func=func,
|
|
||||||
state=state,
|
state=state,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
|
self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
|
||||||
|
|
@ -426,9 +436,9 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(callback):
|
def decorator(callback):
|
||||||
self.register_message_handler(callback,
|
self.register_message_handler(callback, *custom_filters,
|
||||||
commands=commands, regexp=regexp, content_types=content_types,
|
commands=commands, regexp=regexp, content_types=content_types,
|
||||||
func=func, state=state, custom_filters=custom_filters, run_task=run_task,
|
func=func, state=state, run_task=run_task,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,84 @@ from ..utils.helper import Helper, HelperMode, Item
|
||||||
USER_STATE = 'USER_STATE'
|
USER_STATE = 'USER_STATE'
|
||||||
|
|
||||||
|
|
||||||
|
class FiltersFactory:
|
||||||
|
def __init__(self, dispatcher):
|
||||||
|
self._dispatcher = dispatcher
|
||||||
|
self._filters = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_filters(self):
|
||||||
|
return tuple(filter(lambda item: item[-1], self._filters))
|
||||||
|
|
||||||
|
def bind(self, filter_, *args, default=False, with_dispatcher=False):
|
||||||
|
self._filters.append((filter_, args, with_dispatcher, default))
|
||||||
|
|
||||||
|
def unbind(self, filter_):
|
||||||
|
for item in self._filters:
|
||||||
|
if filter_ is item[0]:
|
||||||
|
self._filters.remove(item)
|
||||||
|
return True
|
||||||
|
raise ValueError(f'{filter_} is not binded.')
|
||||||
|
|
||||||
|
def replace(self, original, new):
|
||||||
|
for item in self._filters:
|
||||||
|
if original is item[0]:
|
||||||
|
item[0] = new
|
||||||
|
return True
|
||||||
|
raise ValueError(f'{original} is not binded.')
|
||||||
|
|
||||||
|
def parse(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Generate filters list
|
||||||
|
|
||||||
|
:param args:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
used = []
|
||||||
|
filters = []
|
||||||
|
|
||||||
|
filters.extend(args)
|
||||||
|
|
||||||
|
# Registered filters filters
|
||||||
|
for filter_, args_list, with_dispatcher, default in self._filters:
|
||||||
|
config = {}
|
||||||
|
accept = True
|
||||||
|
|
||||||
|
for item in args_list:
|
||||||
|
value = kwargs.pop(item, None)
|
||||||
|
if value is None:
|
||||||
|
accept = False
|
||||||
|
break
|
||||||
|
config[item] = value
|
||||||
|
|
||||||
|
if accept:
|
||||||
|
if with_dispatcher:
|
||||||
|
config['dispatcher'] = self._dispatcher
|
||||||
|
|
||||||
|
filters.append(filter_(**config))
|
||||||
|
used.append(filter_)
|
||||||
|
|
||||||
|
elif default:
|
||||||
|
if filter_ not in used:
|
||||||
|
used.append(filter_)
|
||||||
|
if isinstance(filter_, Filter):
|
||||||
|
if with_dispatcher:
|
||||||
|
filters.append(filter_(dispatcher=self._dispatcher))
|
||||||
|
else:
|
||||||
|
filters.append(filter_())
|
||||||
|
|
||||||
|
# Not registered filters
|
||||||
|
for key, filter_ in kwargs.items():
|
||||||
|
if isinstance(filter_, Filter):
|
||||||
|
filters.append(filter_)
|
||||||
|
used.append(filter_.__class__)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown filter with key '{key}'")
|
||||||
|
|
||||||
|
return filters
|
||||||
|
|
||||||
|
|
||||||
async def check_filter(filter_, args):
|
async def check_filter(filter_, args):
|
||||||
"""
|
"""
|
||||||
Helper for executing filter
|
Helper for executing filter
|
||||||
|
|
@ -48,6 +126,9 @@ class Filter:
|
||||||
Base class for filters
|
Base class for filters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.check(*args, **kwargs)
|
return self.check(*args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -77,6 +158,7 @@ class AnyFilter(AsyncFilter):
|
||||||
|
|
||||||
def __init__(self, *filters: callable):
|
def __init__(self, *filters: callable):
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
async def check(self, *args):
|
async def check(self, *args):
|
||||||
f = (check_filter(filter_, args) for filter_ in self.filters)
|
f = (check_filter(filter_, args) for filter_ in self.filters)
|
||||||
|
|
@ -90,6 +172,7 @@ class NotFilter(AsyncFilter):
|
||||||
|
|
||||||
def __init__(self, filter_: callable):
|
def __init__(self, filter_: callable):
|
||||||
self.filter = filter_
|
self.filter = filter_
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
async def check(self, *args):
|
async def check(self, *args):
|
||||||
return not await check_filter(self.filter, args)
|
return not await check_filter(self.filter, args)
|
||||||
|
|
@ -102,6 +185,7 @@ class CommandsFilter(AsyncFilter):
|
||||||
|
|
||||||
def __init__(self, commands):
|
def __init__(self, commands):
|
||||||
self.commands = commands
|
self.commands = commands
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
async def check(self, message):
|
async def check(self, message):
|
||||||
if not message.is_command():
|
if not message.is_command():
|
||||||
|
|
@ -126,6 +210,7 @@ class RegexpFilter(Filter):
|
||||||
|
|
||||||
def __init__(self, regexp):
|
def __init__(self, regexp):
|
||||||
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
self.regexp = re.compile(regexp, flags=re.IGNORECASE | re.MULTILINE)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def check(self, message):
|
def check(self, message):
|
||||||
if message.text:
|
if message.text:
|
||||||
|
|
@ -139,6 +224,7 @@ class RegexpCommandsFilter(AsyncFilter):
|
||||||
|
|
||||||
def __init__(self, regexp_commands):
|
def __init__(self, regexp_commands):
|
||||||
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
async def check(self, message):
|
async def check(self, message):
|
||||||
if not message.is_command():
|
if not message.is_command():
|
||||||
|
|
@ -165,6 +251,7 @@ class ContentTypeFilter(Filter):
|
||||||
|
|
||||||
def __init__(self, content_types):
|
def __init__(self, content_types):
|
||||||
self.content_types = content_types
|
self.content_types = content_types
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def check(self, message):
|
def check(self, message):
|
||||||
return ContentType.ANY[0] in self.content_types or \
|
return ContentType.ANY[0] in self.content_types or \
|
||||||
|
|
@ -180,6 +267,7 @@ class CancelFilter(Filter):
|
||||||
if cancel_set is None:
|
if cancel_set is None:
|
||||||
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
cancel_set = ['/cancel', 'cancel', 'cancel.']
|
||||||
self.cancel_set = cancel_set
|
self.cancel_set = cancel_set
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def check(self, message):
|
def check(self, message):
|
||||||
if message.text:
|
if message.text:
|
||||||
|
|
@ -193,7 +281,10 @@ class StateFilter(AsyncFilter):
|
||||||
|
|
||||||
def __init__(self, dispatcher, state):
|
def __init__(self, dispatcher, state):
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
|
if isinstance(state, str):
|
||||||
|
state = (state,)
|
||||||
self.state = state
|
self.state = state
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def get_target(self, obj):
|
def get_target(self, obj):
|
||||||
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
|
||||||
|
|
@ -209,7 +300,7 @@ class StateFilter(AsyncFilter):
|
||||||
chat, user = self.get_target(obj)
|
chat, user = self.get_target(obj)
|
||||||
|
|
||||||
if chat or user:
|
if chat or user:
|
||||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
|
return await self.dispatcher.storage.get_state(chat=chat, user=user) in self.state
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -233,6 +324,7 @@ class ExceptionsFilter(Filter):
|
||||||
|
|
||||||
def __init__(self, exception):
|
def __init__(self, exception):
|
||||||
self.exception = exception
|
self.exception = exception
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def check(self, dispatcher, update, exception):
|
def check(self, dispatcher, update, exception):
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue