diff --git a/aiogram/client/errors_middleware.py b/aiogram/client/errors_middleware.py new file mode 100644 index 00000000..59d95f07 --- /dev/null +++ b/aiogram/client/errors_middleware.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, List, Type + +from aiogram.methods import Response, TelegramMethod +from aiogram.types import TelegramObject +from aiogram.utils.exceptions.base import TelegramAPIError +from aiogram.utils.exceptions.exceptions import ( + CantParseEntitiesStartTag, + CantParseEntitiesUnclosed, + CantParseEntitiesUnmatchedTags, + CantParseEntitiesUnsupportedTag, + DetailedTelegramAPIError, +) + +if TYPE_CHECKING: + from aiogram.client.bot import Bot + from aiogram.client.session.base import NextRequestMiddlewareType + + +class RequestErrorMiddleware: + def __init__(self) -> None: + self._registry: List[Type[DetailedTelegramAPIError]] = [ + CantParseEntitiesStartTag, + CantParseEntitiesUnmatchedTags, + CantParseEntitiesUnclosed, + CantParseEntitiesUnsupportedTag, + ] + + def mount(self, error: Type[DetailedTelegramAPIError]) -> Type[DetailedTelegramAPIError]: + if error in self: + raise ValueError(f"{error!r} is already registered") + if not hasattr(error, "patterns"): + raise ValueError(f"{error!r} has no attribute 'patterns'") + self._registry.append(error) + return error + + def detect_error(self, err: TelegramAPIError) -> TelegramAPIError: + message = err.message + for variant in self._registry: + for pattern in variant.patterns: + if match := re.match(pattern, message): + return variant( + method=err.method, + message=err.message, + match=match, + ) + return err + + def __contains__(self, item: Type[DetailedTelegramAPIError]) -> bool: + return item in self._registry + + async def __call__( + self, + bot: Bot, + method: TelegramMethod[TelegramObject], + make_request: NextRequestMiddlewareType, + ) -> Response[TelegramObject]: + try: + return await make_request(bot, method) + except TelegramAPIError as e: + detected_err = self.detect_error(err=e) + raise detected_err from e diff --git a/aiogram/dispatcher/filters/callback_data.py b/aiogram/dispatcher/filters/callback_data.py new file mode 100644 index 00000000..68f5b773 --- /dev/null +++ b/aiogram/dispatcher/filters/callback_data.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from decimal import Decimal +from enum import Enum +from fractions import Fraction +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union +from uuid import UUID + +from magic_filter import MagicFilter +from pydantic import BaseModel + +from aiogram.dispatcher.filters import BaseFilter +from aiogram.types import CallbackQuery + +T = TypeVar("T", bound="CallbackData") + +MAX_CALLBACK_LENGTH: int = 64 + + +class CallbackDataException(Exception): + pass + + +class CallbackData(BaseModel): + if TYPE_CHECKING: # pragma: no cover + sep: str + prefix: str + + def __init_subclass__(cls, **kwargs: Any) -> None: + if "prefix" not in kwargs: + raise ValueError( + f"prefix required, usage example: " + f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`" + ) + cls.sep = kwargs.pop("sep", ":") + cls.prefix = kwargs.pop("prefix") + if cls.sep in cls.prefix: + raise ValueError( + f"Separator symbol {cls.sep!r} can not be used inside prefix {cls.prefix!r}" + ) + + def _encode_value(self, key: str, value: Any) -> str: + if value is None: + return "" + if isinstance(value, Enum): + return str(value.value) + if isinstance(value, (int, str, float, Decimal, Fraction, UUID)): + return str(value) + raise ValueError( + f"Attribute {key}={value!r} of type {type(value).__name__!r}" + f" can not be packed to callback data" + ) + + def pack(self) -> str: + result = [self.prefix] + for key, value in self.dict().items(): + encoded = self._encode_value(key, value) + if self.sep in encoded: + raise ValueError( + f"Separator symbol {self.sep!r} can not be used in value {key}={encoded!r}" + ) + result.append(encoded) + callback_data = self.sep.join(result) + if len(callback_data.encode()) > MAX_CALLBACK_LENGTH: + raise ValueError( + f"Resulted callback data is too long! len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}" + ) + return callback_data + + @classmethod + def unpack(cls: Type[T], value: str) -> T: + prefix, *parts = value.split(cls.sep) + names = cls.__fields__.keys() + if len(parts) != len(names): + raise TypeError( + f"Callback data {cls.__name__!r} takes {len(names)} arguments but {len(parts)} were given" + ) + if prefix != cls.prefix: + raise ValueError(f"Bad prefix ({prefix!r} != {cls.prefix!r})") + payload = {} + for k, v in zip(names, parts): # type: str, Optional[str] + if field := cls.__fields__.get(k): + if v == "" and not field.required: + v = None + payload[k] = v + return cls(**payload) + + @classmethod + def filter(cls, rule: MagicFilter) -> CallbackQueryFilter: + return CallbackQueryFilter(callback_data=cls, rule=rule) + + class Config: + use_enum_values = True + + +class CallbackQueryFilter(BaseFilter): + callback_data: Type[CallbackData] + rule: MagicFilter + + async def __call__(self, query: CallbackQuery) -> Union[bool, Dict[str, Any]]: + if not isinstance(query, CallbackQuery) or not query.data: + return False + try: + callback_data = self.callback_data.unpack(query.data) + except (TypeError, ValueError): + return False + + if self.rule.resolve(callback_data): + return {"callback_data": callback_data} + return False + + class Config: + arbitrary_types_allowed = True diff --git a/aiogram/utils/exceptions/__init__.py b/aiogram/utils/exceptions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/utils/exceptions/bad_request.py b/aiogram/utils/exceptions/bad_request.py new file mode 100644 index 00000000..9b9d878a --- /dev/null +++ b/aiogram/utils/exceptions/bad_request.py @@ -0,0 +1,5 @@ +from aiogram.utils.exceptions.base import DetailedTelegramAPIError + + +class BadRequest(DetailedTelegramAPIError): + pass diff --git a/aiogram/utils/exceptions/base.py b/aiogram/utils/exceptions/base.py new file mode 100644 index 00000000..fdbc2514 --- /dev/null +++ b/aiogram/utils/exceptions/base.py @@ -0,0 +1,40 @@ +from typing import ClassVar, List, Match, Optional, TypeVar + +from aiogram.methods import TelegramMethod +from aiogram.methods.base import TelegramType + +ErrorType = TypeVar("ErrorType") + + +class TelegramAPIError(Exception): + url: Optional[str] = None + + def __init__( + self, + method: TelegramMethod[TelegramType], + message: str, + ) -> None: + self.method = method + self.message = message + + def render_description(self) -> str: + return self.message + + def __str__(self) -> str: + message = [self.render_description()] + if self.url: + message.append(f"(background on this error at: {self.url})") + return "\n".join(message) + + +class DetailedTelegramAPIError(TelegramAPIError): + patterns: ClassVar[List[str]] + + def __init__( + self, + method: TelegramMethod[TelegramType], + message: str, + match: Match[str], + ) -> None: + super().__init__(method=method, message=message) + self.match: Match[str] = match diff --git a/aiogram/utils/exceptions/conflict.py b/aiogram/utils/exceptions/conflict.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/utils/exceptions/network.py b/aiogram/utils/exceptions/network.py new file mode 100644 index 00000000..067b1a80 --- /dev/null +++ b/aiogram/utils/exceptions/network.py @@ -0,0 +1,5 @@ +from aiogram.utils.exceptions.base import DetailedTelegramAPIError + + +class NetworkError(DetailedTelegramAPIError): + pass diff --git a/aiogram/utils/exceptions/not_found.py b/aiogram/utils/exceptions/not_found.py new file mode 100644 index 00000000..8dfb344b --- /dev/null +++ b/aiogram/utils/exceptions/not_found.py @@ -0,0 +1,5 @@ +from aiogram.utils.exceptions.base import DetailedTelegramAPIError + + +class NotFound(DetailedTelegramAPIError): + pass diff --git a/aiogram/utils/exceptions/server.py b/aiogram/utils/exceptions/server.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/utils/exceptions/special.py b/aiogram/utils/exceptions/special.py new file mode 100644 index 00000000..0568f900 --- /dev/null +++ b/aiogram/utils/exceptions/special.py @@ -0,0 +1,46 @@ +from typing import Optional + +from aiogram.methods import TelegramMethod +from aiogram.methods.base import TelegramType +from aiogram.utils.exceptions.base import TelegramAPIError + + +class RetryAfter(TelegramAPIError): + url = "https://core.telegram.org/bots/faq#my-bot-is-hitting-limits-how-do-i-avoid-this" + + def __init__( + self, + method: TelegramMethod[TelegramType], + message: str, + retry_after: int, + ) -> None: + super().__init__(method=method, message=message) + self.retry_after = retry_after + + def render_description(self) -> str: + description = f"Flood control exceeded on method {type(self.method).__name__!r}" + if chat_id := getattr(self.method, "chat_id", None): + description += f" in chat {chat_id}" + description += f". Retry in {self.retry_after} seconds." + return description + + +class MigrateToChat(TelegramAPIError): + url = "https://core.telegram.org/bots/api#responseparameters" + + def __init__( + self, + method: TelegramMethod[TelegramType], + message: str, + migrate_to_chat_id: int, + ) -> None: + super().__init__(method=method, message=message) + self.migrate_to_chat_id = migrate_to_chat_id + + def render_message(self) -> Optional[str]: + description = ( + f"The group has been migrated to a supergroup with id {self.migrate_to_chat_id}" + ) + if chat_id := getattr(self.method, "chat_id", None): + description += f" from {chat_id}" + return description diff --git a/aiogram/utils/exceptions/unauthorized.py b/aiogram/utils/exceptions/unauthorized.py new file mode 100644 index 00000000..e69de29b diff --git a/aiogram/utils/exceptions/util.py b/aiogram/utils/exceptions/util.py new file mode 100644 index 00000000..a7cb191e --- /dev/null +++ b/aiogram/utils/exceptions/util.py @@ -0,0 +1,20 @@ +def mark_line(text: str, offset: int, length: int = 1) -> str: + try: + if offset > 0 and (new_line_pos := text[:offset].rindex("\n")): + text = "..." + text[:new_line_pos] + offset -= new_line_pos - 3 + except ValueError: + pass + + if offset > 10: + text = "..." + text[offset - 10 :] + offset = 13 + + mark = " " * offset + mark += "^" * length + try: + if new_line_pos := text[len(mark) :].index("\n"): + text = text[:new_line_pos].rstrip() + "..." + except ValueError: + pass + return text + "\n" + mark diff --git a/examples/finite_state_machine.py b/examples/finite_state_machine.py new file mode 100644 index 00000000..65266b64 --- /dev/null +++ b/examples/finite_state_machine.py @@ -0,0 +1,111 @@ +import asyncio +import logging +import sys +from os import getenv + +from aiogram import Bot, Dispatcher, F +from aiogram.dispatcher.filters import Command +from aiogram.dispatcher.fsm.context import FSMContext +from aiogram.dispatcher.fsm.state import State, StatesGroup +from aiogram.types import Message, ReplyKeyboardRemove, ReplyKeyboardMarkup, KeyboardButton +from aiogram.utils.markdown import hbold +from aiogram.utils.markup import KeyboardConstructor + +GENDERS = ["Male", "Female", "Helicopter", "Other"] + +dp = Dispatcher() + + +# States +class Form(StatesGroup): + name = State() # Will be represented in storage as 'Form:name' + age = State() # Will be represented in storage as 'Form:age' + gender = State() # Will be represented in storage as 'Form:gender' + + +@dp.message(Command(commands=["start"])) +async def cmd_start(message: Message, state: FSMContext): + """ + Conversation's entry point + """ + # Set state + await state.set_state(Form.name) + await message.answer("Hi there! What's your name?") + + +@dp.message(Command(commands=["cancel"])) +@dp.message(F.text.lower() == "cancel") +async def cancel_handler(message: Message, state: FSMContext): + """ + Allow user to cancel any action + """ + current_state = await state.get_state() + if current_state is None: + return + + logging.info("Cancelling state %r", current_state) + # Cancel state and inform user about it + await state.clear() + # And remove keyboard (just in case) + await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove()) + + +@dp.message(Form.name) +async def process_name(message: Message, state: FSMContext): + """ + Process user name + """ + await state.update_data(name=message.text) + await state.set_state(Form.age) + await message.answer("How old are you?") + + +# Check age. Age gotta be digit +@dp.message(Form.age, ~F.text.isdigit()) +async def process_age_invalid(message: Message): + """ + If age is invalid + """ + return await message.answer("Age gotta be a number.\nHow old are you? (digits only)") + + +@dp.message(Form.age) +async def process_age(message: Message, state: FSMContext): + # Update state and data + await state.set_state(Form.gender) + await state.update_data(age=int(message.text)) + + # Configure ReplyKeyboardMarkup + constructor = KeyboardConstructor(KeyboardButton) + constructor.add(*(KeyboardButton(text=text) for text in GENDERS)).adjust(2) + markup = ReplyKeyboardMarkup( + resize_keyboard=True, selective=True, keyboard=constructor.export() + ) + await message.reply("What is your gender?", reply_markup=markup) + + +@dp.message(Form.gender) +async def process_gender(message: Message, state: FSMContext): + data = await state.update_data(gender=message.text) + await state.clear() + + # And send message + await message.answer( + ( + f'Hi, nice to meet you, {hbold(data["name"])}\n' + f'Age: {hbold(data["age"])}\n' + f'Gender: {hbold(data["gender"])}\n' + ), + reply_markup=ReplyKeyboardRemove(), + ) + + +async def main(): + bot = Bot(token=getenv("TELEGRAM_TOKEN"), parse_mode="HTML") + + await dp.start_polling(bot) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, stream=sys.stdout) + asyncio.run(main()) diff --git a/tests/test_dispatcher/test_filters/test_callback_data.py b/tests/test_dispatcher/test_filters/test_callback_data.py new file mode 100644 index 00000000..f71ca706 --- /dev/null +++ b/tests/test_dispatcher/test_filters/test_callback_data.py @@ -0,0 +1,177 @@ +from decimal import Decimal +from enum import Enum, auto +from fractions import Fraction +from typing import Optional +from uuid import UUID + +import pytest +from magic_filter import MagicFilter +from pydantic import ValidationError + +from aiogram import F +from aiogram.dispatcher.filters.callback_data import CallbackData +from aiogram.types import CallbackQuery, User + + +class MyIntEnum(Enum): + FOO = auto() + + +class MyStringEnum(str, Enum): + FOO = "FOO" + + +class MyCallback(CallbackData, prefix="test"): + foo: str + bar: int + + +class TestCallbackData: + def test_init_subclass_prefix_required(self): + assert MyCallback.prefix == "test" + + with pytest.raises(ValueError, match="prefix required.+"): + + class MyInvalidCallback(CallbackData): + pass + + def test_init_subclass_sep_validation(self): + assert MyCallback.sep == ":" + + class MyCallback2(CallbackData, prefix="test2", sep="@"): + pass + + assert MyCallback2.sep == "@" + + with pytest.raises(ValueError, match="Separator symbol '@' .+ 'sp@m'"): + + class MyInvalidCallback(CallbackData, prefix="sp@m", sep="@"): + pass + + @pytest.mark.parametrize( + "value,success,expected", + [ + [None, True, ""], + [42, True, "42"], + ["test", True, "test"], + [9.99, True, "9.99"], + [Decimal("9.99"), True, "9.99"], + [Fraction("3/2"), True, "3/2"], + [ + UUID("123e4567-e89b-12d3-a456-426655440000"), + True, + "123e4567-e89b-12d3-a456-426655440000", + ], + [MyIntEnum.FOO, True, "1"], + [MyStringEnum.FOO, True, "FOO"], + [..., False, "..."], + [object, False, "..."], + [object(), False, "..."], + [User(id=42, is_bot=False, first_name="test"), False, "..."], + ], + ) + def test_encode_value(self, value, success, expected): + callback = MyCallback(foo="test", bar=42) + if success: + assert callback._encode_value("test", value) == expected + else: + with pytest.raises(ValueError): + assert callback._encode_value("test", value) == expected + + def test_pack(self): + with pytest.raises(ValueError, match="Separator symbol .+"): + assert MyCallback(foo="te:st", bar=42).pack() + + with pytest.raises(ValueError, match=".+is too long.+"): + assert MyCallback(foo="test" * 32, bar=42).pack() + + assert MyCallback(foo="test", bar=42).pack() == "test:test:42" + + def test_pack_optional(self): + class MyCallback1(CallbackData, prefix="test1"): + foo: str + bar: Optional[int] = None + + assert MyCallback1(foo="spam").pack() == "test1:spam:" + assert MyCallback1(foo="spam", bar=42).pack() == "test1:spam:42" + + class MyCallback2(CallbackData, prefix="test2"): + foo: Optional[str] = None + bar: int + + assert MyCallback2(bar=42).pack() == "test2::42" + assert MyCallback2(foo="spam", bar=42).pack() == "test2:spam:42" + + class MyCallback3(CallbackData, prefix="test3"): + foo: Optional[str] = "experiment" + bar: int + + assert MyCallback3(bar=42).pack() == "test3:experiment:42" + assert MyCallback3(foo="spam", bar=42).pack() == "test3:spam:42" + + def test_unpack(self): + with pytest.raises(TypeError, match=".+ takes 2 arguments but 3 were given"): + MyCallback.unpack("test:test:test:test") + + with pytest.raises(ValueError, match="Bad prefix .+"): + MyCallback.unpack("spam:test:test") + + assert MyCallback.unpack("test:test:42") == MyCallback(foo="test", bar=42) + + def test_unpack_optional(self): + with pytest.raises(ValidationError): + assert MyCallback.unpack("test:test:") + + class MyCallback1(CallbackData, prefix="test1"): + foo: str + bar: Optional[int] = None + + assert MyCallback1.unpack("test1:spam:") == MyCallback1(foo="spam") + assert MyCallback1.unpack("test1:spam:42") == MyCallback1(foo="spam", bar=42) + + class MyCallback2(CallbackData, prefix="test2"): + foo: Optional[str] = None + bar: int + + assert MyCallback2.unpack("test2::42") == MyCallback2(bar=42) + assert MyCallback2.unpack("test2:spam:42") == MyCallback2(foo="spam", bar=42) + + class MyCallback3(CallbackData, prefix="test3"): + foo: Optional[str] = "experiment" + bar: int + + assert MyCallback3.unpack("test3:experiment:42") == MyCallback3(bar=42) + assert MyCallback3.unpack("test3:spam:42") == MyCallback3(foo="spam", bar=42) + + def test_build_filter(self): + filter_object = MyCallback.filter(F.foo == "test") + assert isinstance(filter_object.rule, MagicFilter) + assert filter_object.callback_data is MyCallback + + +class TestCallbackDataFilter: + @pytest.mark.parametrize( + "query,rule,result", + [ + ["test", F.foo == "test", False], + ["test:spam:42", F.foo == "test", False], + ["test:test:42", F.foo == "test", {"callback_data": MyCallback(foo="test", bar=42)}], + ["test:test:", F.foo == "test", False], + ], + ) + @pytest.mark.asyncio + async def test_call(self, query, rule, result): + callback_query = CallbackQuery( + id="1", + from_user=User(id=42, is_bot=False, first_name="test"), + data=query, + chat_instance="test", + ) + + filter_object = MyCallback.filter(rule) + assert await filter_object(callback_query) == result + + @pytest.mark.asyncio + async def test_invalid_call(self): + filter_object = MyCallback.filter(F.test) + assert not await filter_object(User(id=42, is_bot=False, first_name="test"))