From 1a9a11f3fd8e6c6d1a02b757ca313478e3e4d7d3 Mon Sep 17 00:00:00 2001 From: birdi Date: Sat, 27 Jul 2019 12:20:08 +0300 Subject: [PATCH] add multiple text filter --- aiogram/dispatcher/filters/builtin.py | 35 ++++++++++++--------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/aiogram/dispatcher/filters/builtin.py b/aiogram/dispatcher/filters/builtin.py index 15cd73dd..eb3845e1 100644 --- a/aiogram/dispatcher/filters/builtin.py +++ b/aiogram/dispatcher/filters/builtin.py @@ -206,10 +206,10 @@ class Text(Filter): """ def __init__(self, - equals: Optional[Union[str, LazyProxy]] = None, - contains: Optional[Union[str, LazyProxy]] = None, - startswith: Optional[Union[str, LazyProxy]] = None, - endswith: Optional[Union[str, LazyProxy]] = None, + equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, + contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, + startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, + endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None, ignore_case=False): """ Check text for one of pattern. Only one mode can be used in one filter. @@ -232,6 +232,9 @@ class Text(Filter): elif check == 0: raise ValueError(f"No one mode is specified!") + equals, contains, endswith, startswith = map(lambda e: [e] if isinstance(e, str) or isinstance(e, LazyProxy) + else e, + (equals, contains, endswith, startswith)) self.equals = equals self.contains = contains self.endswith = endswith @@ -267,25 +270,17 @@ class Text(Filter): text = text.lower() if self.equals is not None: - self.equals = str(self.equals) - if self.ignore_case: - self.equals = self.equals.lower() - return text == self.equals + self.equals = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.equals)) + return text in self.equals elif self.contains is not None: - self.contains = str(self.contains) - if self.ignore_case: - self.contains = self.contains.lower() - return self.contains in text + self.contains = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.contains)) + return any(map(text.__contains__, self.contains)) elif self.startswith is not None: - self.startswith = str(self.startswith) - if self.ignore_case: - self.startswith = self.startswith.lower() - return text.startswith(self.startswith) + self.startswith = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.startswith)) + return any(map(text.startswith, self.startswith)) elif self.endswith is not None: - self.endswith = str(self.endswith) - if self.ignore_case: - self.endswith = self.endswith.lower() - return text.endswith(self.endswith) + self.endswith = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.endswith)) + return any(map(text.endswith, self.endswith)) return False