From b4ecc421e431b887c9d3a9b65485f0c16ed84572 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Mon, 25 Jun 2018 03:19:58 +0300 Subject: [PATCH] Default filters --- aiogram/dispatcher/dispatcher.py | 3 --- aiogram/dispatcher/filters/builtin.py | 14 ++++++++++++++ aiogram/dispatcher/filters/filters.py | 2 +- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index d643b3dd..e926e772 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -345,9 +345,6 @@ class Dispatcher: :param state: :return: decorated function """ - if content_types is None: - content_types = ContentType.TEXT - filters_set = self.filters_factory.resolve(self.message_handlers, *custom_filters, commands=commands, diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 5a2d9199..25aa2a05 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -1,6 +1,8 @@ import re +import typing from _contextvars import ContextVar +from aiogram import types from aiogram.dispatcher.filters.filters import BaseFilter from aiogram.types import CallbackQuery, ContentType, Message @@ -98,6 +100,12 @@ class ContentTypeFilter(BaseFilter): super().__init__(dispatcher) self.content_types = content_types + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]): + result = super(ContentTypeFilter, cls).validate(full_config) + if not result: + return {cls.key: types.ContentType.TEXT} + async def check(self, message): return ContentType.ANY[0] in self.content_types or \ message.content_type in self.content_types @@ -120,6 +128,12 @@ class StateFilter(BaseFilter): def get_target(self, obj): return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None) + @classmethod + def validate(cls, full_config: typing.Dict[str, typing.Any]): + result = super(StateFilter, cls).validate(full_config) + if not result: + return {cls.key: None} + async def check(self, obj): from ..dispatcher import Dispatcher diff --git a/aiogram/dispatcher/filters/filters.py b/aiogram/dispatcher/filters/filters.py index 716edaac..e65d50e1 100644 --- a/aiogram/dispatcher/filters/filters.py +++ b/aiogram/dispatcher/filters/filters.py @@ -130,7 +130,7 @@ class AbstractFilter(abc.ABC): class BaseFilter(AbstractFilter): """ - Abstract class for filters with default validator + Base class for filters with default validator """ key = None