* Rework middlewares, separate management to `MiddlewareManager` class

* Rework middlewares

* Added changes description for redis

* Added changes description for redis

* Fixed tests with Redis // aioredis replacement

* Changed msg.<html/md>_text attributes behaviour

* Added changelog for spoilers

* Added possibility to get command magic result as handler arguments
This commit is contained in:
Alex Root Junior 2022-04-16 19:07:32 +03:00 committed by GitHub
parent 930bca0876
commit 286cf39c8a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 1380 additions and 804 deletions

View file

@ -2,20 +2,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v4.2.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 21.8b0
rev: 22.3.0
hooks:
- id: black
files: &files '^(aiogram|tests|examples)'
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.9.3
rev: v5.10.1
hooks:
- id: isort
additional_dependencies: [ toml ]

1
CHANGES/865.bugfix.rst Normal file
View file

@ -0,0 +1 @@
Added parsing of spoiler message entity

2
CHANGES/874.misc.rst Normal file
View file

@ -0,0 +1,2 @@
Changed :code:`Message.html_text` and :code:`Message.md_text` attributes behaviour when message has no text.
The empty string will be used instead of raising error.

1
CHANGES/882.misc.rst Normal file
View file

@ -0,0 +1 @@
Used `redis-py` instead of `aioredis` package in due to this packages was merged into single one

3
CHANGES/883.misc.rst Normal file
View file

@ -0,0 +1,3 @@
Solved common naming problem with middlewares that confusing too much developers
- now you can't see the `middleware` and `middlewares` attributes at the same point
because this functionality encapsulated to special interface.

1
CHANGES/885.bugfix.rst Normal file
View file

@ -0,0 +1 @@
Fixed CallbackData factory parsing IntEnum's

1
CHANGES/889.feature.rst Normal file
View file

@ -0,0 +1 @@
Added possibility to get command magic result as handler argument

View file

@ -47,7 +47,7 @@ help:
.PHONY: install
install:
poetry install -E fast -E redis -E proxy -E i18n -E docs
poetry install -E fast -E redis -E proxy -E i18n -E docs --remove-untracked
$(py) pre-commit install
.PHONY: clean
@ -94,9 +94,6 @@ test: test-run-services
test-coverage: test-run-services
mkdir -p $(reports_dir)/tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ --redis $(redis_connection)
.PHONY: test-coverage-report
test-coverage-report:
$(py) coverage html -d $(reports_dir)/coverage
.PHONY: test-coverage-view

View file

