Merge pull request #176 from Birdi7/add-multiple-text-filter

Add multiple text filter
This commit is contained in:
Alex Root Junior 2019-08-05 10:21:23 +03:00 committed by GitHub
commit 8823c8f22a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 257 additions and 24 deletions

View file

@ -206,18 +206,19 @@ class Text(Filter):
""" """
def __init__(self, def __init__(self,
equals: Optional[Union[str, LazyProxy]] = None, equals: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
contains: Optional[Union[str, LazyProxy]] = None, contains: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
startswith: Optional[Union[str, LazyProxy]] = None, startswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
endswith: Optional[Union[str, LazyProxy]] = None, endswith: Optional[Union[str, LazyProxy, Iterable[Union[str, LazyProxy]]]] = None,
ignore_case=False): ignore_case=False):
""" """
Check text for one of pattern. Only one mode can be used in one filter. Check text for one of pattern. Only one mode can be used in one filter.
In every pattern, a single string is treated as a list with 1 element.
:param equals: :param equals: True if object's text in the list
:param contains: :param contains: True if object's text contains all strings from the list
:param startswith: :param startswith: True if object's text starts with any of strings from the list
:param endswith: :param endswith: True if object's text ends with any of strings from the list
:param ignore_case: case insensitive :param ignore_case: case insensitive
""" """
# Only one mode can be used. check it. # Only one mode can be used. check it.
@ -232,6 +233,9 @@ class Text(Filter):
elif check == 0: elif check == 0:
raise ValueError(f"No one mode is specified!") raise ValueError(f"No one mode is specified!")
equals, contains, endswith, startswith = map(lambda e: [e] if isinstance(e, str) or isinstance(e, LazyProxy)
else e,
(equals, contains, endswith, startswith))
self.equals = equals self.equals = equals
self.contains = contains self.contains = contains
self.endswith = endswith self.endswith = endswith
@ -267,25 +271,17 @@ class Text(Filter):
text = text.lower() text = text.lower()
if self.equals is not None: if self.equals is not None:
self.equals = str(self.equals) self.equals = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.equals))
if self.ignore_case: return text in self.equals
self.equals = self.equals.lower()
return text == self.equals
elif self.contains is not None: elif self.contains is not None:
self.contains = str(self.contains) self.contains = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.contains))
if self.ignore_case: return all(map(text.__contains__, self.contains))
self.contains = self.contains.lower()
return self.contains in text
elif self.startswith is not None: elif self.startswith is not None:
self.startswith = str(self.startswith) self.startswith = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.startswith))
if self.ignore_case: return any(map(text.startswith, self.startswith))
self.startswith = self.startswith.lower()
return text.startswith(self.startswith)
elif self.endswith is not None: elif self.endswith is not None:
self.endswith = str(self.endswith) self.endswith = list(map(lambda s: str(s).lower() if self.ignore_case else str(s), self.endswith))
if self.ignore_case: return any(map(text.endswith, self.endswith))
self.endswith = self.endswith.lower()
return text.endswith(self.endswith)
return False return False

View file

@ -0,0 +1,50 @@
"""
This is a bot to show the usage of the builtin Text filter
Instead of a list, a single element can be passed to any filter, it will be treated as list with an element
"""
import logging
from aiogram import Bot, Dispatcher, executor, types
API_TOKEN = 'API_TOKEN_HERE'
# Configure logging
logging.basicConfig(level=logging.INFO)
# Initialize bot and dispatcher
bot = Bot(token=API_TOKEN)
dp = Dispatcher(bot)
# if the text from user in the list
@dp.message_handler(text=['text1', 'text2'])
async def text_in_handler(message: types.Message):
await message.answer("The message text is in the list!")
# if the text contains any string
@dp.message_handler(text_contains='example1')
@dp.message_handler(text_contains='example2')
async def text_contains_any_handler(message: types.Message):
await message.answer("The message text contains any of strings")
# if the text contains all the strings from the list
@dp.message_handler(text_contains=['str1', 'str2'])
async def text_contains_all_handler(message: types.Message):
await message.answer("The message text contains all strings from the list")
# if the text starts with any string from the list
@dp.message_handler(text_startswith=['prefix1', 'prefix2'])
async def text_startswith_handler(message: types.Message):
await message.answer("The message text starts with any of prefixes")
# if the text ends with any string from the list
@dp.message_handler(text_endswith=['postfix1', 'postfix2'])
async def text_endswith_handler(message: types.Message):
await message.answer("The message text ends with any of postfixes")
if __name__ == '__main__':
executor.start_polling(dp, skip_updates=True)

