mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 10:11:52 +00:00
Added lost files
This commit is contained in:
parent
6253b25158
commit
79f21416c8
14 changed files with 586 additions and 0 deletions
64
aiogram/client/errors_middleware.py
Normal file
64
aiogram/client/errors_middleware.py
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, List, Type
|
||||||
|
|
||||||
|
from aiogram.methods import Response, TelegramMethod
|
||||||
|
from aiogram.types import TelegramObject
|
||||||
|
from aiogram.utils.exceptions.base import TelegramAPIError
|
||||||
|
from aiogram.utils.exceptions.exceptions import (
|
||||||
|
CantParseEntitiesStartTag,
|
||||||
|
CantParseEntitiesUnclosed,
|
||||||
|
CantParseEntitiesUnmatchedTags,
|
||||||
|
CantParseEntitiesUnsupportedTag,
|
||||||
|
DetailedTelegramAPIError,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from aiogram.client.bot import Bot
|
||||||
|
from aiogram.client.session.base import NextRequestMiddlewareType
|
||||||
|
|
||||||
|
|
||||||
|
class RequestErrorMiddleware:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._registry: List[Type[DetailedTelegramAPIError]] = [
|
||||||
|
CantParseEntitiesStartTag,
|
||||||
|
CantParseEntitiesUnmatchedTags,
|
||||||
|
CantParseEntitiesUnclosed,
|
||||||
|
CantParseEntitiesUnsupportedTag,
|
||||||
|
]
|
||||||
|
|
||||||
|
def mount(self, error: Type[DetailedTelegramAPIError]) -> Type[DetailedTelegramAPIError]:
|
||||||
|
if error in self:
|
||||||
|
raise ValueError(f"{error!r} is already registered")
|
||||||
|
if not hasattr(error, "patterns"):
|
||||||
|
raise ValueError(f"{error!r} has no attribute 'patterns'")
|
||||||
|
self._registry.append(error)
|
||||||
|
return error
|
||||||
|
|
||||||
|
def detect_error(self, err: TelegramAPIError) -> TelegramAPIError:
|
||||||
|
message = err.message
|
||||||
|
for variant in self._registry:
|
||||||
|
for pattern in variant.patterns:
|
||||||
|
if match := re.match(pattern, message):
|
||||||
|
return variant(
|
||||||
|
method=err.method,
|
||||||
|
message=err.message,
|
||||||
|
match=match,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
|
||||||
|
def __contains__(self, item: Type[DetailedTelegramAPIError]) -> bool:
|
||||||
|
return item in self._registry
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
bot: Bot,
|
||||||
|
method: TelegramMethod[TelegramObject],
|
||||||
|
make_request: NextRequestMiddlewareType,
|
||||||
|
) -> Response[TelegramObject]:
|
||||||
|
try:
|
||||||
|
return await make_request(bot, method)
|
||||||
|
except TelegramAPIError as e:
|
||||||
|
detected_err = self.detect_error(err=e)
|
||||||
|
raise detected_err from e
|
||||||
113
aiogram/dispatcher/filters/callback_data.py
Normal file
113
aiogram/dispatcher/filters/callback_data.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
from enum import Enum
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from magic_filter import MagicFilter
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from aiogram.dispatcher.filters import BaseFilter
|
||||||
|
from aiogram.types import CallbackQuery
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="CallbackData")
|
||||||
|
|
||||||
|
MAX_CALLBACK_LENGTH: int = 64
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackDataException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackData(BaseModel):
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
sep: str
|
||||||
|
prefix: str
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
if "prefix" not in kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"prefix required, usage example: "
|
||||||
|
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
|
||||||
|
)
|
||||||
|
cls.sep = kwargs.pop("sep", ":")
|
||||||
|
cls.prefix = kwargs.pop("prefix")
|
||||||
|
if cls.sep in cls.prefix:
|
||||||
|
raise ValueError(
|
||||||
|
f"Separator symbol {cls.sep!r} can not be used inside prefix {cls.prefix!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _encode_value(self, key: str, value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(value, Enum):
|
||||||
|
return str(value.value)
|
||||||
|
if isinstance(value, (int, str, float, Decimal, Fraction, UUID)):
|
||||||
|
return str(value)
|
||||||
|
raise ValueError(
|
||||||
|
f"Attribute {key}={value!r} of type {type(value).__name__!r}"
|
||||||
|
f" can not be packed to callback data"
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack(self) -> str:
|
||||||
|
result = [self.prefix]
|
||||||
|
for key, value in self.dict().items():
|
||||||
|
encoded = self._encode_value(key, value)
|
||||||
|
if self.sep in encoded:
|
||||||
|
raise ValueError(
|
||||||
|
f"Separator symbol {self.sep!r} can not be used in value {key}={encoded!r}"
|
||||||
|
)
|
||||||
|
result.append(encoded)
|
||||||
|
callback_data = self.sep.join(result)
|
||||||
|
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
|
||||||
|
raise ValueError(
|
||||||
|
f"Resulted callback data is too long! len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
|
||||||
|
)
|
||||||
|
return callback_data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unpack(cls: Type[T], value: str) -> T:
|
||||||
|
prefix, *parts = value.split(cls.sep)
|
||||||
|
names = cls.__fields__.keys()
|
||||||
|
if len(parts) != len(names):
|
||||||
|
raise TypeError(
|
||||||
|
f"Callback data {cls.__name__!r} takes {len(names)} arguments but {len(parts)} were given"
|
||||||
|
)
|
||||||
|
if prefix != cls.prefix:
|
||||||
|
raise ValueError(f"Bad prefix ({prefix!r} != {cls.prefix!r})")
|
||||||
|
payload = {}
|
||||||
|
for k, v in zip(names, parts): # type: str, Optional[str]
|
||||||
|
if field := cls.__fields__.get(k):
|
||||||
|
if v == "" and not field.required:
|
||||||
|
v = None
|
||||||
|
payload[k] = v
|
||||||
|
return cls(**payload)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def filter(cls, rule: MagicFilter) -> CallbackQueryFilter:
|
||||||
|
return CallbackQueryFilter(callback_data=cls, rule=rule)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
use_enum_values = True
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackQueryFilter(BaseFilter):
|
||||||
|
callback_data: Type[CallbackData]
|
||||||
|
rule: MagicFilter
|
||||||
|
|
||||||
|
async def __call__(self, query: CallbackQuery) -> Union[bool, Dict[str, Any]]:
|
||||||
|
if not isinstance(query, CallbackQuery) or not query.data:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
callback_data = self.callback_data.unpack(query.data)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.rule.resolve(callback_data):
|
||||||
|
return {"callback_data": callback_data}
|
||||||
|
return False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
0
aiogram/utils/exceptions/__init__.py
Normal file
0
aiogram/utils/exceptions/__init__.py
Normal file
5
aiogram/utils/exceptions/bad_request.py
Normal file
5
aiogram/utils/exceptions/bad_request.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
from aiogram.utils.exceptions.base import DetailedTelegramAPIError
|
||||||
|
|
||||||
|
|
||||||
|
class BadRequest(DetailedTelegramAPIError):
|
||||||
|
pass
|
||||||
40
aiogram/utils/exceptions/base.py
Normal file
40
aiogram/utils/exceptions/base.py
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
from typing import ClassVar, List, Match, Optional, TypeVar
|
||||||
|
|
||||||
|
from aiogram.methods import TelegramMethod
|
||||||
|
from aiogram.methods.base import TelegramType
|
||||||
|
|
||||||
|
ErrorType = TypeVar("ErrorType")
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramAPIError(Exception):
|
||||||
|
url: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
method: TelegramMethod[TelegramType],
|
||||||
|
message: str,
|
||||||
|
) -> None:
|
||||||
|
self.method = method
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def render_description(self) -> str:
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
message = [self.render_description()]
|
||||||
|
if self.url:
|
||||||
|
message.append(f"(background on this error at: {self.url})")
|
||||||
|
return "\n".join(message)
|
||||||
|
|
||||||
|
|
||||||
|
class DetailedTelegramAPIError(TelegramAPIError):
|
||||||
|
patterns: ClassVar[List[str]]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
method: TelegramMethod[TelegramType],
|
||||||
|
message: str,
|
||||||
|
match: Match[str],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(method=method, message=message)
|
||||||
|
self.match: Match[str] = match
|
||||||
0
aiogram/utils/exceptions/conflict.py
Normal file
0
aiogram/utils/exceptions/conflict.py
Normal file
5
aiogram/utils/exceptions/network.py
Normal file
5
aiogram/utils/exceptions/network.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
from aiogram.utils.exceptions.base import DetailedTelegramAPIError
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkError(DetailedTelegramAPIError):
|
||||||
|
pass
|
||||||
5
aiogram/utils/exceptions/not_found.py
Normal file
5
aiogram/utils/exceptions/not_found.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
from aiogram.utils.exceptions.base import DetailedTelegramAPIError
|
||||||
|
|
||||||
|
|
||||||
|
class NotFound(DetailedTelegramAPIError):
|
||||||
|
pass
|
||||||
0
aiogram/utils/exceptions/server.py
Normal file
0
aiogram/utils/exceptions/server.py
Normal file
46
aiogram/utils/exceptions/special.py
Normal file
46
aiogram/utils/exceptions/special.py
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from aiogram.methods import TelegramMethod
|
||||||
|
from aiogram.methods.base import TelegramType
|
||||||
|
from aiogram.utils.exceptions.base import TelegramAPIError
|
||||||
|
|
||||||
|
|
||||||
|
class RetryAfter(TelegramAPIError):
|
||||||
|
url = "https://core.telegram.org/bots/faq#my-bot-is-hitting-limits-how-do-i-avoid-this"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
method: TelegramMethod[TelegramType],
|
||||||
|
message: str,
|
||||||
|
retry_after: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(method=method, message=message)
|
||||||
|
self.retry_after = retry_after
|
||||||
|
|
||||||
|
def render_description(self) -> str:
|
||||||
|
description = f"Flood control exceeded on method {type(self.method).__name__!r}"
|
||||||
|
if chat_id := getattr(self.method, "chat_id", None):
|
||||||
|
description += f" in chat {chat_id}"
|
||||||
|
description += f". Retry in {self.retry_after} seconds."
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class MigrateToChat(TelegramAPIError):
|
||||||
|
url = "https://core.telegram.org/bots/api#responseparameters"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
method: TelegramMethod[TelegramType],
|
||||||
|
message: str,
|
||||||
|
migrate_to_chat_id: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(method=method, message=message)
|
||||||
|
self.migrate_to_chat_id = migrate_to_chat_id
|
||||||
|
|
||||||
|
def render_message(self) -> Optional[str]:
|
||||||
|
description = (
|
||||||
|
f"The group has been migrated to a supergroup with id {self.migrate_to_chat_id}"
|
||||||
|
)
|
||||||
|
if chat_id := getattr(self.method, "chat_id", None):
|
||||||
|
description += f" from {chat_id}"
|
||||||
|
return description
|
||||||
0
aiogram/utils/exceptions/unauthorized.py
Normal file
0
aiogram/utils/exceptions/unauthorized.py
Normal file
20
aiogram/utils/exceptions/util.py
Normal file
20
aiogram/utils/exceptions/util.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
def mark_line(text: str, offset: int, length: int = 1) -> str:
|
||||||
|
try:
|
||||||
|
if offset > 0 and (new_line_pos := text[:offset].rindex("\n")):
|
||||||
|
text = "..." + text[:new_line_pos]
|
||||||
|
offset -= new_line_pos - 3
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if offset > 10:
|
||||||
|
text = "..." + text[offset - 10 :]
|
||||||
|
offset = 13
|
||||||
|
|
||||||
|
mark = " " * offset
|
||||||
|
mark += "^" * length
|
||||||
|
try:
|
||||||
|
if new_line_pos := text[len(mark) :].index("\n"):
|
||||||
|
text = text[:new_line_pos].rstrip() + "..."
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return text + "\n" + mark
|
||||||
111
examples/finite_state_machine.py
Normal file
111
examples/finite_state_machine.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from os import getenv
|
||||||
|
|
||||||
|
from aiogram import Bot, Dispatcher, F
|
||||||
|
from aiogram.dispatcher.filters import Command
|
||||||
|
from aiogram.dispatcher.fsm.context import FSMContext
|
||||||
|
from aiogram.dispatcher.fsm.state import State, StatesGroup
|
||||||
|
from aiogram.types import Message, ReplyKeyboardRemove, ReplyKeyboardMarkup, KeyboardButton
|
||||||
|
from aiogram.utils.markdown import hbold
|
||||||
|
from aiogram.utils.markup import KeyboardConstructor
|
||||||
|
|
||||||
|
GENDERS = ["Male", "Female", "Helicopter", "Other"]
|
||||||
|
|
||||||
|
dp = Dispatcher()
|
||||||
|
|
||||||
|
|
||||||
|
# States
|
||||||
|
class Form(StatesGroup):
|
||||||
|
name = State() # Will be represented in storage as 'Form:name'
|
||||||
|
age = State() # Will be represented in storage as 'Form:age'
|
||||||
|
gender = State() # Will be represented in storage as 'Form:gender'
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message(Command(commands=["start"]))
|
||||||
|
async def cmd_start(message: Message, state: FSMContext):
|
||||||
|
"""
|
||||||
|
Conversation's entry point
|
||||||
|
"""
|
||||||
|
# Set state
|
||||||
|
await state.set_state(Form.name)
|
||||||
|
await message.answer("Hi there! What's your name?")
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message(Command(commands=["cancel"]))
|
||||||
|
@dp.message(F.text.lower() == "cancel")
|
||||||
|
async def cancel_handler(message: Message, state: FSMContext):
|
||||||
|
"""
|
||||||
|
Allow user to cancel any action
|
||||||
|
"""
|
||||||
|
current_state = await state.get_state()
|
||||||
|
if current_state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("Cancelling state %r", current_state)
|
||||||
|
# Cancel state and inform user about it
|
||||||
|
await state.clear()
|
||||||
|
# And remove keyboard (just in case)
|
||||||
|
await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove())
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message(Form.name)
|
||||||
|
async def process_name(message: Message, state: FSMContext):
|
||||||
|
"""
|
||||||
|
Process user name
|
||||||
|
"""
|
||||||
|
await state.update_data(name=message.text)
|
||||||
|
await state.set_state(Form.age)
|
||||||
|
await message.answer("How old are you?")
|
||||||
|
|
||||||
|
|
||||||
|
# Check age. Age gotta be digit
|
||||||
|
@dp.message(Form.age, ~F.text.isdigit())
|
||||||
|
async def process_age_invalid(message: Message):
|
||||||
|
"""
|
||||||
|
If age is invalid
|
||||||
|
"""
|
||||||
|
return await message.answer("Age gotta be a number.\nHow old are you? (digits only)")
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message(Form.age)
|
||||||
|
async def process_age(message: Message, state: FSMContext):
|
||||||
|
# Update state and data
|
||||||
|
await state.set_state(Form.gender)
|
||||||
|
await state.update_data(age=int(message.text))
|
||||||
|
|
||||||
|
# Configure ReplyKeyboardMarkup
|
||||||
|
constructor = KeyboardConstructor(KeyboardButton)
|
||||||
|
constructor.add(*(KeyboardButton(text=text) for text in GENDERS)).adjust(2)
|
||||||
|
markup = ReplyKeyboardMarkup(
|
||||||
|
resize_keyboard=True, selective=True, keyboard=constructor.export()
|
||||||
|
)
|
||||||
|
await message.reply("What is your gender?", reply_markup=markup)
|
||||||
|
|
||||||
|
|
||||||
|
@dp.message(Form.gender)
|
||||||
|
async def process_gender(message: Message, state: FSMContext):
|
||||||
|
data = await state.update_data(gender=message.text)
|
||||||
|
await state.clear()
|
||||||
|
|
||||||
|
# And send message
|
||||||
|
await message.answer(
|
||||||
|
(
|
||||||
|
f'Hi, nice to meet you, {hbold(data["name"])}\n'
|
||||||
|
f'Age: {hbold(data["age"])}\n'
|
||||||
|
f'Gender: {hbold(data["gender"])}\n'
|
||||||
|
),
|
||||||
|
reply_markup=ReplyKeyboardRemove(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
bot = Bot(token=getenv("TELEGRAM_TOKEN"), parse_mode="HTML")
|
||||||
|
|
||||||
|
await dp.start_polling(bot)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||||
|
asyncio.run(main())
|
||||||
177
tests/test_dispatcher/test_filters/test_callback_data.py
Normal file
177
tests/test_dispatcher/test_filters/test_callback_data.py
Normal file
|
|
@ -0,0 +1,177 @@
|
||||||
|
from decimal import Decimal
|
||||||
|
from enum import Enum, auto
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from magic_filter import MagicFilter
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from aiogram import F
|
||||||
|
from aiogram.dispatcher.filters.callback_data import CallbackData
|
||||||
|
from aiogram.types import CallbackQuery, User
|
||||||
|
|
||||||
|
|
||||||
|
class MyIntEnum(Enum):
|
||||||
|
FOO = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class MyStringEnum(str, Enum):
|
||||||
|
FOO = "FOO"
|
||||||
|
|
||||||
|
|
||||||
|
class MyCallback(CallbackData, prefix="test"):
|
||||||
|
foo: str
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackData:
|
||||||
|
def test_init_subclass_prefix_required(self):
|
||||||
|
assert MyCallback.prefix == "test"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="prefix required.+"):
|
||||||
|
|
||||||
|
class MyInvalidCallback(CallbackData):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_init_subclass_sep_validation(self):
|
||||||
|
assert MyCallback.sep == ":"
|
||||||
|
|
||||||
|
class MyCallback2(CallbackData, prefix="test2", sep="@"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert MyCallback2.sep == "@"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Separator symbol '@' .+ 'sp@m'"):
|
||||||
|
|
||||||
|
class MyInvalidCallback(CallbackData, prefix="sp@m", sep="@"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value,success,expected",
|
||||||
|
[
|
||||||
|
[None, True, ""],
|
||||||
|
[42, True, "42"],
|
||||||
|
["test", True, "test"],
|
||||||
|
[9.99, True, "9.99"],
|
||||||
|
[Decimal("9.99"), True, "9.99"],
|
||||||
|
[Fraction("3/2"), True, "3/2"],
|
||||||
|
[
|
||||||
|
UUID("123e4567-e89b-12d3-a456-426655440000"),
|
||||||
|
True,
|
||||||
|
"123e4567-e89b-12d3-a456-426655440000",
|
||||||
|
],
|
||||||
|
[MyIntEnum.FOO, True, "1"],
|
||||||
|
[MyStringEnum.FOO, True, "FOO"],
|
||||||
|
[..., False, "..."],
|
||||||
|
[object, False, "..."],
|
||||||
|
[object(), False, "..."],
|
||||||
|
[User(id=42, is_bot=False, first_name="test"), False, "..."],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_encode_value(self, value, success, expected):
|
||||||
|
callback = MyCallback(foo="test", bar=42)
|
||||||
|
if success:
|
||||||
|
assert callback._encode_value("test", value) == expected
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert callback._encode_value("test", value) == expected
|
||||||
|
|
||||||
|
def test_pack(self):
|
||||||
|
with pytest.raises(ValueError, match="Separator symbol .+"):
|
||||||
|
assert MyCallback(foo="te:st", bar=42).pack()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=".+is too long.+"):
|
||||||
|
assert MyCallback(foo="test" * 32, bar=42).pack()
|
||||||
|
|
||||||
|
assert MyCallback(foo="test", bar=42).pack() == "test:test:42"
|
||||||
|
|
||||||
|
def test_pack_optional(self):
|
||||||
|
class MyCallback1(CallbackData, prefix="test1"):
|
||||||
|
foo: str
|
||||||
|
bar: Optional[int] = None
|
||||||
|
|
||||||
|
assert MyCallback1(foo="spam").pack() == "test1:spam:"
|
||||||
|
assert MyCallback1(foo="spam", bar=42).pack() == "test1:spam:42"
|
||||||
|
|
||||||
|
class MyCallback2(CallbackData, prefix="test2"):
|
||||||
|
foo: Optional[str] = None
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
assert MyCallback2(bar=42).pack() == "test2::42"
|
||||||
|
assert MyCallback2(foo="spam", bar=42).pack() == "test2:spam:42"
|
||||||
|
|
||||||
|
class MyCallback3(CallbackData, prefix="test3"):
|
||||||
|
foo: Optional[str] = "experiment"
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
assert MyCallback3(bar=42).pack() == "test3:experiment:42"
|
||||||
|
assert MyCallback3(foo="spam", bar=42).pack() == "test3:spam:42"
|
||||||
|
|
||||||
|
def test_unpack(self):
|
||||||
|
with pytest.raises(TypeError, match=".+ takes 2 arguments but 3 were given"):
|
||||||
|
MyCallback.unpack("test:test:test:test")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Bad prefix .+"):
|
||||||
|
MyCallback.unpack("spam:test:test")
|
||||||
|
|
||||||
|
assert MyCallback.unpack("test:test:42") == MyCallback(foo="test", bar=42)
|
||||||
|
|
||||||
|
def test_unpack_optional(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
assert MyCallback.unpack("test:test:")
|
||||||
|
|
||||||
|
class MyCallback1(CallbackData, prefix="test1"):
|
||||||
|
foo: str
|
||||||
|
bar: Optional[int] = None
|
||||||
|
|
||||||
|
assert MyCallback1.unpack("test1:spam:") == MyCallback1(foo="spam")
|
||||||
|
assert MyCallback1.unpack("test1:spam:42") == MyCallback1(foo="spam", bar=42)
|
||||||
|
|
||||||
|
class MyCallback2(CallbackData, prefix="test2"):
|
||||||
|
foo: Optional[str] = None
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
assert MyCallback2.unpack("test2::42") == MyCallback2(bar=42)
|
||||||
|
assert MyCallback2.unpack("test2:spam:42") == MyCallback2(foo="spam", bar=42)
|
||||||
|
|
||||||
|
class MyCallback3(CallbackData, prefix="test3"):
|
||||||
|
foo: Optional[str] = "experiment"
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42)
|
||||||
|
assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42)
|
||||||
|
|
||||||
|
def test_build_filter(self):
|
||||||
|
filter_object = MyCallback.filter(F.foo == "test")
|
||||||
|
assert isinstance(filter_object.rule, MagicFilter)
|
||||||
|
assert filter_object.callback_data is MyCallback
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackDataFilter:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,rule,result",
|
||||||
|
[
|
||||||
|
["test", F.foo == "test", False],
|
||||||
|
["test:spam:42", F.foo == "test", False],
|
||||||
|
["test:test:42", F.foo == "test", {"callback_data": MyCallback(foo="test", bar=42)}],
|
||||||
|
["test:test:", F.foo == "test", False],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call(self, query, rule, result):
|
||||||
|
callback_query = CallbackQuery(
|
||||||
|
id="1",
|
||||||
|
from_user=User(id=42, is_bot=False, first_name="test"),
|
||||||
|
data=query,
|
||||||
|
chat_instance="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_object = MyCallback.filter(rule)
|
||||||
|
assert await filter_object(callback_query) == result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_call(self):
|
||||||
|
filter_object = MyCallback.filter(F.test)
|
||||||
|
assert not await filter_object(User(id=42, is_bot=False, first_name="test"))
|
||||||
Loading…
Add table
Add a link
Reference in a new issue