@ -3,22 +3,9 @@ from __future__ import annotations
import abc
import datetime
import json
from functools import partial
from http import HTTPStatus
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
Final,
List,
Optional,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Final, Optional, Type, Union, cast
from aiogram.exceptions import (
RestartingTelegram,
@ -36,26 +23,15 @@ from aiogram.exceptions import (
from ...methods import Response, TelegramMethod
from ...methods.base import TelegramType
from ...types import UNSET, TelegramObject
from ...types import UNSET
from ..telegram import PRODUCTION, TelegramAPIServer
from .middlewares.base import BaseRequestMiddleware
from .middlewares.manager import RequestMiddlewareManager
if TYPE_CHECKING:
from ..bot import Bot
_JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]
NextRequestMiddlewareType = Callable[
["Bot", TelegramMethod[TelegramObject]], Awaitable[Response[TelegramObject]]
]
RequestMiddlewareType = Union[
BaseRequestMiddleware,
Callable[
[NextRequestMiddlewareType, "Bot", TelegramMethod[TelegramType]],
Awaitable[Response[TelegramType]],
],
]
DEFAULT_TIMEOUT: Final[float] = 60.0
@ -80,7 +56,7 @@ class BaseSession(abc.ABC):
self.json_dumps = json_dumps
self.timeout = timeout
self.middlewares: List[RequestMiddlewareType[TelegramObject]] = []
self.middleware = RequestMiddlewareManager()
def check_response(
self, method: TelegramMethod[TelegramType], status_code: int, content: str
@ -185,19 +161,11 @@ class BaseSession(abc.ABC):
return {k: self.clean_json(v) for k, v in value.items() if v is not None}
return value
def middleware(
self, middleware: RequestMiddlewareType[TelegramObject]
) -> RequestMiddlewareType[TelegramObject]:
self.middlewares.append(middleware)
return middleware
async def __call__(
self, bot: Bot, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET
) -> TelegramType:
middleware = partial(self.make_request, timeout=timeout)
for m in reversed(self.middlewares):
middleware = partial(m, middleware) # type: ignore
return await middleware(bot, method)
middleware = self.middleware.wrap_middlewares(self.make_request, timeout=timeout)
return cast(TelegramType, await middleware(bot, method))
async def __aenter__(self) -> BaseSession:
return self

View file

@ -1,15 +1,23 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Awaitable, Callable
from typing import TYPE_CHECKING, Awaitable, Callable, Union
from aiogram.methods import Response, TelegramMethod
from aiogram.types import TelegramObject
from aiogram.methods.base import TelegramType
if TYPE_CHECKING:
from ...bot import Bot
NextRequestMiddlewareType = Callable[
["Bot", TelegramMethod[TelegramObject]], Awaitable[Response[TelegramObject]]
["Bot", TelegramMethod[TelegramType]], Awaitable[Response[TelegramType]]
]
RequestMiddlewareType = Union[
"BaseRequestMiddleware",
Callable[
[NextRequestMiddlewareType[TelegramType], "Bot", TelegramMethod[TelegramType]],
Awaitable[Response[TelegramType]],
],
]
@ -21,10 +29,10 @@ class BaseRequestMiddleware(ABC):
@abstractmethod
async def __call__(
self,
make_request: NextRequestMiddlewareType,
make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot",
method: TelegramMethod[TelegramObject],
) -> Response[TelegramObject]:
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
"""
Execute middleware

View file

@ -0,0 +1,79 @@
from __future__ import annotations
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Sequence,
Union,
overload,
)
from aiogram.client.session.middlewares.base import (
NextRequestMiddlewareType,
RequestMiddlewareType,
)
from aiogram.methods import Response
from aiogram.methods.base import TelegramMethod, TelegramType
from aiogram.types import TelegramObject
if TYPE_CHECKING:
from aiogram import Bot
class RequestMiddlewareManager(Sequence[RequestMiddlewareType[TelegramObject]]):
def __init__(self) -> None:
self._middlewares: List[RequestMiddlewareType[TelegramObject]] = []
def register(
self,
middleware: RequestMiddlewareType[TelegramObject],
) -> RequestMiddlewareType[TelegramObject]:
self._middlewares.append(middleware)
return middleware
def unregister(self, middleware: RequestMiddlewareType[TelegramObject]) -> None:
self._middlewares.remove(middleware)
def __call__(
self,
middleware: Optional[RequestMiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[RequestMiddlewareType[TelegramObject]], RequestMiddlewareType[TelegramObject]],
RequestMiddlewareType[TelegramObject],
]:
if middleware is None:
return self.register
return self.register(middleware)
@overload
def __getitem__(self, item: int) -> RequestMiddlewareType[TelegramObject]:
pass
@overload
def __getitem__(self, item: slice) -> Sequence[RequestMiddlewareType[TelegramObject]]:
pass
def __getitem__(
self, item: Union[int, slice]
) -> Union[
RequestMiddlewareType[TelegramObject], Sequence[RequestMiddlewareType[TelegramObject]]
]:
return self._middlewares[item]
def __len__(self) -> int:
return len(self._middlewares)
def wrap_middlewares(
self,
callback: Callable[[Bot, TelegramMethod[TelegramType]], Awaitable[Response[TelegramType]]],
**kwargs: Any,
) -> NextRequestMiddlewareType[TelegramType]:
middleware = partial(callback, **kwargs)
for m in reversed(self._middlewares):
middleware = partial(m, middleware) # type: ignore
return middleware

View file

@ -3,8 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Type
from aiogram import loggers
from aiogram.methods import TelegramMethod
from aiogram.methods.base import Response
from aiogram.types import TelegramObject
from aiogram.methods.base import Response, TelegramType
from .base import BaseRequestMiddleware, NextRequestMiddlewareType
@ -25,10 +24,10 @@ class RequestLogging(BaseRequestMiddleware):
async def __call__(
self,
make_request: NextRequestMiddlewareType,
make_request: NextRequestMiddlewareType[TelegramType],
bot: "Bot",
method: TelegramMethod[TelegramObject],
) -> Response[TelegramObject]:
method: TelegramMethod[TelegramType],
) -> Response[TelegramType]:
if type(method) not in self.ignore_methods:
loggers.middlewares.info(
"Make request with method=%r by bot id=%d",

View file

@ -36,8 +36,19 @@ class Dispatcher(Router):
storage: Optional[BaseStorage] = None,
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
events_isolation: Optional[BaseEventIsolation] = None,
disable_fsm: bool = False,
**kwargs: Any,
) -> None:
"""
Root router
:param storage: Storage for FSM
:param fsm_strategy: FSM strategy
:param events_isolation: Events isolation
:param disable_fsm: Disable FSM, note that if you disable FSM
then you should not use storage and events isolation
:param kwargs: Other arguments, will be passed as keyword arguments to handlers
"""
super(Dispatcher, self).__init__(**kwargs)
# Telegram API provides originally only one event type - Update
@ -48,7 +59,8 @@ class Dispatcher(Router):
)
self.update.register(self._listen_update)
# Error handlers should work is out of all other functions and be registered before all others middlewares
# Error handlers should work is out of all other functions
# and should be registered before all others middlewares
self.update.outer_middleware(ErrorsMiddleware(self))
# User context middleware makes small optimization for all other builtin
@ -62,11 +74,31 @@ class Dispatcher(Router):
strategy=fsm_strategy,
events_isolation=events_isolation if events_isolation else DisabledEventIsolation(),
)
self.update.outer_middleware(self.fsm)
if not disable_fsm:
# Note that when FSM middleware is disabled, the event isolation is also disabled
# Because the isolation mechanism is a part of the FSM
self.update.outer_middleware(self.fsm)
self.shutdown.register(self.fsm.close)
self._data: Dict[str, Any] = {}
self._running_lock = Lock()
def __getitem__(self, item: str) -> Any:
return self._data[item]
def __setitem__(self, key: str, value: Any) -> None:
self._data[key] = value
def __delitem__(self, key: str) -> None:
del self._data[key]
def get(self, key: str, /, default: Optional[Any] = None) -> Optional[Any]:
return self._data.get(key, default)
@property
def storage(self) -> BaseStorage:
return self.fsm.storage
@property
def parent_router(self) -> None:
"""
@ -100,8 +132,15 @@ class Dispatcher(Router):
token = Bot.set_current(bot)
try:
kwargs.update(bot=bot)
response = await self.update.wrap_outer_middleware(self.update.trigger, update, kwargs)
response = await self.update.wrap_outer_middleware(
self.update.trigger,
update,
{
**self._data,
**kwargs,
"bot": bot,
},
)
handled = response is not UNHANDLED
return response
finally:

View file

@ -1,33 +1,17 @@
from __future__ import annotations
import functools
from inspect import isclass
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple, Type
from pydantic import ValidationError
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from ...exceptions import FiltersResolveError
from ...types import TelegramObject
from ..filters.base import BaseFilter
from .bases import (
REJECTED,
UNHANDLED,
MiddlewareEventType,
MiddlewareType,
NextMiddlewareType,
SkipHandler,
)
from .bases import REJECTED, UNHANDLED, MiddlewareType, SkipHandler
from .handler import CallbackType, FilterObject, FilterType, HandlerObject, HandlerType
if TYPE_CHECKING:
@ -48,8 +32,9 @@ class TelegramEventObserver:
self.handlers: List[HandlerObject] = []
self.filters: List[Type[BaseFilter]] = []
self.outer_middlewares: List[MiddlewareType[TelegramObject]] = []
self.middlewares: List[MiddlewareType[TelegramObject]] = []
self.middleware = MiddlewareManager()
self.outer_middleware = MiddlewareManager()
# Re-used filters check method from already implemented handler object
# with dummy callback which never will be used
@ -75,7 +60,11 @@ class TelegramEventObserver:
:param bound_filter:
"""
if not issubclass(bound_filter, BaseFilter):
# TODO: This functionality should be deprecated in the future
# in due to bound filter has uncontrollable ordering and
# makes debugging process is harder that explicit using filters
if not isclass(bound_filter) or not issubclass(bound_filter, BaseFilter):
raise TypeError(
"bound_filter() argument 'bound_filter' must be subclass of BaseFilter"
)
@ -97,18 +86,11 @@ class TelegramEventObserver:
yield filter_
registry.append(filter_)
def _resolve_middlewares(self, *, outer: bool = False) -> List[MiddlewareType[TelegramObject]]:
"""
Get all middlewares in a tree
:param *:
"""
middlewares = []
if outer:
middlewares.extend(self.outer_middlewares)
else:
for router in reversed(tuple(self.router.chain_head)):
observer = router.observers[self.event_name]
middlewares.extend(observer.middlewares)
def _resolve_middlewares(self) -> List[MiddlewareType[TelegramObject]]:
middlewares: List[MiddlewareType[TelegramObject]] = []
for router in reversed(tuple(self.router.chain_head)):
observer = router.observers[self.event_name]
middlewares.extend(observer.middleware)
return middlewares
@ -198,23 +180,13 @@ class TelegramEventObserver:
)
return callback
@classmethod
def _wrap_middleware(
cls, middlewares: List[MiddlewareType[MiddlewareEventType]], handler: HandlerType
) -> NextMiddlewareType[MiddlewareEventType]:
@functools.wraps(handler)
def mapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
return handler(event, **kwargs)
middleware = mapper
for m in reversed(middlewares):
middleware = functools.partial(m, middleware)
return middleware
def wrap_outer_middleware(
self, callback: Any, event: TelegramObject, data: Dict[str, Any]
) -> Any:
wrapped_outer = self._wrap_middleware(self._resolve_middlewares(outer=True), callback)
wrapped_outer = self.middleware.wrap_middlewares(
self.outer_middleware,
callback,
)
return wrapped_outer(event, data)
async def trigger(self, event: TelegramObject, **kwargs: Any) -> Any:
@ -233,8 +205,9 @@ class TelegramEventObserver:
if result:
kwargs.update(data, handler=handler)
try:
wrapped_inner = self._wrap_middleware(
self._resolve_middlewares(), handler.call
wrapped_inner = self.outer_middleware.wrap_middlewares(
self._resolve_middlewares(),
handler.call,
)
return await wrapped_inner(event, kwargs)
except SkipHandler:
@ -254,71 +227,3 @@ class TelegramEventObserver:
return callback
return wrapper
def middleware(
self,
middleware: Optional[MiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
MiddlewareType[TelegramObject],
]:
"""
Decorator for registering inner middlewares
Usage:
.. code-block:: python
@<event>.middleware() # via decorator (variant 1)
.. code-block:: python
@<event>.middleware # via decorator (variant 2)
.. code-block:: python
async def my_middleware(handler, event, data): ...
<event>.middleware(my_middleware) # via method
"""
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
self.middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)
def outer_middleware(
self,
middleware: Optional[MiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
MiddlewareType[TelegramObject],
]:
"""
Decorator for registering outer middlewares
Usage:
.. code-block:: python
@<event>.outer_middleware() # via decorator (variant 1)
.. code-block:: python
@<event>.outer_middleware # via decorator (variant 2)
.. code-block:: python
async def my_middleware(handler, event, data): ...
<event>.outer_middleware(my_middleware) # via method
"""
def wrapper(m: MiddlewareType[TelegramObject]) -> MiddlewareType[TelegramObject]:
self.outer_middlewares.append(m)
return m
if middleware is None:
return wrapper
return wrapper(middleware)

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from decimal import Decimal
from enum import Enum
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Literal, Optional, Type, TypeVar, Union
from uuid import UUID
from magic_filter import MagicFilter
@ -22,9 +22,20 @@ class CallbackDataException(Exception):
class CallbackData(BaseModel):
"""
Base class for callback data wrapper
This class should be used as super-class of user-defined callbacks.
The class-keyword :code:`prefix` is required to define prefix
and also the argument :code:`sep` can be passed to define separator (default is :code:`:`).
"""
if TYPE_CHECKING:
sep: str
prefix: str
__separator__: ClassVar[str]
"""Data separator (default is :code:`:`)"""
__prefix__: ClassVar[str]
"""Callback prefix"""
def __init_subclass__(cls, **kwargs: Any) -> None:
if "prefix" not in kwargs:
@ -32,12 +43,14 @@ class CallbackData(BaseModel):
f"prefix required, usage example: "
f"`class {cls.__name__}(CallbackData, prefix='my_callback'): ...`"
)
cls.sep = kwargs.pop("sep", ":")
cls.prefix = kwargs.pop("prefix")
if cls.sep in cls.prefix:
cls.__separator__ = kwargs.pop("sep", ":")
cls.__prefix__ = kwargs.pop("prefix")
if cls.__separator__ in cls.__prefix__:
raise ValueError(
f"Separator symbol {cls.sep!r} can not be used inside prefix {cls.prefix!r}"
f"Separator symbol {cls.__separator__!r} can not be used "
f"inside prefix {cls.__prefix__!r}"
)
super().__init_subclass__(**kwargs)
def _encode_value(self, key: str, value: Any) -> str:
if value is None:
@ -52,31 +65,45 @@ class CallbackData(BaseModel):
)
def pack(self) -> str:
result = [self.prefix]
"""
Generate callback data string
:return: valid callback data for Telegram Bot API
"""
result = [self.__prefix__]
for key, value in self.dict().items():
encoded = self._encode_value(key, value)
if self.sep in encoded:
if self.__separator__ in encoded:
raise ValueError(
f"Separator symbol {self.sep!r} can not be used in value {key}={encoded!r}"
f"Separator symbol {self.__separator__!r} can not be used "
f"in value {key}={encoded!r}"
)
result.append(encoded)
callback_data = self.sep.join(result)
callback_data = self.__separator__.join(result)
if len(callback_data.encode()) > MAX_CALLBACK_LENGTH:
raise ValueError(
f"Resulted callback data is too long! len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
f"Resulted callback data is too long! "
f"len({callback_data!r}.encode()) > {MAX_CALLBACK_LENGTH}"
)
return callback_data
@classmethod
def unpack(cls: Type[T], value: str) -> T:
prefix, *parts = value.split(cls.sep)
"""
Parse callback data string
:param value: value from Telegram
:return: instance of CallbackData
"""
prefix, *parts = value.split(cls.__separator__)
names = cls.__fields__.keys()
if len(parts) != len(names):
raise TypeError(
f"Callback data {cls.__name__!r} takes {len(names)} arguments but {len(parts)} were given"
f"Callback data {cls.__name__!r} takes {len(names)} arguments "
f"but {len(parts)} were given"
)
if prefix != cls.prefix:
raise ValueError(f"Bad prefix ({prefix!r} != {cls.prefix!r})")
if prefix != cls.__prefix__:
raise ValueError(f"Bad prefix ({prefix!r} != {cls.__prefix__!r})")
payload = {}
for k, v in zip(names, parts): # type: str, Optional[str]
if field := cls.__fields__.get(k):
@ -87,15 +114,30 @@ class CallbackData(BaseModel):
@classmethod
def filter(cls, rule: Optional[MagicFilter] = None) -> CallbackQueryFilter:
"""
Generates a filter for callback query with rule
:param rule: magic rule
:return: instance of filter
"""
return CallbackQueryFilter(callback_data=cls, rule=rule)
class Config:
use_enum_values = True
# class Config:
# use_enum_values = True
class CallbackQueryFilter(BaseFilter):
"""
This filter helps to handle callback query.
Should not be used directly, you should create the instance of this filter
via callback data instance
"""
callback_data: Type[CallbackData]
"""Expected type of callback data"""
rule: Optional[MagicFilter] = None
"""Magic rule"""
async def __call__(self, query: CallbackQuery) -> Union[Literal[False], Dict[str, Any]]:
if not isinstance(query, CallbackQuery) or not query.data:
@ -111,3 +153,4 @@ class CallbackQueryFilter(BaseFilter):
class Config:
arbitrary_types_allowed = True
use_enum_values = True

View file

@ -59,7 +59,10 @@ class Command(BaseFilter):
command = await self.parse_command(text=text, bot=bot)
except CommandException:
return False
return {"command": command}
result = {"command": command}
if command.magic_result and isinstance(command.magic_result, dict):
result.update(command.magic_result)
return result
def extract_command(self, text: str) -> CommandObject:
# First step: separate command with arguments
@ -110,20 +113,22 @@ class Command(BaseFilter):
self.validate_prefix(command=command)
await self.validate_mention(bot=bot, command=command)
command = self.validate_command(command)
self.do_magic(command=command)
command = self.do_magic(command=command)
return command
def do_magic(self, command: CommandObject) -> None:
def do_magic(self, command: CommandObject) -> Any:
if not self.command_magic:
return
if not self.command_magic.resolve(command):
return command
result = self.command_magic.resolve(command)
if not result:
raise CommandException("Rejected via magic filter")
return replace(command, magic_result=result)
class Config:
arbitrary_types_allowed = True
@dataclass
@dataclass(frozen=True)
class CommandObject:
"""
Instance of this object is always has command and it prefix.
@ -140,6 +145,7 @@ class CommandObject:
"""Command argument"""
regexp_match: Optional[Match[str]] = field(repr=False, default=None)
"""Will be presented match result if the command is presented as regexp in filter"""
magic_result: Optional[Any] = field(repr=False, default=None)
@property
def mentioned(self) -> bool:

View file

@ -12,9 +12,7 @@ class MagicData(BaseFilter):
class Config:
arbitrary_types_allowed = True
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> bool:
return bool(
self.magic_data.resolve(
AttrDict({"event": event, **{k: v for k, v in enumerate(args)}, **kwargs})
)
async def __call__(self, event: TelegramObject, *args: Any, **kwargs: Any) -> Any:
return self.magic_data.resolve(
AttrDict({"event": event, **{k: v for k, v in enumerate(args)}, **kwargs})
)

View file

@ -1,10 +1,13 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union, cast, overload
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, overload
from magic_filter import AttrDict
from aiogram.dispatcher.flags.getter import extract_flags_from_object
if TYPE_CHECKING:
pass
@dataclass(frozen=True)
class Flag:
@ -25,11 +28,11 @@ class FlagDecorator:
return self._with_flag(new_flag)
@overload
def __call__(self, value: Callable[..., Any]) -> Callable[..., Any]: # type: ignore
def __call__(self, value: Callable[..., Any], /) -> Callable[..., Any]: # type: ignore
pass
@overload
def __call__(self, value: Any) -> "FlagDecorator":
def __call__(self, value: Any, /) -> "FlagDecorator":
pass
@overload
@ -53,8 +56,24 @@ class FlagDecorator:
return self._with_value(AttrDict(kwargs) if value is None else value)
if TYPE_CHECKING:
class _ChatActionFlagProtocol(FlagDecorator):
def __call__( # type: ignore[override]
self,
action: str = ...,
interval: float = ...,
initial_sleep: float = ...,
**kwargs: Any,
) -> FlagDecorator:
pass
class FlagGenerator:
def __getattr__(self, name: str) -> FlagDecorator:
if name[0] == "_":
raise AttributeError("Flag name must NOT start with underscore")
return FlagDecorator(Flag(name, True))
if TYPE_CHECKING:
chat_action: _ChatActionFlagProtocol

View file

@ -2,7 +2,9 @@ from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Literal, Optional, cast
from aioredis import ConnectionPool, Redis
from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool
from redis.asyncio.lock import Lock
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
@ -131,7 +133,7 @@ class RedisStorage(BaseStorage):
return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs)
async def close(self) -> None:
await self.redis.close() # type: ignore
await self.redis.close()
async def set_state(
self,
@ -223,7 +225,7 @@ class RedisEventIsolation(BaseEventIsolation):
key: StorageKey,
) -> AsyncGenerator[None, None]:
redis_key = self.key_builder.build(key, "lock")
async with self.redis.lock(name=redis_key, **self.lock_kwargs):
async with self.redis.lock(name=redis_key, **self.lock_kwargs, lock_class=Lock):
yield None
async def close(self) -> None:

View file

@ -0,0 +1,61 @@
import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, overload
from aiogram.dispatcher.event.bases import MiddlewareEventType, MiddlewareType, NextMiddlewareType
from aiogram.dispatcher.event.handler import HandlerType
from aiogram.types import TelegramObject
class MiddlewareManager(Sequence[MiddlewareType[TelegramObject]]):
def __init__(self) -> None:
self._middlewares: List[MiddlewareType[TelegramObject]] = []
def register(
self,
middleware: MiddlewareType[TelegramObject],
) -> MiddlewareType[TelegramObject]:
self._middlewares.append(middleware)
return middleware
def unregister(self, middleware: MiddlewareType[TelegramObject]) -> None:
self._middlewares.remove(middleware)
def __call__(
self,
middleware: Optional[MiddlewareType[TelegramObject]] = None,
) -> Union[
Callable[[MiddlewareType[TelegramObject]], MiddlewareType[TelegramObject]],
MiddlewareType[TelegramObject],
]:
if middleware is None:
return self.register
return self.register(middleware)
@overload
def __getitem__(self, item: int) -> MiddlewareType[TelegramObject]:
pass
@overload
def __getitem__(self, item: slice) -> Sequence[MiddlewareType[TelegramObject]]:
pass
def __getitem__(
self, item: Union[int, slice]
) -> Union[MiddlewareType[TelegramObject], Sequence[MiddlewareType[TelegramObject]]]:
return self._middlewares[item]
def __len__(self) -> int:
return len(self._middlewares)
@staticmethod
def wrap_middlewares(
middlewares: Sequence[MiddlewareType[MiddlewareEventType]], handler: HandlerType
) -> NextMiddlewareType[MiddlewareEventType]:
@functools.wraps(handler)
def handler_wrapper(event: TelegramObject, kwargs: Dict[str, Any]) -> Any:
return handler(event, **kwargs)
middleware = handler_wrapper
for m in reversed(middlewares):
middleware = functools.partial(m, middleware)
return middleware

View file

@ -8,6 +8,7 @@ from ..utils.imports import import_module
from ..utils.warnings import CodeHasNoEffect
from .event.bases import REJECTED, UNHANDLED
from .event.event import EventObserver
from .event.handler import HandlerType
from .event.telegram import TelegramEventObserver
from .filters import BUILTIN_FILTERS
@ -253,7 +254,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.message
@property
@ -264,7 +264,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.edited_message
@property
@ -275,7 +274,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.channel_post
@property
@ -286,7 +284,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.edited_channel_post
@property
@ -297,7 +294,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.inline_query
@property
@ -308,7 +304,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.chosen_inline_result
@property
@ -319,7 +314,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.callback_query
@property
@ -330,7 +324,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.shipping_query
@property
@ -341,7 +334,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.pre_checkout_query
@property
@ -352,7 +344,6 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.poll
@property
@ -363,9 +354,38 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.poll_answer
@property
def my_chat_member_handler(self) -> TelegramEventObserver:
warnings.warn(
"`Router.my_chat_member_handler(...)` is deprecated and will be removed in version 3.2 "
"use `Router.my_chat_member(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.my_chat_member
@property
def chat_member_handler(self) -> TelegramEventObserver:
warnings.warn(
"`Router.chat_member_handler(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chat_member(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.chat_member
@property
def chat_join_request_handler(self) -> TelegramEventObserver:
warnings.warn(
"`Router.chat_join_request_handler(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chat_join_request(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.chat_join_request
@property
def errors_handler(self) -> TelegramEventObserver:
warnings.warn(
@ -374,5 +394,139 @@ class Router:
DeprecationWarning,
stacklevel=2,
)
return self.errors
def register_message(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_message(...)` is deprecated and will be removed in version 3.2 "
"use `Router.message.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.message.register(*args, **kwargs)
def register_edited_message(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_edited_message(...)` is deprecated and will be removed in version 3.2 "
"use `Router.edited_message.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.edited_message.register(*args, **kwargs)
def register_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_channel_post(...)` is deprecated and will be removed in version 3.2 "
"use `Router.channel_post.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.channel_post.register(*args, **kwargs)
def register_edited_channel_post(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_edited_channel_post(...)` is deprecated and will be removed in version 3.2 "
"use `Router.edited_channel_post.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.edited_channel_post.register(*args, **kwargs)
def register_inline_query(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_inline_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.inline_query.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.inline_query.register(*args, **kwargs)
def register_chosen_inline_result(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_chosen_inline_result(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chosen_inline_result.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.chosen_inline_result.register(*args, **kwargs)
def register_callback_query(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_callback_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.callback_query.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.callback_query.register(*args, **kwargs)
def register_shipping_query(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_shipping_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.shipping_query.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.shipping_query.register(*args, **kwargs)
def register_pre_checkout_query(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_pre_checkout_query(...)` is deprecated and will be removed in version 3.2 "
"use `Router.pre_checkout_query.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.pre_checkout_query.register(*args, **kwargs)
def register_poll(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_poll(...)` is deprecated and will be removed in version 3.2 "
"use `Router.poll.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.poll.register(*args, **kwargs)
def register_poll_answer(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_poll_answer(...)` is deprecated and will be removed in version 3.2 "
"use `Router.poll_answer.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.poll_answer.register(*args, **kwargs)
def register_my_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_my_chat_member(...)` is deprecated and will be removed in version 3.2 "
"use `Router.my_chat_member.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.my_chat_member.register(*args, **kwargs)
def register_chat_member(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_chat_member(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chat_member.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.chat_member.register(*args, **kwargs)
def register_chat_join_request(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_chat_join_request(...)` is deprecated and will be removed in version 3.2 "
"use `Router.chat_join_request.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.chat_join_request.register(*args, **kwargs)
def register_errors(self, *args: Any, **kwargs: Any) -> HandlerType:
warnings.warn(
"`Router.register_errors(...)` is deprecated and will be removed in version 3.2 "
"use `Router.errors.register(...)`",
DeprecationWarning,
stacklevel=2,
)
return self.errors.register(*args, **kwargs)

View file

@ -84,6 +84,8 @@ class TelegramMethod(abc.ABC, BaseModel, Generic[TelegramType]):
async def emit(self, bot: Bot) -> TelegramType:
return await bot(self)
as_ = emit
def __await__(self) -> Generator[Any, None, TelegramType]:
from aiogram.client.bot import Bot

View file

@ -8,6 +8,10 @@ if TYPE_CHECKING:
from .keyboard_button_poll_type import KeyboardButtonPollType
class WebApp(MutableTelegramObject):
url: str
class KeyboardButton(MutableTelegramObject):
"""
This object represents one button of the reply keyboard. For simple text buttons *String* can be used instead of this object to specify text of the button. Optional fields *request_contact*, *request_location*, and *request_poll* are mutually exclusive.
@ -26,3 +30,4 @@ class KeyboardButton(MutableTelegramObject):
"""*Optional*. If :code:`True`, the user's current location will be sent when the button is pressed. Available in private chats only"""
request_poll: Optional[KeyboardButtonPollType] = None
"""*Optional*. If specified, the user will be asked to create a poll and send it to the bot when the button is pressed. Available in private chats only"""
web_app: Optional[WebApp] = None

View file

@ -71,7 +71,7 @@ if TYPE_CHECKING:
from .voice_chat_started import VoiceChatStarted
class Message(TelegramObject):
class _BaseMessage(TelegramObject):
"""
This object represents a message.
@ -195,6 +195,8 @@ class Message(TelegramObject):
reply_markup: Optional[InlineKeyboardMarkup] = None
"""*Optional*. Inline keyboard attached to the message. :code:`login_url` buttons are represented as ordinary :code:`url` buttons."""
class Message(_BaseMessage):
@property
def content_type(self) -> str:
if self.text:
@ -265,11 +267,8 @@ class Message(TelegramObject):
return ContentType.UNKNOWN
def _unparse_entities(self, text_decoration: TextDecoration) -> str:
text = self.text or self.caption
if text is None:
raise TypeError("This message doesn't have any text.")
entities = self.entities or self.caption_entities
text = self.text or self.caption or ""
entities = self.entities or self.caption_entities or []
return text_decoration.unparse(text=text, entities=entities)
@property

View file

@ -12,7 +12,7 @@ from aiogram.types import Message, TelegramObject
logger = logging.getLogger(__name__)
DEFAULT_INTERVAL = 5.0
DEFAULT_INITIAL_SLEEP = 0.1
DEFAULT_INITIAL_SLEEP = 0.0
class ChatActionSender:

View file

@ -1,18 +1,28 @@
from typing import Any
from typing import Any, Optional
from urllib.parse import urlencode, urljoin
BASE_DOCS_URL = "https://docs.aiogram.dev/"
BRANCH = "dev-3.x"
BASE_PAGE_URL = f"{BASE_DOCS_URL}/en/{BRANCH}/"
def _format_url(url: str, *path: str, fragment_: Optional[str] = None, **query: Any) -> str:
url = urljoin(url, "/".join(path), allow_fragments=True)
if query:
url += "?" + urlencode(query)
if fragment_:
url += "#" + fragment_
return url
def docs_url(*path: str, fragment_: Optional[str] = None, **query: Any) -> str:
return _format_url(BASE_PAGE_URL, *path, fragment_=fragment_, **query)
def create_tg_link(link: str, **kwargs: Any) -> str:
url = f"tg://{link}"
if kwargs:
query = urlencode(kwargs)
url += f"?{query}"
return url
return _format_url(f"tg://{link}", **kwargs)
def create_telegram_link(uri: str, **kwargs: Any) -> str:
url = urljoin("https://t.me", uri)
if kwargs:
query = urlencode(query=kwargs)
url += f"?{query}"
return url
def create_telegram_link(*path: str, **kwargs: Any) -> str:
return _format_url("https://t.me", *path, **kwargs)

View file

@ -37,7 +37,7 @@ class TextDecoration(ABC):
if entity.type in {"bot_command", "url", "mention", "phone_number"}:
# This entities should not be changed
return text
if entity.type in {"bold", "italic", "code", "underline", "strikethrough"}:
if entity.type in {"bold", "italic", "code", "underline", "strikethrough", "spoiler"}:
return cast(str, getattr(self, entity.type)(value=text))
if entity.type == "pre":
return (
@ -102,35 +102,39 @@ class TextDecoration(ABC):
yield self.quote(remove_surrogates(text[offset:length]))
@abstractmethod
def link(self, value: str, link: str) -> str: # pragma: no cover
def link(self, value: str, link: str) -> str:
pass
@abstractmethod
def bold(self, value: str) -> str: # pragma: no cover
def bold(self, value: str) -> str:
pass
@abstractmethod
def italic(self, value: str) -> str: # pragma: no cover
def italic(self, value: str) -> str:
pass
@abstractmethod
def code(self, value: str) -> str: # pragma: no cover
def code(self, value: str) -> str:
pass
@abstractmethod
def pre(self, value: str) -> str: # pragma: no cover
def pre(self, value: str) -> str:
pass
@abstractmethod
def pre_language(self, value: str, language: str) -> str: # pragma: no cover
def pre_language(self, value: str, language: str) -> str:
pass
@abstractmethod
def underline(self, value: str) -> str: # pragma: no cover
def underline(self, value: str) -> str:
pass
@abstractmethod
def strikethrough(self, value: str) -> str: # pragma: no cover
def strikethrough(self, value: str) -> str:
pass
@abstractmethod
def spoiler(self, value: str) -> str:
pass
@abstractmethod
@ -139,14 +143,20 @@ class TextDecoration(ABC):
class HtmlDecoration(TextDecoration):
BOLD_TAG = "b"
ITALIC_TAG = "i"
UNDERLINE_TAG = "u"
STRIKETHROUGH_TAG = "s"
SPOILER_TAG = ('span class="tg-spoiler"', "span")
def link(self, value: str, link: str) -> str:
return f'<a href="{link}">{value}</a>'
def bold(self, value: str) -> str:
return f"<b>{value}</b>"
return f"<{self.BOLD_TAG}>{value}</{self.BOLD_TAG}>"
def italic(self, value: str) -> str:
return f"<i>{value}</i>"
return f"<{self.ITALIC_TAG}>{value}</{self.ITALIC_TAG}>"
def code(self, value: str) -> str:
return f"<code>{value}</code>"
@ -158,10 +168,13 @@ class HtmlDecoration(TextDecoration):
return f'<pre><code class="language-{language}">{value}</code></pre>'
def underline(self, value: str) -> str:
return f"<u>{value}</u>"
return f"<{self.UNDERLINE_TAG}>{value}</{self.UNDERLINE_TAG}>"
def strikethrough(self, value: str) -> str:
return f"<s>{value}</s>"
return f"<{self.STRIKETHROUGH_TAG}>{value}</{self.STRIKETHROUGH_TAG}>"
def spoiler(self, value: str) -> str:
return f"<{self.SPOILER_TAG[0]}>{value}</{self.SPOILER_TAG[1]}>"
def quote(self, value: str) -> str:
return html.escape(value, quote=False)
@ -194,6 +207,9 @@ class MarkdownDecoration(TextDecoration):
def strikethrough(self, value: str) -> str:
return f"~{value}~"
def spoiler(self, value: str) -> str:
return f"|{value}|"
def quote(self, value: str) -> str:
return re.sub(pattern=self.MARKDOWN_QUOTE_PATTERN, repl=r"\\\1", string=value)

View file

@ -0,0 +1,118 @@
==============================
Callback Data Factory & Filter
==============================
.. autoclass:: aiogram.dispatcher.filters.callback_data.CallbackData
:members:
:member-order: bysource
:undoc-members: False
Usage
=====
Create subclass of :code:`CallbackData`:
.. code-block:: python
class MyCallback(CallbackData, prefix="my"):
foo: str
bar: int
After that you can generate any callback based on this class, for example:
.. code-block:: python
cb1 = MyCallback(foo="demo", bar=42)
cb1.pack() # returns 'my:demo:42'
cb1.unpack('my:demo:42') # returns <MyCallback(foo="demo", bar=42)>
So... Now you can use this class to generate any callbacks with defined structure
.. code-block:: python
...
# Pass it into the markup
InlineKeyboardButton(
text="demo",
callback_data=MyCallback(foo="demo", bar="42").pack() # value should be packed to string
)
...
... and handle by specific rules
.. code-block:: python
# Filter callback by type and value of field :code:`foo`
@router.callback_query(MyCallback.filter(F.foo == "demo"))
async def my_callback_foo(query: CallbackQuery, callback_data: MyCallback):
await query.answer(...)
...
print("bar =", callback_data.bar)
Also can be used in :doc:`Keyboard builder </utils/keyboard>`:
.. code-block:: python
builder = InlineKeyboardBuilder()
builder.button(
text="demo",
callback_data=MyCallback(foo="demo", bar="42") # Value can be not packed to string inplace, because builder knows what to do with callback instance
)
Another abstract example:
.. code-block:: python
class Action(str, Enum):
ban = "ban"
kick = "kick"
warn = "warn"
class AdminAction(CallbackData, prefix="adm"):
action: Action
chat_id: int
user_id: int
...
# Inside handler
builder = InlineKeyboardBuilder()
for action in Action:
builder.button(
text=action.value.title(),
callback_data=AdminAction(action=action, chat_id=chat_id, user_id=user_id),
)
await bot.send_message(
chat_id=admins_chat,
text=f"What do you want to do with {html.quote(name)}",
reply_markup=builder.as_markup(),
)
...
@router.callback_query(AdminAction.filter(F.action == Action.ban))
async def ban_user(query: CallbackQuery, callback_data: AdminAction, bot: Bot):
await bot.ban_chat_member(
chat_id=callback_data.chat_id,
user_id=callback_data.user_id,
...
)
Known limitations
=================
Allowed types and their subclasses:
- :code:`str`
- :code:`int`
- :code:`bool`
- :code:`float`
- :code:`Decimal` (:code:`from decimal import Decimal`)
- :code:`Fraction` (:code:`from fractions import Fraction`)
- :code:`UUID` (:code:`from uuid import UUID`)
- :code:`Enum` (:code:`from enum import Enum`, only for string enums)
- :code:`IntEnum` (:code:`from enum import IntEnum`, only for int enums)
.. note::
Note that the integer Enum's should be always is subclasses of :code:`IntEnum` in due to parsing issues.

View file

@ -22,6 +22,7 @@ Here is list of builtin filters:
exception
magic_filters
magic_data
callback_data
Own filters specification
=========================

View file

@ -4,6 +4,6 @@ Utils
.. toctree::
i18n
keyboard
i18n
chat_action

View file

@ -1,5 +1,4 @@
import logging
from typing import Any
from aiogram import Bot, Dispatcher, types
from aiogram.types import Message

View file

@ -30,7 +30,7 @@ ignore_missing_imports = True
[mypy-uvloop]
ignore_missing_imports = True
[mypy-aioredis]
[mypy-redis.*]
ignore_missing_imports = True
[mypy-babel.*]

970
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -37,7 +37,7 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.8"
magic-filter = "^1.0.5"
magic-filter = "^1.0.6"
aiohttp = "^3.8.1"
pydantic = "^1.9.0"
aiofiles = "^0.8.0"
@ -46,30 +46,29 @@ uvloop = { version = "^0.16.0", markers = "sys_platform == 'darwin' or sys_platf
# i18n
Babel = { version = "^2.9.1", optional = true }
# Proxy
aiohttp-socks = {version = "^0.7.1", optional = true}
aiohttp-socks = { version = "^0.7.1", optional = true }
# Redis
aioredis = {version = "^2.0.1", optional = true}
redis = { version = "^4.2.2", optional = true }
# Docs
Sphinx = { version = "^4.2.0", optional = true }
sphinx-intl = { version = "^2.0.1", optional = true }
sphinx-autobuild = { version = "^2021.3.14", optional = true }
sphinx-copybutton = {version = "^0.5.0", optional = true}
furo = {version = "^2022.2.14", optional = true}
sphinx-copybutton = { version = "^0.5.0", optional = true }
furo = { version = "^2022.4.7", optional = true }
sphinx-prompt = { version = "^1.5.0", optional = true }
Sphinx-Substitution-Extensions = { version = "^2020.9.30", optional = true }
towncrier = {version = "^21.9.0", optional = true}
towncrier = { version = "^21.9.0", optional = true }
pygments = { version = "^2.4", optional = true }
pymdown-extensions = {version = "^9.2", optional = true}
pymdown-extensions = { version = "^9.3", optional = true }
markdown-include = { version = "^0.6", optional = true }
Pygments = {version = "^2.11.2", optional = true}
Pygments = { version = "^2.11.2", optional = true }
[tool.poetry.dev-dependencies]
ipython = "^8.0.1"
ipython = "^8.1.1"
black = "^22.1.0"
isort = "^5.10.1"
flake8 = "^4.0.1"
flake8-html = "^0.4.1"
mypy = "^0.931"
mypy = "^0.942"
pytest = "^7.0.1"
pytest-html = "^3.1.1"
pytest-asyncio = "^0.18.1"
@ -90,7 +89,7 @@ sentry-sdk = "^1.5.5"
[tool.poetry.extras]
fast = ["uvloop"]
redis = ["aioredis"]
redis = ["redis"]
proxy = ["aiohttp-socks"]
i18n = ["Babel"]
docs = [
@ -110,7 +109,7 @@ docs = [
[tool.black]
line-length = 99
target-version = ['py37', 'py38']
target-version = ['py38', 'py39', 'py310']
exclude = '''
(
\.eggs

View file

@ -2,7 +2,7 @@ from pathlib import Path
import pytest
from _pytest.config import UsageError
from aioredis.connection import parse_url as parse_redis_url
from redis.asyncio.connection import parse_url as parse_redis_url
from aiogram import Bot, Dispatcher
from aiogram.dispatcher.fsm.storage.memory import (

View file

@ -215,11 +215,11 @@ class TestBaseSession:
return await make_request(bot, method)
session = CustomSession()
assert not session.middlewares
assert not session.middleware._middlewares
session.middleware(my_middleware)
assert my_middleware in session.middlewares
assert len(session.middlewares) == 1
assert my_middleware in session.middleware
assert len(session.middleware) == 1
async def test_use_middleware(self, bot: MockedBot):
flag_before = False

View file

@ -0,0 +1,45 @@
from aiogram import Bot
from aiogram.client.session.middlewares.base import (
BaseRequestMiddleware,
NextRequestMiddlewareType,
)
from aiogram.client.session.middlewares.manager import RequestMiddlewareManager
from aiogram.methods import Response, TelegramMethod
from aiogram.types import TelegramObject
class TestMiddlewareManager:
async def test_register(self):
manager = RequestMiddlewareManager()
@manager
async def middleware(handler, event, data):
await handler(event, data)
assert middleware in manager._middlewares
manager.unregister(middleware)
assert middleware not in manager._middlewares
async def test_wrap_middlewares(self):
manager = RequestMiddlewareManager()
class MyMiddleware(BaseRequestMiddleware):
async def __call__(
self,
make_request: NextRequestMiddlewareType,
bot: Bot,
method: TelegramMethod[TelegramObject],
) -> Response[TelegramObject]:
return await make_request(bot, method)
manager.register(MyMiddleware())
@manager()
@manager
async def middleware(make_request, bot, method):
return await make_request(bot, method)
async def target_call(bot, method, timeout: int = None):
return timeout
assert await manager.wrap_middlewares(target_call, timeout=42)(None, None) == 42

View file

@ -641,13 +641,15 @@ class TestMessage:
assert method.message_id == message.message_id
@pytest.mark.parametrize(
"text,entities,correct",
"text,entities,mode,expected_value",
[
["test", [MessageEntity(type="bold", offset=0, length=4)], True],
["", [], False],
["test", [MessageEntity(type="bold", offset=0, length=4)], "html", "<b>test</b>"],
["test", [MessageEntity(type="bold", offset=0, length=4)], "md", "*test*"],
["", [], "html", ""],
["", [], "md", ""],
],
)
def test_html_text(self, text, entities, correct):
def test_html_text(self, text, entities, mode, expected_value):
message = Message(
message_id=42,
chat=Chat(id=42, type="private"),
@ -655,11 +657,4 @@ class TestMessage:
text=text,
entities=entities,
)
if correct:
assert message.html_text
assert message.md_text
else:
with pytest.raises(TypeError):
assert message.html_text
with pytest.raises(TypeError):
assert message.md_text
assert getattr(message, f"{mode}_text") == expected_value

View file

@ -5,27 +5,38 @@ from aiogram.dispatcher.router import Router
from tests.deprecated import check_deprecated
OBSERVERS = {
"callback_query",
"channel_post",
"chosen_inline_result",
"edited_channel_post",
"edited_message",
"errors",
"inline_query",
"message",
"edited_message",
"channel_post",
"edited_channel_post",
"inline_query",
"chosen_inline_result",
"callback_query",
"shipping_query",
"pre_checkout_query",
"poll",
"poll_answer",
"pre_checkout_query",
"shipping_query",
"my_chat_member",
"chat_member",
"chat_join_request",
"errors",
}
DEPRECATED_OBSERVERS = {observer + "_handler" for observer in OBSERVERS}
@pytest.mark.parametrize("observer_name", DEPRECATED_OBSERVERS)
@pytest.mark.parametrize("observer_name", OBSERVERS)
def test_deprecated_handlers_name(observer_name: str):
router = Router()
with check_deprecated("3.2", exception=AttributeError):
observer = getattr(router, observer_name)
observer = getattr(router, f"{observer_name}_handler")
assert isinstance(observer, TelegramEventObserver)
@pytest.mark.parametrize("observer_name", OBSERVERS)
def test_deprecated_register_handlers(observer_name: str):
router = Router()
with check_deprecated("3.2", exception=AttributeError):
register = getattr(router, f"register_{observer_name}")
register(lambda event: True)
assert callable(register)

View file

@ -74,7 +74,22 @@ class TestDispatcher:
assert dp.update.handlers
assert dp.update.handlers[0].callback == dp._listen_update
assert dp.update.outer_middlewares
assert dp.update.outer_middleware
def test_data_bind(self):
dp = Dispatcher()
assert dp.get("foo") is None
assert dp.get("foo", 42) == 42
dp["foo"] = 1
assert dp._data["foo"] == 1
assert dp["foo"] == 1
del dp["foo"]
assert "foo" not in dp._data
def test_storage_property(self, dispatcher: Dispatcher):
assert dispatcher.storage is dispatcher.fsm.storage
def test_parent_router(self, dispatcher: Dispatcher):
with pytest.raises(RuntimeError):

View file

@ -0,0 +1,42 @@
from functools import partial
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
class TestMiddlewareManager:
async def test_register(self):
manager = MiddlewareManager()
@manager
async def middleware(handler, event, data):
await handler(event, data)
assert middleware in manager._middlewares
manager.unregister(middleware)
assert middleware not in manager._middlewares
async def test_wrap_middlewares(self):
manager = MiddlewareManager()
async def target(*args, **kwargs):
kwargs["target"] = True
kwargs["stack"].append(-1)
return kwargs
async def middleware1(handler, event, data):
data["mw1"] = True
data["stack"].append(1)
return await handler(event, data)
async def middleware2(handler, event, data):
data["mw2"] = True
data["stack"].append(2)
return await handler(event, data)
wrapped = manager.wrap_middlewares([middleware1, middleware2], target)
assert isinstance(wrapped, partial)
assert wrapped.func is middleware1
result = await wrapped(None, {"stack": []})
assert result == {"mw1": True, "mw2": True, "target": True, "stack": [1, 2, -1]}

View file

@ -297,10 +297,9 @@ class TestTelegramEventObserver:
def test_register_middleware(self, middleware_type):
event_observer = TelegramEventObserver(Router(), "test")
middlewares = getattr(event_observer, f"{middleware_type}s")
decorator = getattr(event_observer, middleware_type)
middlewares = getattr(event_observer, middleware_type)
@decorator
@middlewares
async def my_middleware1(handler, event, data):
pass
@ -308,7 +307,7 @@ class TestTelegramEventObserver:
assert my_middleware1.__name__ == "my_middleware1"
assert my_middleware1 in middlewares
@decorator()
@middlewares()
async def my_middleware2(handler, event, data):
pass
@ -319,13 +318,13 @@ class TestTelegramEventObserver:
async def my_middleware3(handler, event, data):
pass
decorator(my_middleware3)
middlewares(my_middleware3)
assert my_middleware3 is not None
assert my_middleware3.__name__ == "my_middleware3"
assert my_middleware3 in middlewares
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]
assert list(middlewares) == [my_middleware1, my_middleware2, my_middleware3]
def test_register_global_filters(self):
router = Router(use_builtin_filters=False)

View file

@ -30,7 +30,7 @@ class MyCallback(CallbackData, prefix="test"):
class TestCallbackData:
def test_init_subclass_prefix_required(self):
assert MyCallback.prefix == "test"
assert MyCallback.__prefix__ == "test"
with pytest.raises(ValueError, match="prefix required.+"):
@ -38,12 +38,12 @@ class TestCallbackData:
pass
def test_init_subclass_sep_validation(self):
assert MyCallback.sep == ":"
assert MyCallback.__separator__ == ":"
class MyCallback2(CallbackData, prefix="test2", sep="@"):
pass
assert MyCallback2.sep == "@"
assert MyCallback2.__separator__ == "@"
with pytest.raises(ValueError, match="Separator symbol '@' .+ 'sp@m'"):

View file

@ -92,6 +92,18 @@ class TestCommandFilter:
command = Command(commands=["test"])
assert bool(await command(message=message, bot=bot)) is result
async def test_command_magic_result(self, bot: MockedBot):
message = Message(
message_id=0,
text="/test 42",
chat=Chat(id=42, type="private"),
date=datetime.datetime.now(),
)
command = Command(commands=["test"], command_magic=(F.args.as_("args")))
result = await command(message=message, bot=bot)
assert "args" in result
assert result["args"] == "42"
class TestCommandObject:
@pytest.mark.parametrize(

View file

@ -2,7 +2,7 @@ import re
import pytest
from aiogram import Dispatcher, F
from aiogram import Dispatcher
from aiogram.dispatcher.filters import ExceptionMessageFilter, ExceptionTypeFilter
from aiogram.types import Update

View file

@ -22,9 +22,9 @@ class TestMagicDataFilter:
assert value.spam is True
return value
f = MagicData(magic_data=F.func(check))
f = MagicData(magic_data=F.func(check).as_("test"))
result = await f(Update(update_id=123), "foo", "bar", spam=True)
assert called
assert isinstance(result, bool)
assert result
assert isinstance(result, dict)
assert result["test"]

View file

@ -111,8 +111,8 @@ class TestSimpleI18nMiddleware:
middleware = SimpleI18nMiddleware(i18n=i18n)
middleware.setup(router=dp)
assert middleware not in dp.update.outer_middlewares
assert middleware in dp.message.outer_middlewares
assert middleware not in dp.update.outer_middleware
assert middleware in dp.message.outer_middleware
async def test_get_unknown_locale(self, i18n: I18n):
dp = Dispatcher()

View file

@ -2,7 +2,7 @@ from typing import Any, Dict
import pytest
from aiogram.utils.link import create_telegram_link, create_tg_link
from aiogram.utils.link import BRANCH, create_telegram_link, create_tg_link, docs_url
class TestLink:
@ -22,3 +22,12 @@ class TestLink:
)
def test_create_telegram_link(self, base: str, params: Dict[str, Any], result: str):
assert create_telegram_link(base, **params) == result
def test_fragment(self):
assert (
docs_url("test.html", fragment_="test")
== f"https://docs.aiogram.dev/en/{BRANCH}/test.html#test"
)
def test_docs(self):
assert docs_url("test.html") == f"https://docs.aiogram.dev/en/{BRANCH}/test.html"

View file

@ -47,6 +47,11 @@ class TestTextDecoration:
'<a href="tg://user?id=42">test</a>',
],
[html_decoration, MessageEntity(type="url", offset=0, length=5), "test"],
[
html_decoration,
MessageEntity(type="spoiler", offset=0, length=5),
'<span class="tg-spoiler">test</span>',
],
[
html_decoration,
MessageEntity(type="text_link", offset=0, length=5, url="https://aiogram.dev"),
@ -76,6 +81,7 @@ class TestTextDecoration:
[markdown_decoration, MessageEntity(type="bot_command", offset=0, length=5), "test"],
[markdown_decoration, MessageEntity(type="email", offset=0, length=5), "test"],
[markdown_decoration, MessageEntity(type="phone_number", offset=0, length=5), "test"],
[markdown_decoration, MessageEntity(type="spoiler", offset=0, length=5), "|test|"],
[
markdown_decoration,
MessageEntity(