mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 18:01:04 +00:00
Added lost files
This commit is contained in:
parent
6253b25158
commit
79f21416c8
14 changed files with 586 additions and 0 deletions
64
aiogram/client/errors_middleware.py
Normal file
64
aiogram/client/errors_middleware.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Type
|
||||
|
||||
from aiogram.methods import Response, TelegramMethod
|
||||
from aiogram.types import TelegramObject
|
||||
from aiogram.utils.exceptions.base import TelegramAPIError
|
||||
from aiogram.utils.exceptions.exceptions import (
|
||||
CantParseEntitiesStartTag,
|
||||
CantParseEntitiesUnclosed,
|
||||
CantParseEntitiesUnmatchedTags,
|
||||
CantParseEntitiesUnsupportedTag,
|
||||
DetailedTelegramAPIError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.client.bot import Bot
|
||||
from aiogram.client.session.base import NextRequestMiddlewareType
|
||||
|
||||
|
||||
class RequestErrorMiddleware:
|
||||
def __init__(self) -> None:
|
||||
self._registry: List[Type[DetailedTelegramAPIError]] = [
|
||||
CantParseEntitiesStartTag,
|
||||
CantParseEntitiesUnmatchedTags,
|
||||
CantParseEntitiesUnclosed,
|
||||
CantParseEntitiesUnsupportedTag,
|
||||
]
|
||||
|
||||
def mount(self, error: Type[DetailedTelegramAPIError]) -> Type[DetailedTelegramAPIError]:
|
||||
if error in self:
|
||||
raise ValueError(f"{error!r} is already registered")
|
||||
if not hasattr(error, "patterns"):
|
||||
raise ValueError(f"{error!r} has no attribute 'patterns'")
|
||||
self._registry.append(error)
|
||||
return error
|
||||
|
||||
def detect_error(self, err: TelegramAPIError) -> TelegramAPIError:
|
||||
message = err.message
|
||||
for variant in self._registry:
|
||||
for pattern in variant.patterns:
|
||||
if match := re.match(pattern, message):
|
||||
return variant(
|
||||
method=err.method,
|
||||
message=err.message,
|
||||
match=match,
|
||||
)
|
||||
return err
|
||||
|
||||
def __contains__(self, item: Type[DetailedTelegramAPIError]) -> bool:
|
||||
return item in self._registry
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
bot: Bot,
|
||||
method: TelegramMethod[TelegramObject],
|
||||
make_request: NextRequestMiddlewareType,
|
||||
) -> Response[TelegramObject]:
|
||||
try:
|
||||
return await make_request(bot, method)
|
||||
except TelegramAPIError as e:
|
||||
detected_err = self.detect_error(err=e)
|
||||
raise detected_err from e
|
||||
113
aiogram/dispatcher/filters/callback_data.py
Normal file
113
aiogram/dispatcher/filters/callback_data.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from magic_filter import MagicFilter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aiogram.dispatcher.filters import BaseFilter
|
||||
from aiogram.types import CallbackQuery
|
||||
|
||||
T = TypeVar("T", bound="CallbackData")
|
||||
|
||||
MAX_CALLBACK_LENGTH: int = 64
|
||||
|
||||
|
||||
class CallbackDataException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CallbackData(BaseModel):
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
sep: str
|
||||
prefix: str
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
if "prefix" not in kwargs:
|
||||
raise ValueError(
|
||||
f"prefix required, usage example: "
|
||||
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
|
||||
)
|
||||
cls.sep = kwargs.pop("sep", ":")
|
||||
cls.prefix = kwargs.pop("prefix")
|
||||
if cls.sep in cls.prefix:
|
||||
raise ValueError(
|
||||
f"Separator symbol {cls.sep!r} can not be used inside prefix {cls.prefix!r}"
|
||||
)
|
||||
|
||||
def _encode_value(self, key: str, value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, Enum):
|
||||
return str(value.value)
|
||||
if isinstance(value, (int, str, float, Decimal, Fraction, UUID)):
|
||||
return str(value)
|
||||
raise ValueError(
|
||||
f"Attribute {key}={value!r} of type {type(value).__name__!r}"
|
||||
f" can not be packed to callback data"
|
||||
)
|
||||
|
||||
def pack(self) -> str:
|
||||
result = [self.prefix]
|
||||
for key, value in self.dict().items():
|
||||
encoded = self._encode_value(key, value)
|
||||
if self.sep in encoded:
|
||||
raise ValueError(
|
||||
f"Separator symbol {self.sep!r} can not be used in value {key}={encoded!r}"
|
||||
)
|
||||
result.append(encoded)
|
||||
callback_data = self.sep.join(result)
|
||||
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
|
||||
raise ValueError(
|
||||
f"Resulted callback data is too long! len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
|
||||
)
|
||||
return callback_data
|
||||
|
||||
@classmethod
|
||||
def unpack(cls: Type[T], value: str) -> T:
|
||||
prefix, *parts = value.split(cls.sep)
|
||||
names = cls.__fields__.keys()
|
||||
if len(parts) != len(names):
|
||||
raise TypeError(
|
||||
f"Callback data {cls.__name__!r} takes {len(names)} arguments but {len(parts)} were given"
|
||||
)
|
||||
if prefix != cls.prefix:
|
||||
raise ValueError(f"Bad prefix ({prefix!r} != {cls.prefix!r})")
|
||||
payload = {}
|
||||
for k, v in zip(names, parts): # type: str, Optional[str]
|
||||
if field := cls.__fields__.get(k):
|
||||
if v == "" and not field.required:
|
||||
v = None
|
||||
payload[k] = v
|
||||
return cls(**payload)
|
||||
|
||||
@classmethod
|
||||
def filter(cls, rule: MagicFilter) -> CallbackQueryFilter:
|
||||
return CallbackQueryFilter(callback_data=cls, rule=rule)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class CallbackQueryFilter(BaseFilter):
|
||||
callback_data: Type[CallbackData]
|
||||
rule: MagicFilter
|
||||
|
||||
async def __call__(self, query: CallbackQuery) -> Union[bool, Dict[str, Any]]:
|
||||
if not isinstance(query, CallbackQuery) or not query.data:
|
||||
return False
|
||||
try:
|
||||
callback_data = self.callback_data.unpack(query.data)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if self.rule.resolve(callback_data):
|
||||
return {"callback_data": callback_data}
|
||||
return False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
0
aiogram/utils/exceptions/__init__.py
Normal file
0
aiogram/utils/exceptions/__init__.py
Normal file
5
aiogram/utils/exceptions/bad_request.py
Normal file
5
aiogram/utils/exceptions/bad_request.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from aiogram.utils.exceptions.base import DetailedTelegramAPIError
|
||||
|
||||
|
||||
class BadRequest(DetailedTelegramAPIError):
|
||||
pass
|
||||
40
aiogram/utils/exceptions/base.py
Normal file
40
aiogram/utils/exceptions/base.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from typing import ClassVar, List, Match, Optional, TypeVar
|
||||
|
||||
from aiogram.methods import TelegramMethod
|
||||
from aiogram.methods.base import TelegramType
|
||||
|
||||
ErrorType = TypeVar("ErrorType")
|
||||
|
||||
|
||||
class TelegramAPIError(Exception):
|
||||
url: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: TelegramMethod[TelegramType],
|
||||
message: str,
|
||||
) -> None:
|
||||
self.method = method
|
||||
self.message = message
|
||||
|
||||
def render_description(self) -> str:
|
||||
return self.message
|
||||
|
||||
def __str__(self) -> str:
|
||||
message = [self.render_description()]
|
||||
if self.url:
|
||||
message.append(f"(background on this error at: {self.url})")
|
||||
return "\n".join(message)
|
||||
|
||||
|
||||
class DetailedTelegramAPIError(TelegramAPIError):
|
||||
patterns: ClassVar[List[str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: TelegramMethod[TelegramType],
|
||||
message: str,
|
||||
match: Match[str],
|
||||
) -> None:
|
||||
super().__init__(method=method, message=message)
|
||||
self.match: Match[str] = match
|
||||
0
aiogram/utils/exceptions/conflict.py
Normal file
0
aiogram/utils/exceptions/conflict.py
Normal file
5
aiogram/utils/exceptions/network.py
Normal file
5
aiogram/utils/exceptions/network.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from aiogram.utils.exceptions.base import DetailedTelegramAPIError
|
||||
|
||||
|
||||
class NetworkError(DetailedTelegramAPIError):
|
||||
pass
|
||||
5
aiogram/utils/exceptions/not_found.py
Normal file
5
aiogram/utils/exceptions/not_found.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from aiogram.utils.exceptions.base import DetailedTelegramAPIError
|
||||
|
||||
|
||||
class NotFound(DetailedTelegramAPIError):
|
||||
pass
|
||||
0
aiogram/utils/exceptions/server.py
Normal file
0
aiogram/utils/exceptions/server.py
Normal file
46
aiogram/utils/exceptions/special.py
Normal file
46
aiogram/utils/exceptions/special.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from typing import Optional
|
||||
|
||||
from aiogram.methods import TelegramMethod
|
||||
from aiogram.methods.base import TelegramType
|
||||
from aiogram.utils.exceptions.base import TelegramAPIError
|
||||
|
||||
|
||||
class RetryAfter(TelegramAPIError):
|
||||
url = "https://core.telegram.org/bots/faq#my-bot-is-hitting-limits-how-do-i-avoid-this"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: TelegramMethod[TelegramType],
|
||||
message: str,
|
||||
retry_after: int,
|
||||
) -> None:
|
||||
super().__init__(method=method, message=message)
|
||||
self.retry_after = retry_after
|
||||
|
||||
def render_description(self) -> str:
|
||||
description = f"Flood control exceeded on method {type(self.method).__name__!r}"
|
||||
if chat_id := getattr(self.method, "chat_id", None):
|
||||
description += f" in chat {chat_id}"
|
||||
description += f". Retry in {self.retry_after} seconds."
|
||||
return description
|
||||
|
||||
|
||||
class MigrateToChat(TelegramAPIError):
|
||||
url = "https://core.telegram.org/bots/api#responseparameters"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: TelegramMethod[TelegramType],
|
||||
message: str,
|
||||
migrate_to_chat_id: int,
|
||||
) -> None:
|
||||
super().__init__(method=method, message=message)
|
||||
self.migrate_to_chat_id = migrate_to_chat_id
|
||||
|
||||
def render_message(self) -> Optional[str]:
|
||||
description = (
|
||||
f"The group has been migrated to a supergroup with id {self.migrate_to_chat_id}"
|
||||
)
|
||||
if chat_id := getattr(self.method, "chat_id", None):
|
||||
description += f" from {chat_id}"
|
||||
return description
|
||||
0
aiogram/utils/exceptions/unauthorized.py
Normal file
0
aiogram/utils/exceptions/unauthorized.py
Normal file
20
aiogram/utils/exceptions/util.py
Normal file
20
aiogram/utils/exceptions/util.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
def mark_line(text: str, offset: int, length: int = 1) -> str:
|
||||
try:
|
||||
if offset > 0 and (new_line_pos := text[:offset].rindex("\n")):
|
||||
text = "..." + text[:new_line_pos]
|
||||
offset -= new_line_pos - 3
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if offset > 10:
|
||||
text = "..." + text[offset - 10 :]
|
||||
offset = 13
|
||||
|
||||
mark = " " * offset
|
||||
mark += "^" * length
|
||||
try:
|
||||
if new_line_pos := text[len(mark) :].index("\n"):
|
||||
text = text[:new_line_pos].rstrip() + "..."
|
||||
except ValueError:
|
||||
pass
|
||||
return text + "\n" + mark
|
||||
111
examples/finite_state_machine.py
Normal file
111
examples/finite_state_machine.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from os import getenv
|
||||
|
||||
from aiogram import Bot, Dispatcher, F
|
||||
from aiogram.dispatcher.filters import Command
|
||||
from aiogram.dispatcher.fsm.context import FSMContext
|
||||
from aiogram.dispatcher.fsm.state import State, StatesGroup
|
||||
from aiogram.types import Message, ReplyKeyboardRemove, ReplyKeyboardMarkup, KeyboardButton
|
||||
from aiogram.utils.markdown import hbold
|
||||
from aiogram.utils.markup import KeyboardConstructor
|
||||
|
||||
GENDERS = ["Male", "Female", "Helicopter", "Other"]
|
||||
|
||||
dp = Dispatcher()
|
||||
|
||||
|
||||
# States
|
||||
class Form(StatesGroup):
|
||||
name = State() # Will be represented in storage as 'Form:name'
|
||||
age = State() # Will be represented in storage as 'Form:age'
|
||||
gender = State() # Will be represented in storage as 'Form:gender'
|
||||
|
||||
|
||||
@dp.message(Command(commands=["start"]))
|
||||
async def cmd_start(message: Message, state: FSMContext):
|
||||
"""
|
||||
Conversation's entry point
|
||||
"""
|
||||
# Set state
|
||||
await state.set_state(Form.name)
|
||||
await message.answer("Hi there! What's your name?")
|
||||
|
||||
|
||||
@dp.message(Command(commands=["cancel"]))
|
||||
@dp.message(F.text.lower() == "cancel")
|
||||
async def cancel_handler(message: Message, state: FSMContext):
|
||||
"""
|
||||
Allow user to cancel any action
|
||||
"""
|
||||
current_state = await state.get_state()
|
||||
if current_state is None:
|
||||
return
|
||||
|
||||
logging.info("Cancelling state %r", current_state)
|
||||
# Cancel state and inform user about it
|
||||
await state.clear()
|
||||
# And remove keyboard (just in case)
|
||||
await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove())
|
||||
|
||||
|
||||
@dp.message(Form.name)
|
||||
async def process_name(message: Message, state: FSMContext):
|
||||
"""
|
||||
Process user name
|
||||
"""
|
||||
await state.update_data(name=message.text)
|
||||
await state.set_state(Form.age)
|
||||
await message.answer("How old are you?")
|
||||
|
||||
|
||||
# Check age. Age gotta be digit
|
||||
@dp.message(Form.age, ~F.text.isdigit())
|
||||
async def process_age_invalid(message: Message):
|
||||
"""
|
||||
If age is invalid
|
||||
"""
|
||||
return await message.answer("Age gotta be a number.\nHow old are you? (digits only)")
|
||||
|
||||
|
||||
@dp.message(Form.age)
|
||||
async def process_age(message: Message, state: FSMContext):
|
||||
# Update state and data
|
||||
await state.set_state(Form.gender)
|
||||
await state.update_data(age=int(message.text))
|
||||
|
||||
# Configure ReplyKeyboardMarkup
|
||||
constructor = KeyboardConstructor(KeyboardButton)
|
||||
constructor.add(*(KeyboardButton(text=text) for text in GENDERS)).adjust(2)
|
||||
markup = ReplyKeyboardMarkup(
|
||||
resize_keyboard=True, selective=True, keyboard=constructor.export()
|
||||
)
|
||||
await message.reply("What is your gender?", reply_markup=markup)
|
||||
|
||||
|
||||
@dp.message(Form.gender)
|
||||
async def process_gender(message: Message, state: FSMContext):
|
||||
data = await state.update_data(gender=message.text)
|
||||
await state.clear()
|
||||
|
||||
# And send message
|
||||
await message.answer(
|
||||
(
|
||||
f'Hi, nice to meet you, {hbold(data["name"])}\n'
|
||||
f'Age: {hbold(data["age"])}\n'
|
||||
f'Gender: {hbold(data["gender"])}\n'
|
||||
),
|
||||
reply_markup=ReplyKeyboardRemove(),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
bot = Bot(token=getenv("TELEGRAM_TOKEN"), parse_mode="HTML")
|
||||
|
||||
await dp.start_polling(bot)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
asyncio.run(main())
|
||||
177
tests/test_dispatcher/test_filters/test_callback_data.py
Normal file
177
tests/test_dispatcher/test_filters/test_callback_data.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
from decimal import Decimal
|
||||
from enum import Enum, auto
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from magic_filter import MagicFilter
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiogram import F
|
||||
from aiogram.dispatcher.filters.callback_data import CallbackData
|
||||
from aiogram.types import CallbackQuery, User
|
||||
|
||||
|
||||
class MyIntEnum(Enum):
|
||||
FOO = auto()
|
||||
|
||||
|
||||
class MyStringEnum(str, Enum):
|
||||
FOO = "FOO"
|
||||
|
||||
|
||||
class MyCallback(CallbackData, prefix="test"):
|
||||
foo: str
|
||||
bar: int
|
||||
|
||||
|
||||
class TestCallbackData:
|
||||
def test_init_subclass_prefix_required(self):
|
||||
assert MyCallback.prefix == "test"
|
||||
|
||||
with pytest.raises(ValueError, match="prefix required.+"):
|
||||
|
||||
class MyInvalidCallback(CallbackData):
|
||||
pass
|
||||
|
||||
def test_init_subclass_sep_validation(self):
|
||||
assert MyCallback.sep == ":"
|
||||
|
||||
class MyCallback2(CallbackData, prefix="test2", sep="@"):
|
||||
pass
|
||||
|
||||
assert MyCallback2.sep == "@"
|
||||
|
||||
with pytest.raises(ValueError, match="Separator symbol '@' .+ 'sp@m'"):
|
||||
|
||||
class MyInvalidCallback(CallbackData, prefix="sp@m", sep="@"):
|
||||
pass
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,success,expected",
|
||||
[
|
||||
[None, True, ""],
|
||||
[42, True, "42"],
|
||||
["test", True, "test"],
|
||||
[9.99, True, "9.99"],
|
||||
[Decimal("9.99"), True, "9.99"],
|
||||
[Fraction("3/2"), True, "3/2"],
|
||||
[
|
||||
UUID("123e4567-e89b-12d3-a456-426655440000"),
|
||||
True,
|
||||
"123e4567-e89b-12d3-a456-426655440000",
|
||||
],
|
||||
[MyIntEnum.FOO, True, "1"],
|
||||
[MyStringEnum.FOO, True, "FOO"],
|
||||
[..., False, "..."],
|
||||
[object, False, "..."],
|
||||
[object(), False, "..."],
|
||||
[User(id=42, is_bot=False, first_name="test"), False, "..."],
|
||||
],
|
||||
)
|
||||
def test_encode_value(self, value, success, expected):
|
||||
callback = MyCallback(foo="test", bar=42)
|
||||
if success:
|
||||
assert callback._encode_value("test", value) == expected
|
||||
else:
|
||||
with pytest.raises(ValueError):
|
||||
assert callback._encode_value("test", value) == expected
|
||||
|
||||
def test_pack(self):
|
||||
with pytest.raises(ValueError, match="Separator symbol .+"):
|
||||
assert MyCallback(foo="te:st", bar=42).pack()
|
||||
|
||||
with pytest.raises(ValueError, match=".+is too long.+"):
|
||||
assert MyCallback(foo="test" * 32, bar=42).pack()
|
||||
|
||||
assert MyCallback(foo="test", bar=42).pack() == "test:test:42"
|
||||
|
||||
def test_pack_optional(self):
|
||||
class MyCallback1(CallbackData, prefix="test1"):
|
||||
foo: str
|
||||
bar: Optional[int] = None
|
||||
|
||||
assert MyCallback1(foo="spam").pack() == "test1:spam:"
|
||||
assert MyCallback1(foo="spam", bar=42).pack() == "test1:spam:42"
|
||||
|
||||
class MyCallback2(CallbackData, prefix="test2"):
|
||||
foo: Optional[str] = None
|
||||
bar: int
|
||||
|
||||
assert MyCallback2(bar=42).pack() == "test2::42"
|
||||
assert MyCallback2(foo="spam", bar=42).pack() == "test2:spam:42"
|
||||
|
||||
class MyCallback3(CallbackData, prefix="test3"):
|
||||
foo: Optional[str] = "experiment"
|
||||
bar: int
|
||||
|
||||
assert MyCallback3(bar=42).pack() == "test3:experiment:42"
|
||||
assert MyCallback3(foo="spam", bar=42).pack() == "test3:spam:42"
|
||||
|
||||
def test_unpack(self):
|
||||
with pytest.raises(TypeError, match=".+ takes 2 arguments but 3 were given"):
|
||||
MyCallback.unpack("test:test:test:test")
|
||||
|
||||
with pytest.raises(ValueError, match="Bad prefix .+"):
|
||||
MyCallback.unpack("spam:test:test")
|
||||
|
||||
assert MyCallback.unpack("test:test:42") == MyCallback(foo="test", bar=42)
|
||||
|
||||
def test_unpack_optional(self):
|
||||
with pytest.raises(ValidationError):
|
||||
assert MyCallback.unpack("test:test:")
|
||||
|
||||
class MyCallback1(CallbackData, prefix="test1"):
|
||||
foo: str
|
||||
bar: Optional[int] = None
|
||||
|
||||
assert MyCallback1.unpack("test1:spam:") == MyCallback1(foo="spam")
|
||||
assert MyCallback1.unpack("test1:spam:42") == MyCallback1(foo="spam", bar=42)
|
||||
|
||||
class MyCallback2(CallbackData, prefix="test2"):
|
||||
foo: Optional[str] = None
|
||||
bar: int
|
||||
|
||||
assert MyCallback2.unpack("test2::42") == MyCallback2(bar=42)
|
||||
assert MyCallback2.unpack("test2:spam:42") == MyCallback2(foo="spam", bar=42)
|
||||
|
||||
class MyCallback3(CallbackData, prefix="test3"):
|
||||
foo: Optional[str] = "experiment"
|
||||
bar: int
|
||||
|
||||
assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42)
|
||||
assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42)
|
||||
|
||||
def test_build_filter(self):
|
||||
filter_object = MyCallback.filter(F.foo == "test")
|
||||
assert isinstance(filter_object.rule, MagicFilter)
|
||||
assert filter_object.callback_data is MyCallback
|
||||
|
||||
|
||||
class TestCallbackDataFilter:
|
||||
@pytest.mark.parametrize(
|
||||
"query,rule,result",
|
||||
[
|
||||
["test", F.foo == "test", False],
|
||||
["test:spam:42", F.foo == "test", False],
|
||||
["test:test:42", F.foo == "test", {"callback_data": MyCallback(foo="test", bar=42)}],
|
||||
["test:test:", F.foo == "test", False],
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_call(self, query, rule, result):
|
||||
callback_query = CallbackQuery(
|
||||
id="1",
|
||||
from_user=User(id=42, is_bot=False, first_name="test"),
|
||||
data=query,
|
||||
chat_instance="test",
|
||||
)
|
||||
|
||||
filter_object = MyCallback.filter(rule)
|
||||
assert await filter_object(callback_query) == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_call(self):
|
||||
filter_object = MyCallback.filter(F.test)
|
||||
assert not await filter_object(User(id=42, is_bot=False, first_name="test"))
|
||||
Loading…
Add table
Add a link
Reference in a new issue