View file

@ -55,6 +55,56 @@ class TestTextFilter:
assert await check(InlineQuery(query=test_text)) assert await check(InlineQuery(query=test_text))
assert await check(Poll(question=test_text)) assert await check(Poll(question=test_text))
@pytest.mark.asyncio
@pytest.mark.parametrize("test_prefix_list, test_text, ignore_case",
[(['not_example', ''], '', True),
(['', 'not_example'], 'exAmple_string', True),
(['not_example', ''], '', False),
(['', 'not_example'], 'exAmple_string', False),
(['example_string', 'not_example'], 'example_string', True),
(['not_example', 'example_string'], 'exAmple_string', True),
(['exAmple_string', 'not_example'], 'example_string', True),
(['not_example', 'example_string'], 'example_string', False),
(['example_string', 'not_example'], 'exAmple_string', False),
(['not_example', 'exAmple_string'], 'example_string', False),
(['example_string', 'not_example'], 'example_string_dsf', True),
(['not_example', 'example_string'], 'example_striNG_dsf', True),
(['example_striNG', 'not_example'], 'example_string_dsf', True),
(['not_example', 'example_string'], 'example_string_dsf', False),
(['example_string', 'not_example'], 'example_striNG_dsf', False),
(['not_example', 'example_striNG'], 'example_string_dsf', False),
(['example_string', 'not_example'], 'not_example_string', True),
(['not_example', 'example_string'], 'not_eXample_string', True),
(['EXample_string', 'not_example'], 'not_example_string', True),
(['not_example', 'example_string'], 'not_example_string', False),
(['example_string', 'not_example'], 'not_eXample_string', False),
(['not_example', 'EXample_string'], 'not_example_string', False),
])
async def test_startswith_list(self, test_prefix_list, test_text, ignore_case):
test_filter = Text(startswith=test_prefix_list, ignore_case=ignore_case)
async def check(obj):
result = await test_filter.check(obj)
if ignore_case:
_test_prefix_list = map(str.lower, test_prefix_list)
_test_text = test_text.lower()
else:
_test_prefix_list = test_prefix_list
_test_text = test_text
return result is any(map(_test_text.startswith, _test_prefix_list))
assert await check(Message(text=test_text))
assert await check(CallbackQuery(data=test_text))
assert await check(InlineQuery(query=test_text))
assert await check(Poll(question=test_text))
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("test_postfix, test_text, ignore_case", @pytest.mark.parametrize("test_postfix, test_text, ignore_case",
[('', '', True), [('', '', True),
@ -105,6 +155,55 @@ class TestTextFilter:
assert await check(InlineQuery(query=test_text)) assert await check(InlineQuery(query=test_text))
assert await check(Poll(question=test_text)) assert await check(Poll(question=test_text))
@pytest.mark.asyncio
@pytest.mark.parametrize("test_postfix_list, test_text, ignore_case",
[(['', 'not_example'], '', True),
(['not_example', ''], 'exAmple_string', True),
(['', 'not_example'], '', False),
(['not_example', ''], 'exAmple_string', False),
(['example_string', 'not_example'], 'example_string', True),
(['not_example', 'example_string'], 'exAmple_string', True),
(['exAmple_string', 'not_example'], 'example_string', True),
(['example_string', 'not_example'], 'example_string', False),
(['not_example', 'example_string'], 'exAmple_string', False),
(['exAmple_string', 'not_example'], 'example_string', False),
(['example_string', 'not_example'], 'example_string_dsf', True),
(['not_example', 'example_string'], 'example_striNG_dsf', True),
(['example_striNG', 'not_example'], 'example_string_dsf', True),
(['not_example', 'example_string'], 'example_string_dsf', False),
(['example_string', 'not_example'], 'example_striNG_dsf', False),
(['not_example', 'example_striNG'], 'example_string_dsf', False),
(['not_example', 'example_string'], 'not_example_string', True),
(['example_string', 'not_example'], 'not_eXample_string', True),
(['not_example', 'EXample_string'], 'not_eXample_string', True),
(['not_example', 'example_string'], 'not_example_string', False),
(['example_string', 'not_example'], 'not_eXample_string', False),
(['not_example', 'EXample_string'], 'not_example_string', False),
])
async def test_endswith_list(self, test_postfix_list, test_text, ignore_case):
test_filter = Text(endswith=test_postfix_list, ignore_case=ignore_case)
async def check(obj):
result = await test_filter.check(obj)
if ignore_case:
_test_postfix_list = map(str.lower, test_postfix_list)
_test_text = test_text.lower()
else:
_test_postfix_list = test_postfix_list
_test_text = test_text
return result is any(map(_test_text.endswith, _test_postfix_list))
assert await check(Message(text=test_text))
assert await check(CallbackQuery(data=test_text))
assert await check(InlineQuery(query=test_text))
assert await check(Poll(question=test_text))
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("test_string, test_text, ignore_case", @pytest.mark.parametrize("test_string, test_text, ignore_case",
[('', '', True), [('', '', True),
@ -155,6 +254,37 @@ class TestTextFilter:
assert await check(InlineQuery(query=test_text)) assert await check(InlineQuery(query=test_text))
assert await check(Poll(question=test_text)) assert await check(Poll(question=test_text))
@pytest.mark.asyncio
@pytest.mark.parametrize("test_filter_list, test_text, ignore_case",
[(['a', 'ab', 'abc'], 'A', True),
(['a', 'ab', 'abc'], 'ab', True),
(['a', 'ab', 'abc'], 'aBc', True),
(['a', 'ab', 'abc'], 'd', True),
(['a', 'ab', 'abc'], 'A', False),
(['a', 'ab', 'abc'], 'ab', False),
(['a', 'ab', 'abc'], 'aBc', False),
(['a', 'ab', 'abc'], 'd', False),
])
async def test_contains_list(self, test_filter_list, test_text, ignore_case):
test_filter = Text(contains=test_filter_list, ignore_case=ignore_case)
async def check(obj):
result = await test_filter.check(obj)
if ignore_case:
_test_filter_list = list(map(str.lower, test_filter_list))
_test_text = test_text.lower()
else:
_test_filter_list = test_filter_list
_test_text = test_text
return result is all(map(_test_text.__contains__, _test_filter_list))
assert await check(Message(text=test_text))
assert await check(CallbackQuery(data=test_text))
assert await check(InlineQuery(query=test_text))
assert await check(Poll(question=test_text))
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("test_filter_text, test_text, ignore_case", @pytest.mark.parametrize("test_filter_text, test_text, ignore_case",
[('', '', True), [('', '', True),
@ -195,3 +325,60 @@ class TestTextFilter:
assert await check(CallbackQuery(data=test_text)) assert await check(CallbackQuery(data=test_text))
assert await check(InlineQuery(query=test_text)) assert await check(InlineQuery(query=test_text))
assert await check(Poll(question=test_text)) assert await check(Poll(question=test_text))
@pytest.mark.asyncio
@pytest.mark.parametrize("test_filter_list, test_text, ignore_case",
[(['', 'new_string'], '', True),
(['new_string', ''], 'exAmple_string', True),
(['new_string', ''], '', False),
(['', 'new_string'], 'exAmple_string', False),
(['example_string'], 'example_string', True),
(['example_string'], 'exAmple_string', True),
(['exAmple_string'], 'example_string', True),
(['example_string'], 'example_string', False),
(['example_string'], 'exAmple_string', False),
(['exAmple_string'], 'example_string', False),
(['example_string'], 'not_example_string', True),
(['example_string'], 'not_eXample_string', True),
(['EXample_string'], 'not_eXample_string', True),
(['example_string'], 'not_example_string', False),
(['example_string'], 'not_eXample_string', False),
(['EXample_string'], 'not_example_string', False),
(['example_string', 'new_string'], 'example_string', True),
(['new_string', 'example_string'], 'exAmple_string', True),
(['exAmple_string', 'new_string'], 'example_string', True),
(['example_string', 'new_string'], 'example_string', False),
(['new_string', 'example_string'], 'exAmple_string', False),
(['exAmple_string', 'new_string'], 'example_string', False),
(['example_string', 'new_string'], 'not_example_string', True),
(['new_string', 'example_string'], 'not_eXample_string', True),
(['EXample_string', 'new_string'], 'not_eXample_string', True),
(['example_string', 'new_string'], 'not_example_string', False),
(['new_string', 'example_string'], 'not_eXample_string', False),
(['EXample_string', 'new_string'], 'not_example_string', False),
])
async def test_equals_list(self, test_filter_list, test_text, ignore_case):
test_filter = Text(equals=test_filter_list, ignore_case=ignore_case)
async def check(obj):
result = await test_filter.check(obj)
if ignore_case:
_test_filter_list = list(map(str.lower, test_filter_list))
_test_text = test_text.lower()
else:
_test_filter_list = test_filter_list
_test_text = test_text
assert result is (_test_text in _test_filter_list)
await check(Message(text=test_text))
await check(CallbackQuery(data=test_text))
await check(InlineQuery(query=test_text))
await check(Poll(question=test_text))