diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f35271f0..38066ebe 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,6 +43,12 @@ jobs: virtualenvs-create: true virtualenvs-in-project: true + - name: Setup redis + if: ${{ matrix.os != 'windows-latest' }} + uses: shogo82148/actions-setup-redis@v1 + with: + redis-version: 6 + - name: Load cached venv id: cached-poetry-dependencies uses: actions/cache@v2 @@ -64,7 +70,14 @@ jobs: run: | poetry run black --check --diff aiogram tests - - name: Run tests + - name: Run tests (with Redis) + if: ${{ matrix.os != 'windows-latest' }} + run: | + poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml --redis redis://localhost:6379/0 + + - name: Run tests (without Redis) + # Redis can't be used on GitHub Windows Runners + if: ${{ matrix.os == 'windows-latest' }} run: | poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml diff --git a/Makefile b/Makefile index c1fa9797..da73f8b7 100644 --- a/Makefile +++ b/Makefile @@ -4,8 +4,14 @@ base_python := python3 py := poetry run python := $(py) python +package_dir := aiogram +tests_dir := tests +scripts_dir := scripts +code_dir := $(package_dir) $(tests_dir) $(scripts_dir) reports_dir := reports +redis_connection := redis://localhost:6379 + .PHONY: help help: @echo "=======================================================================================" @@ -17,13 +23,8 @@ help: @echo " clean: Delete temporary files" @echo "" @echo "Code quality:" - @echo " isort: Run isort tool" - @echo " black: Run black tool" - @echo " flake8: Run flake8 tool" - @echo " flake8-report: Run flake8 with HTML reporting" - @echo " mypy: Run mypy tool" - @echo " mypy-report: Run mypy tool with HTML reporting" - @echo " lint: Run isort, black, flake8 and mypy tools" + @echo " lint: Lint code by isort, black, flake8 and mypy tools" + @echo " reformat: Reformat code by isort and black tools" @echo "" @echo "Tests:" @echo " test: Run tests" @@ -31,8 +32,8 @@ help: @echo " test-coverage-report: Open coverage report in default system web browser" @echo "" @echo "Documentation:" - @echo " docs: Build docs" - @echo " docs-serve: Serve docs for local development" + @echo " docs: Build docs" + @echo " docs-serve: Serve docs for local development" @echo " docs-prepare-reports: Move all HTML reports to docs dir" @echo "" @echo "Project" @@ -65,33 +66,17 @@ clean: # Code quality # ================================================================================================= -.PHONY: isort -isort: - $(py) isort aiogram tests scripts - -.PHONY: black -black: - $(py) black aiogram tests scripts - -.PHONY: flake8 -flake8: - $(py) flake8 aiogram - -.PHONY: flake8-report -flake8-report: - mkdir -p $(reports_dir)/flake8 - $(py) flake8 --format=html --htmldir=$(reports_dir)/flake8 aiogram - -.PHONY: mypy -mypy: - $(py) mypy aiogram - -.PHONY: mypy-report -mypy-report: - $(py) mypy aiogram --html-report $(reports_dir)/typechecking - .PHONY: lint -lint: isort black flake8 mypy +lint: + $(py) isort --check-only $(code_dir) + $(py) black --check --diff $(code_dir) + $(py) flake8 $(code_dir) + $(py) mypy $(package_dir) + +.PHONY: reformat +reformat: + $(py) black $(code_dir) + $(py) isort $(code_dir) # ================================================================================================= # Tests @@ -99,12 +84,12 @@ lint: isort black flake8 mypy .PHONY: test test: - $(py) pytest --cov=aiogram --cov-config .coveragerc tests/ + $(py) pytest --cov=aiogram --cov-config .coveragerc tests/ --redis $(redis_connection) .PHONY: test-coverage test-coverage: mkdir -p $(reports_dir)/tests/ - $(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ + $(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ --redis $(redis_connection) .PHONY: test-coverage-report test-coverage-report: diff --git a/aiogram/__init__.py b/aiogram/__init__.py index 31b52552..639e68c9 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -6,6 +6,8 @@ from .dispatcher import filters, handler from .dispatcher.dispatcher import Dispatcher from .dispatcher.middlewares.base import BaseMiddleware from .dispatcher.router import Router +from .utils.text_decorations import html_decoration as _html_decoration +from .utils.text_decorations import markdown_decoration as _markdown_decoration try: import uvloop as _uvloop @@ -15,6 +17,8 @@ except ImportError: # pragma: no cover pass F = MagicFilter() +html = _html_decoration +md = _markdown_decoration __all__ = ( "__api_version__", @@ -29,6 +33,8 @@ __all__ = ( "filters", "handler", "F", + "html", + "md", ) __version__ = "3.0.0-alpha.8" diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index 78ff5aaf..95c721a1 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -4,7 +4,7 @@ import asyncio import contextvars import warnings from asyncio import CancelledError, Future, Lock -from typing import Any, AsyncGenerator, Dict, Optional, Union, cast +from typing import Any, AsyncGenerator, Dict, Optional, Union from .. import loggers from ..client.bot import Bot @@ -13,7 +13,6 @@ from ..types import TelegramObject, Update, User from ..utils.exceptions.base import TelegramAPIError from .event.bases import UNHANDLED, SkipHandler from .event.telegram import TelegramEventObserver -from .fsm.context import FSMContext from .fsm.middleware import FSMContextMiddleware from .fsm.storage.base import BaseStorage from .fsm.storage.memory import MemoryStorage @@ -32,7 +31,7 @@ class Dispatcher(Router): self, storage: Optional[BaseStorage] = None, fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT, - isolate_events: bool = True, + isolate_events: bool = False, **kwargs: Any, ) -> None: super(Dispatcher, self).__init__(**kwargs) @@ -255,7 +254,9 @@ class Dispatcher(Router): ) return True # because update was processed but unsuccessful - async def _polling(self, bot: Bot, polling_timeout: int = 30, **kwargs: Any) -> None: + async def _polling( + self, bot: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any + ) -> None: """ Internal polling process @@ -264,7 +265,11 @@ class Dispatcher(Router): :return: """ async for update in self._listen_updates(bot, polling_timeout=polling_timeout): - await self._process_update(bot=bot, update=update, **kwargs) + handle_update = self._process_update(bot=bot, update=update, **kwargs) + if handle_as_tasks: + asyncio.create_task(handle_update) + else: + await handle_update async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: """ @@ -342,11 +347,15 @@ class Dispatcher(Router): return None - async def start_polling(self, *bots: Bot, polling_timeout: int = 10, **kwargs: Any) -> None: + async def start_polling( + self, *bots: Bot, polling_timeout: int = 10, handle_as_tasks: bool = True, **kwargs: Any + ) -> None: """ Polling runner :param bots: + :param polling_timeout: + :param handle_as_tasks: :param kwargs: :return: """ @@ -363,7 +372,12 @@ class Dispatcher(Router): "Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name ) coro_list.append( - self._polling(bot=bot, polling_timeout=polling_timeout, **kwargs) + self._polling( + bot=bot, + handle_as_tasks=handle_as_tasks, + polling_timeout=polling_timeout, + **kwargs, + ) ) await asyncio.gather(*coro_list) finally: @@ -372,22 +386,27 @@ class Dispatcher(Router): loggers.dispatcher.info("Polling stopped") await self.emit_shutdown(**workflow_data) - def run_polling(self, *bots: Bot, polling_timeout: int = 30, **kwargs: Any) -> None: + def run_polling( + self, *bots: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any + ) -> None: """ Run many bots with polling :param bots: Bot instances :param polling_timeout: Poling timeout + :param handle_as_tasks: Run task for each event and no wait result :param kwargs: contextual data :return: """ try: return asyncio.run( - self.start_polling(*bots, **kwargs, polling_timeout=polling_timeout) + self.start_polling( + *bots, + **kwargs, + polling_timeout=polling_timeout, + handle_as_tasks=handle_as_tasks, + ) ) except (KeyboardInterrupt, SystemExit): # pragma: no cover # Allow to graceful shutdown pass - - def current_state(self, chat_id: int, user_id: int) -> FSMContext: - return cast(FSMContext, self.fsm.resolve_context(chat_id=chat_id, user_id=user_id)) diff --git a/aiogram/dispatcher/filters/command.py b/aiogram/dispatcher/filters/command.py index 899b09be..0e584c99 100644 --- a/aiogram/dispatcher/filters/command.py +++ b/aiogram/dispatcher/filters/command.py @@ -1,18 +1,24 @@ from __future__ import annotations import re -from dataclasses import dataclass, field -from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast +from dataclasses import dataclass, field, replace +from typing import Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast -from pydantic import validator +from magic_filter import MagicFilter +from pydantic import Field, validator from aiogram import Bot from aiogram.dispatcher.filters import BaseFilter from aiogram.types import Message +from aiogram.utils.deep_linking import decode_payload CommandPatterType = Union[str, re.Pattern] +class CommandException(Exception): + pass + + class Command(BaseFilter): """ This filter can be helpful for handling commands from the text messages. @@ -29,6 +35,8 @@ class Command(BaseFilter): """Ignore case (Does not work with regexp, use flags instead)""" commands_ignore_mention: bool = False """Ignore bot mention. By default bot can not handle commands intended for other bots""" + command_magic: Optional[MagicFilter] = None + """Validate command object via Magic filter after all checks done""" @validator("commands", always=True) def _validate_commands( @@ -39,12 +47,54 @@ class Command(BaseFilter): return value async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]: - if not message.text: + text = message.text or message.caption + if not text: return False - return await self.parse_command(text=message.text, bot=bot) + try: + command = await self.parse_command(text=cast(str, message.text), bot=bot) + except CommandException: + return False + return {"command": command} - async def parse_command(self, text: str, bot: Bot) -> Union[bool, Dict[str, CommandObject]]: + def extract_command(self, text: str) -> CommandObject: + # First step: separate command with arguments + # "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"] + try: + full_command, *args = text.split(maxsplit=1) + except ValueError: + raise CommandException("not enough values to unpack") + + # Separate command into valuable parts + # "/command@mention" -> "/", ("command", "@", "mention") + prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@") + return CommandObject( + prefix=prefix, command=command, mention=mention, args=args[0] if args else None + ) + + def validate_prefix(self, command: CommandObject) -> None: + if command.prefix not in self.commands_prefix: + raise CommandException("Invalid command prefix") + + async def validate_mention(self, bot: Bot, command: CommandObject) -> None: + if command.mention and not self.commands_ignore_mention: + me = await bot.me() + if me.username and command.mention.lower() != me.username.lower(): + raise CommandException("Mention did not match") + + def validate_command(self, command: CommandObject) -> CommandObject: + for allowed_command in cast(Sequence[CommandPatterType], self.commands): + # Command can be presented as regexp pattern or raw string + # then need to validate that in different ways + if isinstance(allowed_command, Pattern): # Regexp + result = allowed_command.match(command.command) + if result: + return replace(command, match=result) + elif command.command == allowed_command: # String + return command + raise CommandException("Command did not match pattern") + + async def parse_command(self, text: str, bot: Bot) -> CommandObject: """ Extract command from the text and validate @@ -52,56 +102,18 @@ class Command(BaseFilter): :param bot: :return: """ - if not text.strip(): - return False + command = self.extract_command(text) + self.validate_prefix(command=command) + await self.validate_mention(bot=bot, command=command) + command = self.validate_command(command) + self.do_magic(command=command) + return command - # First step: separate command with arguments - # "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"] - full_command, *args = text.split(maxsplit=1) - - # Separate command into valuable parts - # "/command@mention" -> "/", ("command", "@", "mention") - prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@") - - # Validate prefixes - if prefix not in self.commands_prefix: - return False - - # Validate mention - if mention and not self.commands_ignore_mention: - me = await bot.me() - if me.username and mention.lower() != me.username.lower(): - return False - - # Validate command - for allowed_command in cast(Sequence[CommandPatterType], self.commands): - # Command can be presented as regexp pattern or raw string - # then need to validate that in different ways - if isinstance(allowed_command, Pattern): # Regexp - result = allowed_command.match(command) - if result: - return { - "command": CommandObject( - prefix=prefix, - command=command, - mention=mention, - args=args[0] if args else None, - match=result, - ) - } - - elif command == allowed_command: # String - return { - "command": CommandObject( - prefix=prefix, - command=command, - mention=mention, - args=args[0] if args else None, - match=None, - ) - } - - return False + def do_magic(self, command: CommandObject) -> None: + if not self.command_magic: + return + if not self.command_magic.resolve(command): + raise CommandException("Rejected via magic filter") class Config: arbitrary_types_allowed = True @@ -143,3 +155,40 @@ class CommandObject: if self.args: line += " " + self.args return line + + +class CommandStart(Command): + commands: Tuple[str] = Field(("start",), const=True) + commands_prefix: str = Field("/", const=True) + deep_link: bool = False + deep_link_encoded: bool = False + + async def parse_command(self, text: str, bot: Bot) -> CommandObject: + """ + Extract command from the text and validate + + :param text: + :param bot: + :return: + """ + command = self.extract_command(text) + self.validate_prefix(command=command) + await self.validate_mention(bot=bot, command=command) + command = self.validate_command(command) + command = self.validate_deeplink(command=command) + self.do_magic(command=command) + return command + + def validate_deeplink(self, command: CommandObject) -> CommandObject: + if not self.deep_link: + return command + if not command.args: + raise CommandException("Deep-link was missing") + args = command.args + if self.deep_link_encoded: + try: + args = decode_payload(args) + except UnicodeDecodeError as e: + raise CommandException(f"Failed to decode Base64: {e}") + return replace(command, args=args) + return command diff --git a/aiogram/dispatcher/fsm/context.py b/aiogram/dispatcher/fsm/context.py index 78ed480b..dc4e4030 100644 --- a/aiogram/dispatcher/fsm/context.py +++ b/aiogram/dispatcher/fsm/context.py @@ -1,25 +1,35 @@ from typing import Any, Dict, Optional +from aiogram import Bot from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType class FSMContext: - def __init__(self, storage: BaseStorage, chat_id: int, user_id: int) -> None: + def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None: + self.bot = bot self.storage = storage self.chat_id = chat_id self.user_id = user_id async def set_state(self, state: StateType = None) -> None: - await self.storage.set_state(chat_id=self.chat_id, user_id=self.user_id, state=state) + await self.storage.set_state( + bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state + ) async def get_state(self) -> Optional[str]: - return await self.storage.get_state(chat_id=self.chat_id, user_id=self.user_id) + return await self.storage.get_state( + bot=self.bot, chat_id=self.chat_id, user_id=self.user_id + ) async def set_data(self, data: Dict[str, Any]) -> None: - await self.storage.set_data(chat_id=self.chat_id, user_id=self.user_id, data=data) + await self.storage.set_data( + bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data + ) async def get_data(self) -> Dict[str, Any]: - return await self.storage.get_data(chat_id=self.chat_id, user_id=self.user_id) + return await self.storage.get_data( + bot=self.bot, chat_id=self.chat_id, user_id=self.user_id + ) async def update_data( self, data: Optional[Dict[str, Any]] = None, **kwargs: Any @@ -27,7 +37,7 @@ class FSMContext: if data: kwargs.update(data) return await self.storage.update_data( - chat_id=self.chat_id, user_id=self.user_id, data=kwargs + bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs ) async def clear(self) -> None: diff --git a/aiogram/dispatcher/fsm/middleware.py b/aiogram/dispatcher/fsm/middleware.py index 1e3ba91c..734c5825 100644 --- a/aiogram/dispatcher/fsm/middleware.py +++ b/aiogram/dispatcher/fsm/middleware.py @@ -1,5 +1,6 @@ -from typing import Any, Awaitable, Callable, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Optional, cast +from aiogram import Bot from aiogram.dispatcher.fsm.context import FSMContext from aiogram.dispatcher.fsm.storage.base import BaseStorage from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy @@ -24,24 +25,27 @@ class FSMContextMiddleware(BaseMiddleware[Update]): event: Update, data: Dict[str, Any], ) -> Any: - context = self.resolve_event_context(data) + bot: Bot = cast(Bot, data["bot"]) + context = self.resolve_event_context(bot, data) data["fsm_storage"] = self.storage if context: data.update({"state": context, "raw_state": await context.get_state()}) if self.isolate_events: - async with self.storage.lock(chat_id=context.chat_id, user_id=context.user_id): + async with self.storage.lock( + bot=bot, chat_id=context.chat_id, user_id=context.user_id + ): return await handler(event, data) return await handler(event, data) - def resolve_event_context(self, data: Dict[str, Any]) -> Optional[FSMContext]: + def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]: user = data.get("event_from_user") chat = data.get("event_chat") chat_id = chat.id if chat else None user_id = user.id if user else None - return self.resolve_context(chat_id=chat_id, user_id=user_id) + return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id) def resolve_context( - self, chat_id: Optional[int], user_id: Optional[int] + self, bot: Bot, chat_id: Optional[int], user_id: Optional[int] ) -> Optional[FSMContext]: if chat_id is None: chat_id = user_id @@ -50,8 +54,8 @@ class FSMContextMiddleware(BaseMiddleware[Update]): chat_id, user_id = apply_strategy( chat_id=chat_id, user_id=user_id, strategy=self.strategy ) - return self.get_context(chat_id=chat_id, user_id=user_id) + return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id) return None - def get_context(self, chat_id: int, user_id: int) -> FSMContext: - return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id) + def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext: + return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id) diff --git a/aiogram/dispatcher/fsm/storage/base.py b/aiogram/dispatcher/fsm/storage/base.py index f394cd61..42826915 100644 --- a/aiogram/dispatcher/fsm/storage/base.py +++ b/aiogram/dispatcher/fsm/storage/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from contextlib import asynccontextmanager from typing import Any, AsyncGenerator, Dict, Optional, Union +from aiogram import Bot from aiogram.dispatcher.fsm.state import State StateType = Optional[Union[str, State]] @@ -11,34 +12,42 @@ class BaseStorage(ABC): @abstractmethod @asynccontextmanager async def lock( - self, chat_id: int, user_id: int + self, bot: Bot, chat_id: int, user_id: int ) -> AsyncGenerator[None, None]: # pragma: no cover yield None @abstractmethod async def set_state( - self, chat_id: int, user_id: int, state: StateType = None + self, bot: Bot, chat_id: int, user_id: int, state: StateType = None ) -> None: # pragma: no cover pass @abstractmethod - async def get_state(self, chat_id: int, user_id: int) -> Optional[str]: # pragma: no cover + async def get_state( + self, bot: Bot, chat_id: int, user_id: int + ) -> Optional[str]: # pragma: no cover pass @abstractmethod async def set_data( - self, chat_id: int, user_id: int, data: Dict[str, Any] + self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any] ) -> None: # pragma: no cover pass @abstractmethod - async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]: # pragma: no cover + async def get_data( + self, bot: Bot, chat_id: int, user_id: int + ) -> Dict[str, Any]: # pragma: no cover pass async def update_data( - self, chat_id: int, user_id: int, data: Dict[str, Any] + self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any] ) -> Dict[str, Any]: - current_data = await self.get_data(chat_id=chat_id, user_id=user_id) + current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id) current_data.update(data) - await self.set_data(chat_id=chat_id, user_id=user_id, data=current_data) + await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data) return current_data.copy() + + @abstractmethod + async def close(self) -> None: # pragma: no cover + pass diff --git a/aiogram/dispatcher/fsm/storage/memory.py b/aiogram/dispatcher/fsm/storage/memory.py index 933e225c..3e82d306 100644 --- a/aiogram/dispatcher/fsm/storage/memory.py +++ b/aiogram/dispatcher/fsm/storage/memory.py @@ -4,6 +4,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional +from aiogram import Bot from aiogram.dispatcher.fsm.state import State from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType @@ -17,23 +18,30 @@ class MemoryStorageRecord: class MemoryStorage(BaseStorage): def __init__(self) -> None: - self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict( - lambda: defaultdict(MemoryStorageRecord) - ) + self.storage: DefaultDict[ + Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] + ] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord))) + + async def close(self) -> None: + pass @asynccontextmanager - async def lock(self, chat_id: int, user_id: int) -> AsyncGenerator[None, None]: - async with self.storage[chat_id][user_id].lock: + async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]: + async with self.storage[bot][chat_id][user_id].lock: yield None - async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None: - self.storage[chat_id][user_id].state = state.state if isinstance(state, State) else state + async def set_state( + self, bot: Bot, chat_id: int, user_id: int, state: StateType = None + ) -> None: + self.storage[bot][chat_id][user_id].state = ( + state.state if isinstance(state, State) else state + ) - async def get_state(self, chat_id: int, user_id: int) -> Optional[str]: - return self.storage[chat_id][user_id].state + async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]: + return self.storage[bot][chat_id][user_id].state - async def set_data(self, chat_id: int, user_id: int, data: Dict[str, Any]) -> None: - self.storage[chat_id][user_id].data = data.copy() + async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None: + self.storage[bot][chat_id][user_id].data = data.copy() - async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]: - return self.storage[chat_id][user_id].data.copy() + async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]: + return self.storage[bot][chat_id][user_id].data.copy() diff --git a/aiogram/dispatcher/fsm/storage/redis.py b/aiogram/dispatcher/fsm/storage/redis.py new file mode 100644 index 00000000..64c832f9 --- /dev/null +++ b/aiogram/dispatcher/fsm/storage/redis.py @@ -0,0 +1,101 @@ +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast + +from aioredis import ConnectionPool, Redis + +from aiogram import Bot +from aiogram.dispatcher.fsm.state import State +from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType + +PrefixFactoryType = Callable[[Bot], str] +STATE_KEY = "state" +STATE_DATA_KEY = "data" +STATE_LOCK_KEY = "lock" + +DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60} + + +class RedisStorage(BaseStorage): + def __init__( + self, + redis: Redis, + prefix: str = "fsm", + prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False, + state_ttl: Optional[int] = None, + data_ttl: Optional[int] = None, + lock_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + if lock_kwargs is None: + lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS + self.redis = redis + self.prefix = prefix + self.prefix_bot = prefix_bot + self.state_ttl = state_ttl + self.data_ttl = data_ttl + self.lock_kwargs = lock_kwargs + + @classmethod + def from_url( + cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> "RedisStorage": + if connection_kwargs is None: + connection_kwargs = {} + pool = ConnectionPool.from_url(url, **connection_kwargs) + redis = Redis(connection_pool=pool) + return cls(redis=redis, **kwargs) + + async def close(self) -> None: + await self.redis.close() + + def generate_key(self, bot: Bot, *parts: Any) -> str: + prefix_parts = [self.prefix] + if self.prefix_bot: + if isinstance(self.prefix_bot, dict): + prefix_parts.append(self.prefix_bot[bot.id]) + elif callable(self.prefix_bot): + prefix_parts.append(self.prefix_bot(bot)) + else: + prefix_parts.append(str(bot.id)) + prefix_parts.extend(parts) + return ":".join(map(str, prefix_parts)) + + @asynccontextmanager + async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]: + key = self.generate_key(bot, chat_id, user_id, STATE_LOCK_KEY) + async with self.redis.lock(name=key, **self.lock_kwargs): + yield None + + async def set_state( + self, bot: Bot, chat_id: int, user_id: int, state: StateType = None + ) -> None: + key = self.generate_key(bot, chat_id, user_id, STATE_KEY) + if state is None: + await self.redis.delete(key) + else: + await self.redis.set( + key, state.state if isinstance(state, State) else state, ex=self.state_ttl + ) + + async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]: + key = self.generate_key(bot, chat_id, user_id, STATE_KEY) + value = await self.redis.get(key) + if isinstance(value, bytes): + return value.decode("utf-8") + return cast(Optional[str], value) + + async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None: + key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY) + if not data: + await self.redis.delete(key) + return + json_data = bot.session.json_dumps(data) + await self.redis.set(key, json_data, ex=self.data_ttl) + + async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]: + key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY) + value = await self.redis.get(key) + if value is None: + return {} + if isinstance(value, bytes): + value = value.decode("utf-8") + return cast(Dict[str, Any], bot.session.json_loads(value)) diff --git a/aiogram/utils/auth_widget.py b/aiogram/utils/auth_widget.py new file mode 100644 index 00000000..a67afe65 --- /dev/null +++ b/aiogram/utils/auth_widget.py @@ -0,0 +1,34 @@ +import hashlib +import hmac +from typing import Any, Dict + + +def check_signature(token: str, hash: str, **kwargs: Any) -> bool: + """ + Generate hexadecimal representation + of the HMAC-SHA-256 signature of the data-check-string + with the SHA256 hash of the bot's token used as a secret key + + :param token: + :param hash: + :param kwargs: all params received on auth + :return: + """ + secret = hashlib.sha256(token.encode("utf-8")) + check_string = "\n".join(map(lambda k: f"{k}={kwargs[k]}", sorted(kwargs))) + hmac_string = hmac.new( + secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256 + ).hexdigest() + return hmac_string == hash + + +def check_integrity(token: str, data: Dict[str, Any]) -> bool: + """ + Verify the authentication and the integrity + of the data received on user's auth + + :param token: Bot's token + :param data: all data that came on auth + :return: + """ + return check_signature(token, **data) diff --git a/aiogram/utils/deep_linking.py b/aiogram/utils/deep_linking.py new file mode 100644 index 00000000..caac2c26 --- /dev/null +++ b/aiogram/utils/deep_linking.py @@ -0,0 +1,131 @@ +""" +Deep linking + +Telegram bots have a deep linking mechanism, that allows for passing +additional parameters to the bot on startup. It could be a command that +launches the bot — or an auth token to connect the user's Telegram +account to their account on some external service. + +You can read detailed description in the source: +https://core.telegram.org/bots#deep-linking + +We have add some utils to get deep links more handy. + +Basic link example: + + .. code-block:: python + + from aiogram.utils.deep_linking import get_start_link + link = await get_start_link('foo') + + # result: 'https://t.me/MyBot?start=foo' + +Encoded link example: + + .. code-block:: python + + from aiogram.utils.deep_linking import get_start_link + + link = await get_start_link('foo', encode=True) + # result: 'https://t.me/MyBot?start=Zm9v' + +Decode it back example: + .. code-block:: python + + from aiogram.utils.deep_linking import decode_payload + from aiogram.types import Message + + @dp.message_handler(commands=["start"]) + async def handler(message: Message): + args = message.get_args() + payload = decode_payload(args) + await message.answer(f"Your payload: {payload}") + +""" +import re +from base64 import urlsafe_b64decode, urlsafe_b64encode +from typing import Literal, cast + +from aiogram import Bot +from aiogram.utils.link import create_telegram_link + +BAD_PATTERN = re.compile(r"[^_A-z0-9-]") + + +async def create_start_link(bot: Bot, payload: str, encode: bool = False) -> str: + """ + Create 'start' deep link with your payload. + + If you need to encode payload or pass special characters - + set encode as True + + :param bot: bot instance + :param payload: args passed with /start + :param encode: encode payload with base64url + :return: link + """ + username = (await bot.me()).username + return create_deep_link(username=username, link_type="start", payload=payload, encode=encode) + + +async def create_startgroup_link(bot: Bot, payload: str, encode: bool = False) -> str: + """ + Create 'startgroup' deep link with your payload. + + If you need to encode payload or pass special characters - + set encode as True + + :param bot: bot instance + :param payload: args passed with /start + :param encode: encode payload with base64url + :return: link + """ + username = (await bot.me()).username + return create_deep_link( + username=username, link_type="startgroup", payload=payload, encode=encode + ) + + +def create_deep_link( + username: str, link_type: Literal["start", "startgroup"], payload: str, encode: bool = False +) -> str: + """ + Create deep link. + + :param username: + :param link_type: `start` or `startgroup` + :param payload: any string-convertible data + :param encode: pass True to encode the payload + :return: deeplink + """ + if not isinstance(payload, str): + payload = str(payload) + + if encode: + payload = encode_payload(payload) + + if re.search(BAD_PATTERN, payload): + raise ValueError( + "Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. " + "Pass `encode=True` or encode payload manually." + ) + + if len(payload) > 64: + raise ValueError("Payload must be up to 64 characters long.") + + return create_telegram_link(username, **{cast(str, link_type): payload}) + + +def encode_payload(payload: str) -> str: + """Encode payload with URL-safe base64url.""" + payload = str(payload) + bytes_payload: bytes = urlsafe_b64encode(payload.encode()) + str_payload = bytes_payload.decode() + return str_payload.replace("=", "") + + +def decode_payload(payload: str) -> str: + """Decode payload with URL-safe base64url.""" + payload += "=" * (4 - len(payload) % 4) + result: bytes = urlsafe_b64decode(payload) + return result.decode() diff --git a/aiogram/utils/keyboard.py b/aiogram/utils/keyboard.py index 19409c94..9cb10b02 100644 --- a/aiogram/utils/keyboard.py +++ b/aiogram/utils/keyboard.py @@ -2,13 +2,27 @@ from __future__ import annotations from itertools import chain from itertools import cycle as repeat_all -from typing import Any, Generator, Generic, Iterable, List, Optional, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Generator, + Generic, + Iterable, + List, + Optional, + Type, + TypeVar, + Union, + no_type_check, +) from aiogram.dispatcher.filters.callback_data import CallbackData from aiogram.types import ( + CallbackGame, InlineKeyboardButton, InlineKeyboardMarkup, KeyboardButton, + LoginUrl, ReplyKeyboardMarkup, ) @@ -239,3 +253,28 @@ def repeat_last(items: Iterable[T]) -> Generator[T, None, None]: except StopIteration: finished = True yield value + + +class InlineKeyboardConstructor(KeyboardConstructor[InlineKeyboardButton]): + if TYPE_CHECKING: # pragma: no cover + + @no_type_check + def button( + self, + text: str, + url: Optional[str] = None, + login_url: Optional[LoginUrl] = None, + callback_data: Optional[Union[str, CallbackData]] = None, + switch_inline_query: Optional[str] = None, + switch_inline_query_current_chat: Optional[str] = None, + callback_game: Optional[CallbackGame] = None, + pay: Optional[bool] = None, + **kwargs: Any, + ) -> "KeyboardConstructor[InlineKeyboardButton]": + ... + + def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup: + ... + + def __init__(self) -> None: + super().__init__(InlineKeyboardButton) diff --git a/aiogram/utils/link.py b/aiogram/utils/link.py new file mode 100644 index 00000000..87d402e2 --- /dev/null +++ b/aiogram/utils/link.py @@ -0,0 +1,18 @@ +from typing import Any +from urllib.parse import urlencode, urljoin + + +def create_tg_link(link: str, **kwargs: Any) -> str: + url = f"tg://{link}" + if kwargs: + query = urlencode(kwargs) + url += f"?{query}" + return url + + +def create_telegram_link(uri: str, **kwargs: Any) -> str: + url = urljoin("https://t.me", uri) + if kwargs: + query = urlencode(query=kwargs) + url += f"?{query}" + return url diff --git a/aiogram/utils/text_decorations.py b/aiogram/utils/text_decorations.py index a41e481f..23c9c2a7 100644 --- a/aiogram/utils/text_decorations.py +++ b/aiogram/utils/text_decorations.py @@ -183,7 +183,7 @@ class MarkdownDecoration(TextDecoration): return f"`{value}`" def pre(self, value: str) -> str: - return f"```{value}```" + return f"```\n{value}\n```" def pre_language(self, value: str, language: str) -> str: return f"```{language}\n{value}\n```" diff --git a/mypy.ini b/mypy.ini index afe61218..a75c96cb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,6 @@ [mypy] ;plugins = pydantic.mypy -python_version = 3.7 +python_version = 3.8 show_error_codes = True show_error_context = True pretty = True @@ -29,3 +29,6 @@ ignore_missing_imports = True [mypy-uvloop] ignore_missing_imports = True + +[mypy-aioredis] +ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 7c3acd24..cb05c237 100644 --- a/poetry.lock +++ b/poetry.lock @@ -38,6 +38,21 @@ aiohttp = ">=2.3.2" attrs = ">=19.2.0" python-socks = {version = ">=1.0.1", extras = ["asyncio"]} +[[package]] +name = "aioredis" +version = "2.0.0a1" +description = "asyncio (PEP 3156) Redis support" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +async-timeout = "*" +typing-extensions = "*" + +[package.extras] +hiredis = ["hiredis (>=1.0)"] + [[package]] name = "alabaster" version = "0.7.12" @@ -156,7 +171,7 @@ lxml = ["lxml"] [[package]] name = "black" -version = "21.5b1" +version = "21.6b0" description = "The uncompromising code formatter." category = "dev" optional = false @@ -172,8 +187,9 @@ toml = ">=0.10.1" [package.extras] colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.6.0)", "aiohttp-cors"] +d = ["aiohttp (>=3.6.0)", "aiohttp-cors (>=0.4.0)"] python2 = ["typed-ast (>=1.4.2)"] +uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "cfgv" @@ -218,9 +234,6 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" -[package.dependencies] -toml = {version = "*", optional = true, markers = "extra == \"toml\""} - [package.extras] toml = ["toml"] @@ -234,7 +247,7 @@ python-versions = ">=3.5" [[package]] name = "distlib" -version = "0.3.1" +version = "0.3.2" description = "Distribution utilities" category = "dev" optional = false @@ -301,7 +314,7 @@ test = ["pytest", "pytest-cov", "pytest-xdist"] [[package]] name = "identify" -version = "2.2.5" +version = "2.2.10" description = "File identification library for Python" category = "dev" optional = false @@ -312,11 +325,11 @@ license = ["editdistance-s"] [[package]] name = "idna" -version = "3.1" +version = "3.2" description = "Internationalized Domain Names in Applications (IDNA)" category = "main" optional = false -python-versions = ">=3.4" +python-versions = ">=3.5" [[package]] name = "imagesize" @@ -328,7 +341,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "importlib-metadata" -version = "4.0.1" +version = "4.5.0" description = "Read metadata from Python packages" category = "dev" optional = false @@ -351,7 +364,7 @@ python-versions = "*" [[package]] name = "ipython" -version = "7.23.1" +version = "7.24.1" description = "IPython: Productive Interactive Computing" category = "dev" optional = false @@ -371,7 +384,7 @@ pygments = "*" traitlets = ">=4.2" [package.extras] -all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.16)", "pygments", "qtconsole", "requests", "testpath"] +all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.17)", "pygments", "qtconsole", "requests", "testpath"] doc = ["Sphinx (>=1.3)"] kernel = ["ipykernel"] nbconvert = ["nbconvert"] @@ -379,7 +392,7 @@ nbformat = ["nbformat"] notebook = ["notebook", "ipywidgets"] parallel = ["ipyparallel"] qtconsole = ["qtconsole"] -test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.16)"] +test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.17)"] [[package]] name = "ipython-genutils" @@ -739,18 +752,19 @@ testing = ["coverage", "hypothesis (>=5.7.1)"] [[package]] name = "pytest-cov" -version = "2.12.0" +version = "2.12.1" description = "Pytest plugin for measuring coverage." category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.dependencies] -coverage = {version = ">=5.2.1", extras = ["toml"]} +coverage = ">=5.2.1" pytest = ">=4.6" +toml = "*" [package.extras] -testing = ["fields", "hunter", "process-tests (==2.0.2)", "six", "pytest-xdist", "virtualenv"] +testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"] [[package]] name = "pytest-html" @@ -764,6 +778,17 @@ python-versions = ">=3.6" pytest = ">=5.0,<6.0.0 || >6.0.0" pytest-metadata = "*" +[[package]] +name = "pytest-lazy-fixture" +version = "0.6.3" +description = "It helps to use fixtures in pytest.mark.parametrize" +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +pytest = ">=3.2.5" + [[package]] name = "pytest-metadata" version = "1.11.0" @@ -930,17 +955,17 @@ test = ["pytest", "pytest-cov"] [[package]] name = "sphinx-copybutton" -version = "0.3.1" +version = "0.3.2" description = "Add a copy button to each of your code cells." category = "main" optional = false -python-versions = "*" +python-versions = ">=3.6" [package.dependencies] sphinx = ">=1.8" [package.extras] -code_style = ["flake8 (>=3.7.0,<3.8.0)", "black", "pre-commit (==1.17.0)"] +code_style = ["pre-commit (==2.12.1)"] [[package]] name = "sphinx-intl" @@ -1171,11 +1196,12 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pyt docs = ["sphinx", "sphinx-intl", "sphinx-autobuild", "sphinx-copybutton", "furo", "sphinx-prompt", "Sphinx-Substitution-Extensions"] fast = [] proxy = ["aiohttp-socks"] +redis = ["aioredis"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "2fcd44a8937b3ea48196c8eba8ceb0533281af34c884103bcc5b4f5f16b817d5" +content-hash = "362a6caf937b1c457599cbf2cd5d000eab4cac529bd7fe8c257ae713ebc63331" [metadata.files] aiofiles = [ @@ -1225,6 +1251,10 @@ aiohttp-socks = [ {file = "aiohttp_socks-0.5.5-py3-none-any.whl", hash = "sha256:faaa25ed4dc34440ca888d23e089420f3b1918dc4ecf062c3fd9474827ad6a39"}, {file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"}, ] +aioredis = [ + {file = "aioredis-2.0.0a1-py3-none-any.whl", hash = "sha256:32d7910724282a475c91b8b34403867069a4f07bf0c5ad5fe66cd797322f9a0d"}, + {file = "aioredis-2.0.0a1.tar.gz", hash = "sha256:5884f384b8ecb143bb73320a96e7c464fd38e117950a7d48340a35db8e35e7d2"}, +] alabaster = [ {file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"}, {file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"}, @@ -1274,8 +1304,8 @@ beautifulsoup4 = [ {file = "beautifulsoup4-4.9.3.tar.gz", hash = "sha256:84729e322ad1d5b4d25f805bfa05b902dd96450f43842c4e99067d5e1369eb25"}, ] black = [ - {file = "black-21.5b1-py3-none-any.whl", hash = "sha256:8a60071a0043876a4ae96e6c69bd3a127dad2c1ca7c8083573eb82f92705d008"}, - {file = "black-21.5b1.tar.gz", hash = "sha256:23695358dbcb3deafe7f0a3ad89feee5999a46be5fec21f4f1d108be0bcdb3b1"}, + {file = "black-21.6b0-py3-none-any.whl", hash = "sha256:dfb8c5a069012b2ab1e972e7b908f5fb42b6bbabcba0a788b86dc05067c7d9c7"}, + {file = "black-21.6b0.tar.gz", hash = "sha256:dc132348a88d103016726fe360cb9ede02cecf99b76e3660ce6c596be132ce04"}, ] cfgv = [ {file = "cfgv-3.3.0-py2.py3-none-any.whl", hash = "sha256:b449c9c6118fe8cca7fa5e00b9ec60ba08145d281d52164230a69211c5d597a1"}, @@ -1352,8 +1382,8 @@ decorator = [ {file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"}, ] distlib = [ - {file = "distlib-0.3.1-py2.py3-none-any.whl", hash = "sha256:8c09de2c67b3e7deef7184574fc060ab8a793e7adbb183d942c389c8b13c52fb"}, - {file = "distlib-0.3.1.zip", hash = "sha256:edf6116872c863e1aa9d5bb7cb5e05a022c519a4594dc703843343a9ddd9bff1"}, + {file = "distlib-0.3.2-py2.py3-none-any.whl", hash = "sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c"}, + {file = "distlib-0.3.2.zip", hash = "sha256:106fef6dc37dd8c0e2c0a60d3fca3e77460a48907f335fa28420463a6f799736"}, ] docutils = [ {file = "docutils-0.17.1-py2.py3-none-any.whl", hash = "sha256:cf316c8370a737a022b72b56874f6602acf974a37a9fba42ec2876387549fc61"}, @@ -1376,28 +1406,28 @@ furo = [ {file = "furo-2020.12.30b24.tar.gz", hash = "sha256:30171899c9c06d692a778e6daf6cb2e5cbb05efc6006e1692e5e776007dc8a8c"}, ] identify = [ - {file = "identify-2.2.5-py2.py3-none-any.whl", hash = "sha256:9c3ab58543c03bd794a1735e4552ef6dec49ec32053278130d525f0982447d47"}, - {file = "identify-2.2.5.tar.gz", hash = "sha256:bc1705694253763a3160b943316867792ec00ba7a0ee40b46e20aebaf4e0c46a"}, + {file = "identify-2.2.10-py2.py3-none-any.whl", hash = "sha256:18d0c531ee3dbc112fa6181f34faa179de3f57ea57ae2899754f16a7e0ff6421"}, + {file = "identify-2.2.10.tar.gz", hash = "sha256:5b41f71471bc738e7b586308c3fca172f78940195cb3bf6734c1e66fdac49306"}, ] idna = [ - {file = "idna-3.1-py3-none-any.whl", hash = "sha256:5205d03e7bcbb919cc9c19885f9920d622ca52448306f2377daede5cf3faac16"}, - {file = "idna-3.1.tar.gz", hash = "sha256:c5b02147e01ea9920e6b0a3f1f7bb833612d507592c837a6c49552768f4054e1"}, + {file = "idna-3.2-py3-none-any.whl", hash = "sha256:14475042e284991034cb48e06f6851428fb14c4dc953acd9be9a5e95c7b6dd7a"}, + {file = "idna-3.2.tar.gz", hash = "sha256:467fbad99067910785144ce333826c71fb0e63a425657295239737f7ecd125f3"}, ] imagesize = [ {file = "imagesize-1.2.0-py2.py3-none-any.whl", hash = "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1"}, {file = "imagesize-1.2.0.tar.gz", hash = "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"}, ] importlib-metadata = [ - {file = "importlib_metadata-4.0.1-py3-none-any.whl", hash = "sha256:d7eb1dea6d6a6086f8be21784cc9e3bcfa55872b52309bc5fad53a8ea444465d"}, - {file = "importlib_metadata-4.0.1.tar.gz", hash = "sha256:8c501196e49fb9df5df43833bdb1e4328f64847763ec8a50703148b73784d581"}, + {file = "importlib_metadata-4.5.0-py3-none-any.whl", hash = "sha256:833b26fb89d5de469b24a390e9df088d4e52e4ba33b01dc5e0e4f41b81a16c00"}, + {file = "importlib_metadata-4.5.0.tar.gz", hash = "sha256:b142cc1dd1342f31ff04bb7d022492b09920cb64fed867cd3ea6f80fe3ebd139"}, ] iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] ipython = [ - {file = "ipython-7.23.1-py3-none-any.whl", hash = "sha256:f78c6a3972dde1cc9e4041cbf4de583546314ba52d3c97208e5b6b2221a9cb7d"}, - {file = "ipython-7.23.1.tar.gz", hash = "sha256:714810a5c74f512b69d5f3b944c86e592cee0a5fb9c728e582f074610f6cf038"}, + {file = "ipython-7.24.1-py3-none-any.whl", hash = "sha256:d513e93327cf8657d6467c81f1f894adc125334ffe0e4ddd1abbb1c78d828703"}, + {file = "ipython-7.24.1.tar.gz", hash = "sha256:9bc24a99f5d19721fb8a2d1408908e9c0520a17fff2233ffe82620847f17f1b6"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, @@ -1637,13 +1667,17 @@ pytest-asyncio = [ {file = "pytest_asyncio-0.15.1-py3-none-any.whl", hash = "sha256:3042bcdf1c5d978f6b74d96a151c4cfb9dcece65006198389ccd7e6c60eb1eea"}, ] pytest-cov = [ - {file = "pytest-cov-2.12.0.tar.gz", hash = "sha256:8535764137fecce504a49c2b742288e3d34bc09eed298ad65963616cc98fd45e"}, - {file = "pytest_cov-2.12.0-py2.py3-none-any.whl", hash = "sha256:95d4933dcbbacfa377bb60b29801daa30d90c33981ab2a79e9ab4452c165066e"}, + {file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"}, + {file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"}, ] pytest-html = [ {file = "pytest-html-3.1.1.tar.gz", hash = "sha256:3ee1cf319c913d19fe53aeb0bc400e7b0bc2dbeb477553733db1dad12eb75ee3"}, {file = "pytest_html-3.1.1-py3-none-any.whl", hash = "sha256:b7f82f123936a3f4d2950bc993c2c1ca09ce262c9ae12f9ac763a2401380b455"}, ] +pytest-lazy-fixture = [ + {file = "pytest-lazy-fixture-0.6.3.tar.gz", hash = "sha256:0e7d0c7f74ba33e6e80905e9bfd81f9d15ef9a790de97993e34213deb5ad10ac"}, + {file = "pytest_lazy_fixture-0.6.3-py3-none-any.whl", hash = "sha256:e0b379f38299ff27a653f03eaa69b08a6fd4484e46fd1c9907d984b9f9daeda6"}, +] pytest-metadata = [ {file = "pytest-metadata-1.11.0.tar.gz", hash = "sha256:71b506d49d34e539cc3cfdb7ce2c5f072bea5c953320002c95968e0238f8ecf1"}, {file = "pytest_metadata-1.11.0-py2.py3-none-any.whl", hash = "sha256:576055b8336dd4a9006dd2a47615f76f2f8c30ab12b1b1c039d99e834583523f"}, @@ -1763,8 +1797,8 @@ sphinx-autobuild = [ {file = "sphinx_autobuild-2020.9.1-py3-none-any.whl", hash = "sha256:df5c72cb8b8fc9b31279c4619780c4e95029be6de569ff60a8bb2e99d20f63dd"}, ] sphinx-copybutton = [ - {file = "sphinx-copybutton-0.3.1.tar.gz", hash = "sha256:0e0461df394515284e3907e3f418a0c60ef6ab6c9a27a800c8552772d0a402a2"}, - {file = "sphinx_copybutton-0.3.1-py3-none-any.whl", hash = "sha256:5125c718e763596e6e52d92e15ee0d6f4800ad3817939be6dee51218870b3e3d"}, + {file = "sphinx-copybutton-0.3.2.tar.gz", hash = "sha256:f901f17e7dadc063bcfca592c5160f9113ec17501a59e046af3edb82b7527656"}, + {file = "sphinx_copybutton-0.3.2-py3-none-any.whl", hash = "sha256:f16f8ed8dfc60f2b34a58cb69bfa04722e24be2f6d7e04db5554c32cde4df815"}, ] sphinx-intl = [ {file = "sphinx-intl-2.0.1.tar.gz", hash = "sha256:b25a6ec169347909e8d983eefe2d8adecb3edc2f27760db79b965c69950638b4"}, diff --git a/pyproject.toml b/pyproject.toml index 4dee13c6..4befda3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aiogram" -version = "3.0.0-alpha.8" +version = "3.0.0-alpha.9" description = "Modern and fully asynchronous framework for Telegram Bot API" authors = ["Alex Root Junior "] license = "MIT" @@ -38,8 +38,9 @@ Babel = "^2.9.1" aiofiles = "^0.6.0" async_lru = "^1.0.2" aiohttp-socks = { version = "^0.5.5", optional = true } +aioredis = { version = "^2.0.0a1", allow-prereleases = true, optional = true } typing-extensions = { version = "^3.7.4", python = "<3.8" } -magic-filter = {version = "1.0.0a1", allow-prereleases = true} +magic-filter = { version = "1.0.0a1", allow-prereleases = true } sphinx = { version = "^3.1.0", optional = true } sphinx-intl = { version = "^2.0.1", optional = true } sphinx-autobuild = { version = "^2020.9.1", optional = true } @@ -50,6 +51,7 @@ Sphinx-Substitution-Extensions = { version = "^2020.9.30", optional = true } [tool.poetry.dev-dependencies] aiohttp-socks = "^0.5" +aioredis = { version = "^2.0.0a1", allow-prereleases = true } ipython = "^7.22.0" uvloop = { version = "^0.15.2", markers = "sys_platform == 'darwin' or sys_platform == 'linux'" } black = "^21.4b2" @@ -79,9 +81,11 @@ sphinx-copybutton = "^0.3.1" furo = "^2020.11.15-beta.17" sphinx-prompt = "^1.3.0" Sphinx-Substitution-Extensions = "^2020.9.30" +pytest-lazy-fixture = "^0.6.3" [tool.poetry.extras] fast = ["uvloop"] +redis = ["aioredis"] proxy = ["aiohttp-socks"] docs = [ "sphinx", diff --git a/tests/conftest.py b/tests/conftest.py index 60d9d0fe..92dd97fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,64 @@ import pytest +from _pytest.config import UsageError +from aioredis.connection import parse_url as parse_redis_url from aiogram import Bot +from aiogram.dispatcher.fsm.storage.memory import MemoryStorage +from aiogram.dispatcher.fsm.storage.redis import RedisStorage from tests.mocked_bot import MockedBot +def pytest_addoption(parser): + parser.addoption("--redis", default=None, help="run tests which require redis connection") + + +def pytest_configure(config): + config.addinivalue_line("markers", "redis: marked tests require redis connection to run") + + +def pytest_collection_modifyitems(config, items): + redis_uri = config.getoption("--redis") + if redis_uri is None: + skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run") + for item in items: + if "redis" in item.keywords: + item.add_marker(skip_redis) + return + try: + parse_redis_url(redis_uri) + except ValueError as e: + raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}") + + +@pytest.fixture(scope="session") +def redis_server(request): + redis_uri = request.config.getoption("--redis") + return redis_uri + + +@pytest.fixture() +@pytest.mark.redis +async def redis_storage(redis_server): + if not redis_server: + pytest.skip("Redis is not available here") + storage = RedisStorage.from_url(redis_server) + try: + yield storage + finally: + conn = await storage.redis + await conn.flushdb() + await storage.close() + + +@pytest.fixture() +async def memory_storage(): + storage = MemoryStorage() + try: + yield storage + finally: + await storage.close() + + @pytest.fixture() def bot(): bot = MockedBot() diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml new file mode 100644 index 00000000..453f5e5a --- /dev/null +++ b/tests/docker-compose.yml @@ -0,0 +1,7 @@ +version: "3.9" + +services: + redis: + image: redis:6-alpine + ports: + - "${REDIS_PORT-6379}:6379" diff --git a/tests/test_api/test_client/test_session/test_base_session.py b/tests/test_api/test_client/test_session/test_base_session.py index 448f663e..ef82c1d3 100644 --- a/tests/test_api/test_client/test_session/test_base_session.py +++ b/tests/test_api/test_client/test_session/test_base_session.py @@ -6,7 +6,7 @@ import pytest from aiogram.client.session.base import BaseSession, TelegramType from aiogram.client.telegram import PRODUCTION, TelegramAPIServer -from aiogram.methods import DeleteMessage, GetMe, Response, TelegramMethod +from aiogram.methods import DeleteMessage, GetMe, TelegramMethod from aiogram.types import UNSET try: @@ -20,7 +20,9 @@ class CustomSession(BaseSession): async def close(self): pass - async def make_request(self, token: str, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET) -> None: # type: ignore + async def make_request( + self, token: str, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET + ) -> None: # type: ignore assert isinstance(token, str) assert isinstance(method, TelegramMethod) diff --git a/tests/test_api/test_methods/test_edit_message_media.py b/tests/test_api/test_methods/test_edit_message_media.py index c6715163..ee3003b9 100644 --- a/tests/test_api/test_methods/test_edit_message_media.py +++ b/tests/test_api/test_methods/test_edit_message_media.py @@ -3,7 +3,7 @@ from typing import Union import pytest from aiogram.methods import EditMessageMedia, Request -from aiogram.types import BufferedInputFile, InputMedia, InputMediaPhoto, Message +from aiogram.types import BufferedInputFile, InputMediaPhoto, Message from tests.mocked_bot import MockedBot diff --git a/tests/test_api/test_methods/test_get_url.py b/tests/test_api/test_methods/test_get_url.py index 3c769ca2..76b24200 100644 --- a/tests/test_api/test_methods/test_get_url.py +++ b/tests/test_api/test_methods/test_get_url.py @@ -2,8 +2,8 @@ import datetime from typing import Optional import pytest -from aiogram.types import Chat, Message +from aiogram.types import Chat, Message from tests.mocked_bot import MockedBot diff --git a/tests/test_api/test_methods/test_send_audio.py b/tests/test_api/test_methods/test_send_audio.py index 4a33bbdc..2a5e67fd 100644 --- a/tests/test_api/test_methods/test_send_audio.py +++ b/tests/test_api/test_methods/test_send_audio.py @@ -3,7 +3,7 @@ import datetime import pytest from aiogram.methods import Request, SendAudio -from aiogram.types import Audio, Chat, File, Message +from aiogram.types import Audio, Chat, Message from tests.mocked_bot import MockedBot diff --git a/tests/test_api/test_methods/test_set_chat_administrator_custom_title.py b/tests/test_api/test_methods/test_set_chat_administrator_custom_title.py index 2f4752c7..1395df0d 100644 --- a/tests/test_api/test_methods/test_set_chat_administrator_custom_title.py +++ b/tests/test_api/test_methods/test_set_chat_administrator_custom_title.py @@ -1,6 +1,6 @@ import pytest -from aiogram.methods import Request, SetChatAdministratorCustomTitle, SetChatTitle +from aiogram.methods import Request, SetChatAdministratorCustomTitle from tests.mocked_bot import MockedBot diff --git a/tests/test_api/test_methods/test_set_chat_photo.py b/tests/test_api/test_methods/test_set_chat_photo.py index 02e00670..f648ccdb 100644 --- a/tests/test_api/test_methods/test_set_chat_photo.py +++ b/tests/test_api/test_methods/test_set_chat_photo.py @@ -1,7 +1,7 @@ import pytest from aiogram.methods import Request, SetChatPhoto -from aiogram.types import BufferedInputFile, InputFile +from aiogram.types import BufferedInputFile from tests.mocked_bot import MockedBot diff --git a/tests/test_dispatcher/test_dispatcher.py b/tests/test_dispatcher/test_dispatcher.py index ecf44712..37bbf634 100644 --- a/tests/test_dispatcher/test_dispatcher.py +++ b/tests/test_dispatcher/test_dispatcher.py @@ -9,8 +9,6 @@ import pytest from aiogram import Bot from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler -from aiogram.dispatcher.fsm.strategy import FSMStrategy -from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware from aiogram.dispatcher.router import Router from aiogram.methods import GetMe, GetUpdates, SendMessage from aiogram.types import ( @@ -423,7 +421,7 @@ class TestDispatcher: assert User.get_current(False) return kwargs - result = await router.update.trigger(update, test="PASS") + result = await router.update.trigger(update, test="PASS", bot=None) assert isinstance(result, dict) assert result["event_update"] == update assert result["event_router"] == router @@ -526,8 +524,9 @@ class TestDispatcher: assert len(log_records) == 1 assert "Cause exception while process update" in log_records[0] + @pytest.mark.parametrize("as_task", [True, False]) @pytest.mark.asyncio - async def test_polling(self, bot: MockedBot): + async def test_polling(self, bot: MockedBot, as_task: bool): dispatcher = Dispatcher() async def _mock_updates(*_): @@ -539,8 +538,11 @@ class TestDispatcher: "aiogram.dispatcher.dispatcher.Dispatcher._listen_updates" ) as patched_listen_updates: patched_listen_updates.return_value = _mock_updates() - await dispatcher._polling(bot=bot) - mocked_process_update.assert_awaited() + await dispatcher._polling(bot=bot, handle_as_tasks=as_task) + if as_task: + pass + else: + mocked_process_update.assert_awaited() @pytest.mark.asyncio async def test_exception_handler_catch_exceptions(self): @@ -548,9 +550,12 @@ class TestDispatcher: router = Router() dp.include_router(router) + class CustomException(Exception): + pass + @router.message() async def message_handler(message: Message): - raise Exception("KABOOM") + raise CustomException("KABOOM") update = Update( update_id=42, @@ -562,23 +567,23 @@ class TestDispatcher: from_user=User(id=42, is_bot=False, first_name="Test"), ), ) - with pytest.raises(Exception, match="KABOOM"): - await dp.update.trigger(update) + with pytest.raises(CustomException, match="KABOOM"): + await dp.update.trigger(update, bot=None) @router.errors() async def error_handler(event: Update, exception: Exception): return "KABOOM" - response = await dp.update.trigger(update) + response = await dp.update.trigger(update, bot=None) assert response == "KABOOM" @dp.errors() async def root_error_handler(event: Update, exception: Exception): return exception - response = await dp.update.trigger(update) + response = await dp.update.trigger(update, bot=None) - assert isinstance(response, Exception) + assert isinstance(response, CustomException) assert str(response) == "KABOOM" @pytest.mark.asyncio @@ -654,20 +659,3 @@ class TestDispatcher: log_records = [rec.message for rec in caplog.records] assert "Cause exception while process update" in log_records[0] - - @pytest.mark.parametrize( - "strategy,case,expected", - [ - [FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)], - [FSMStrategy.CHAT, (-42, 42), (-42, -42)], - [FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)], - [FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)], - [FSMStrategy.CHAT, (42, 42), (42, 42)], - [FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)], - ], - ) - def test_get_current_state_context(self, strategy, case, expected): - dp = Dispatcher(fsm_strategy=strategy) - chat_id, user_id = case - state = dp.current_state(chat_id=chat_id, user_id=user_id) - assert (state.chat_id, state.user_id) == expected diff --git a/tests/test_dispatcher/test_event/test_handler.py b/tests/test_dispatcher/test_event/test_handler.py index d7e6a1da..168dac59 100644 --- a/tests/test_dispatcher/test_event/test_handler.py +++ b/tests/test_dispatcher/test_event/test_handler.py @@ -5,7 +5,6 @@ import pytest from aiogram import F from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject -from aiogram.dispatcher.filters import Text from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.handler.base import BaseHandler from aiogram.types import Update diff --git a/tests/test_dispatcher/test_filters/test_command.py b/tests/test_dispatcher/test_filters/test_command.py index 6eb24097..a3ea4756 100644 --- a/tests/test_dispatcher/test_filters/test_command.py +++ b/tests/test_dispatcher/test_filters/test_command.py @@ -1,10 +1,11 @@ import datetime import re -from typing import Match import pytest +from aiogram import F from aiogram.dispatcher.filters import Command, CommandObject +from aiogram.dispatcher.filters.command import CommandStart from aiogram.methods import GetMe from aiogram.types import Chat, Message, User from tests.mocked_bot import MockedBot @@ -18,45 +19,54 @@ class TestCommandFilter: assert cmd.commands[0] == "start" assert cmd == Command(commands=["start"]) + @pytest.mark.parametrize( + "text,command,result", + [ + ["/test@tbot", Command(commands=["test"], commands_prefix="/"), True], + ["!test", Command(commands=["test"], commands_prefix="/"), False], + ["/test@mention", Command(commands=["test"], commands_prefix="/"), False], + ["/tests", Command(commands=["test"], commands_prefix="/"), False], + ["/", Command(commands=["test"], commands_prefix="/"), False], + ["/ test", Command(commands=["test"], commands_prefix="/"), False], + ["", Command(commands=["test"], commands_prefix="/"), False], + [" ", Command(commands=["test"], commands_prefix="/"), False], + ["test", Command(commands=["test"], commands_prefix="/"), False], + [" test", Command(commands=["test"], commands_prefix="/"), False], + ["a", Command(commands=["test"], commands_prefix="/"), False], + ["/test@tbot some args", Command(commands=["test"]), True], + ["/test42@tbot some args", Command(commands=[re.compile(r"test(\d+)")]), True], + [ + "/test42@tbot some args", + Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "some args"), + True, + ], + [ + "/test42@tbot some args", + Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "test"), + False, + ], + ["/start test", CommandStart(), True], + ["/start", CommandStart(deep_link=True), False], + ["/start test", CommandStart(deep_link=True), True], + ["/start test", CommandStart(deep_link=True, deep_link_encoded=True), False], + ["/start dGVzdA", CommandStart(deep_link=True, deep_link_encoded=True), True], + ], + ) @pytest.mark.asyncio - async def test_parse_command(self, bot: MockedBot): - # TODO: parametrize + async def test_parse_command(self, bot: MockedBot, text: str, result: bool, command: Command): # TODO: test ignore case # TODO: test ignore mention bot.add_result_for( GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot") ) - command = Command(commands=["test", re.compile(r"test(\d+)")], commands_prefix="/") - assert await command.parse_command("/test@tbot", bot) - assert not await command.parse_command("!test", bot) - assert not await command.parse_command("/test@mention", bot) - assert not await command.parse_command("/tests", bot) - assert not await command.parse_command("/", bot) - assert not await command.parse_command("/ test", bot) - assert not await command.parse_command("", bot) - assert not await command.parse_command(" ", bot) - assert not await command.parse_command("test", bot) - assert not await command.parse_command(" test", bot) - assert not await command.parse_command("a", bot) + message = Message( + message_id=0, text=text, chat=Chat(id=42, type="private"), date=datetime.datetime.now() + ) - result = await command.parse_command("/test@tbot some args", bot) - assert isinstance(result, dict) - assert "command" in result - assert isinstance(result["command"], CommandObject) - assert result["command"].command == "test" - assert result["command"].mention == "tbot" - assert result["command"].args == "some args" - - result = await command.parse_command("/test42@tbot some args", bot) - assert isinstance(result, dict) - assert "command" in result - assert isinstance(result["command"], CommandObject) - assert result["command"].command == "test42" - assert result["command"].mention == "tbot" - assert result["command"].args == "some args" - assert isinstance(result["command"].match, Match) + response = await command(message, bot) + assert bool(response) is result @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/test_dispatcher/test_fsm/storage/test_memory.py b/tests/test_dispatcher/test_fsm/storage/test_memory.py deleted file mode 100644 index 2f587075..00000000 --- a/tests/test_dispatcher/test_fsm/storage/test_memory.py +++ /dev/null @@ -1,45 +0,0 @@ -import pytest - -from aiogram.dispatcher.fsm.storage.memory import MemoryStorage, MemoryStorageRecord - - -@pytest.fixture() -def storage(): - return MemoryStorage() - - -class TestMemoryStorage: - @pytest.mark.asyncio - async def test_set_state(self, storage: MemoryStorage): - assert await storage.get_state(chat_id=-42, user_id=42) is None - - await storage.set_state(chat_id=-42, user_id=42, state="state") - assert await storage.get_state(chat_id=-42, user_id=42) == "state" - - assert -42 in storage.storage - assert 42 in storage.storage[-42] - assert isinstance(storage.storage[-42][42], MemoryStorageRecord) - assert storage.storage[-42][42].state == "state" - - @pytest.mark.asyncio - async def test_set_data(self, storage: MemoryStorage): - assert await storage.get_data(chat_id=-42, user_id=42) == {} - - await storage.set_data(chat_id=-42, user_id=42, data={"foo": "bar"}) - assert await storage.get_data(chat_id=-42, user_id=42) == {"foo": "bar"} - - assert -42 in storage.storage - assert 42 in storage.storage[-42] - assert isinstance(storage.storage[-42][42], MemoryStorageRecord) - assert storage.storage[-42][42].data == {"foo": "bar"} - - @pytest.mark.asyncio - async def test_update_data(self, storage: MemoryStorage): - assert await storage.get_data(chat_id=-42, user_id=42) == {} - assert await storage.update_data(chat_id=-42, user_id=42, data={"foo": "bar"}) == { - "foo": "bar" - } - assert await storage.update_data(chat_id=-42, user_id=42, data={"baz": "spam"}) == { - "foo": "bar", - "baz": "spam", - } diff --git a/tests/test_dispatcher/test_fsm/storage/test_redis.py b/tests/test_dispatcher/test_fsm/storage/test_redis.py new file mode 100644 index 00000000..7b914a33 --- /dev/null +++ b/tests/test_dispatcher/test_fsm/storage/test_redis.py @@ -0,0 +1,21 @@ +import pytest + +from aiogram.dispatcher.fsm.storage.redis import RedisStorage +from tests.mocked_bot import MockedBot + + +@pytest.mark.redis +class TestRedisStorage: + @pytest.mark.parametrize( + "prefix_bot,result", + [ + [False, "fsm:-1:2"], + [True, "fsm:42:-1:2"], + [{42: "kaboom"}, "fsm:kaboom:-1:2"], + [lambda bot: "kaboom", "fsm:kaboom:-1:2"], + ], + ) + @pytest.mark.asyncio + async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result): + storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot) + assert storage.generate_key(bot, -1, 2) == result diff --git a/tests/test_dispatcher/test_fsm/storage/test_storages.py b/tests/test_dispatcher/test_fsm/storage/test_storages.py new file mode 100644 index 00000000..fcb2deae --- /dev/null +++ b/tests/test_dispatcher/test_fsm/storage/test_storages.py @@ -0,0 +1,44 @@ +import pytest + +from aiogram.dispatcher.fsm.storage.base import BaseStorage +from tests.mocked_bot import MockedBot + + +@pytest.mark.parametrize( + "storage", + [pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")], +) +class TestStorages: + @pytest.mark.asyncio + async def test_lock(self, bot: MockedBot, storage: BaseStorage): + # TODO: ?!? + async with storage.lock(bot=bot, chat_id=-42, user_id=42): + assert True, "You are kidding me?" + + @pytest.mark.asyncio + async def test_set_state(self, bot: MockedBot, storage: BaseStorage): + assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None + + await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state") + assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state" + await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None) + assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None + + @pytest.mark.asyncio + async def test_set_data(self, bot: MockedBot, storage: BaseStorage): + assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {} + + await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}) + assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"} + await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={}) + assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {} + + @pytest.mark.asyncio + async def test_update_data(self, bot: MockedBot, storage: BaseStorage): + assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {} + assert await storage.update_data( + bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"} + ) == {"foo": "bar"} + assert await storage.update_data( + bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"} + ) == {"foo": "bar", "baz": "spam"} diff --git a/tests/test_dispatcher/test_fsm/test_context.py b/tests/test_dispatcher/test_fsm/test_context.py index 6c444c44..fb98c423 100644 --- a/tests/test_dispatcher/test_fsm/test_context.py +++ b/tests/test_dispatcher/test_fsm/test_context.py @@ -2,27 +2,28 @@ import pytest from aiogram.dispatcher.fsm.context import FSMContext from aiogram.dispatcher.fsm.storage.memory import MemoryStorage +from tests.mocked_bot import MockedBot @pytest.fixture() -def state(): +def state(bot: MockedBot): storage = MemoryStorage() - ctx = storage.storage[-42][42] + ctx = storage.storage[bot][-42][42] ctx.state = "test" ctx.data = {"foo": "bar"} - return FSMContext(storage=storage, user_id=-42, chat_id=42) + return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42) class TestFSMContext: @pytest.mark.asyncio - async def test_address_mapping(self): + async def test_address_mapping(self, bot: MockedBot): storage = MemoryStorage() - ctx = storage.storage[-42][42] + ctx = storage.storage[bot][-42][42] ctx.state = "test" ctx.data = {"foo": "bar"} - state = FSMContext(storage=storage, chat_id=-42, user_id=42) - state2 = FSMContext(storage=storage, chat_id=42, user_id=42) - state3 = FSMContext(storage=storage, chat_id=69, user_id=69) + state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42) + state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42) + state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69) assert await state.get_state() == "test" assert await state2.get_state() is None diff --git a/tests/test_dispatcher/test_handler/test_chosen_inline_result.py b/tests/test_dispatcher/test_handler/test_chosen_inline_result.py index 2e1f4045..ecbb363d 100644 --- a/tests/test_dispatcher/test_handler/test_chosen_inline_result.py +++ b/tests/test_dispatcher/test_handler/test_chosen_inline_result.py @@ -3,7 +3,7 @@ from typing import Any import pytest from aiogram.dispatcher.handler import ChosenInlineResultHandler -from aiogram.types import CallbackQuery, ChosenInlineResult, User +from aiogram.types import ChosenInlineResult, User class TestChosenInlineResultHandler: diff --git a/tests/test_dispatcher/test_handler/test_error.py b/tests/test_dispatcher/test_handler/test_error.py index f6e6b090..a83d96a4 100644 --- a/tests/test_dispatcher/test_handler/test_error.py +++ b/tests/test_dispatcher/test_handler/test_error.py @@ -2,16 +2,7 @@ from typing import Any import pytest -from aiogram.dispatcher.handler import ErrorHandler, PollHandler -from aiogram.types import ( - CallbackQuery, - InlineQuery, - Poll, - PollOption, - ShippingAddress, - ShippingQuery, - User, -) +from aiogram.dispatcher.handler import ErrorHandler class TestErrorHandler: diff --git a/tests/test_dispatcher/test_handler/test_inline_query.py b/tests/test_dispatcher/test_handler/test_inline_query.py index 100fccdd..99ed913f 100644 --- a/tests/test_dispatcher/test_handler/test_inline_query.py +++ b/tests/test_dispatcher/test_handler/test_inline_query.py @@ -3,7 +3,7 @@ from typing import Any import pytest from aiogram.dispatcher.handler import InlineQueryHandler -from aiogram.types import CallbackQuery, InlineQuery, User +from aiogram.types import InlineQuery, User class TestCallbackQueryHandler: diff --git a/tests/test_dispatcher/test_handler/test_poll.py b/tests/test_dispatcher/test_handler/test_poll.py index 172012d6..6fc23e9e 100644 --- a/tests/test_dispatcher/test_handler/test_poll.py +++ b/tests/test_dispatcher/test_handler/test_poll.py @@ -3,15 +3,7 @@ from typing import Any import pytest from aiogram.dispatcher.handler import PollHandler -from aiogram.types import ( - CallbackQuery, - InlineQuery, - Poll, - PollOption, - ShippingAddress, - ShippingQuery, - User, -) +from aiogram.types import Poll, PollOption class TestShippingQueryHandler: diff --git a/tests/test_dispatcher/test_handler/test_shipping_query.py b/tests/test_dispatcher/test_handler/test_shipping_query.py index 0d5aa578..0e938571 100644 --- a/tests/test_dispatcher/test_handler/test_shipping_query.py +++ b/tests/test_dispatcher/test_handler/test_shipping_query.py @@ -3,7 +3,7 @@ from typing import Any import pytest from aiogram.dispatcher.handler import ShippingQueryHandler -from aiogram.types import CallbackQuery, InlineQuery, ShippingAddress, ShippingQuery, User +from aiogram.types import ShippingAddress, ShippingQuery, User class TestShippingQueryHandler: diff --git a/tests/test_utils/test_auth_widget.py b/tests/test_utils/test_auth_widget.py new file mode 100644 index 00000000..a6071760 --- /dev/null +++ b/tests/test_utils/test_auth_widget.py @@ -0,0 +1,27 @@ +import pytest + +from aiogram.utils.auth_widget import check_integrity + +TOKEN = "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11" + + +@pytest.fixture +def data(): + return { + "id": "42", + "first_name": "John", + "last_name": "Smith", + "username": "username", + "photo_url": "https://t.me/i/userpic/320/picname.jpg", + "auth_date": "1565810688", + "hash": "c303db2b5a06fe41d23a9b14f7c545cfc11dcc7473c07c9c5034ae60062461ce", + } + + +class TestCheckIntegrity: + def test_ok(self, data): + assert check_integrity(TOKEN, data) is True + + def test_fail(self, data): + data.pop("username") + assert check_integrity(TOKEN, data) is False diff --git a/tests/test_utils/test_deep_linking.py b/tests/test_utils/test_deep_linking.py new file mode 100644 index 00000000..de44725c --- /dev/null +++ b/tests/test_utils/test_deep_linking.py @@ -0,0 +1,94 @@ +import pytest +from async_lru import alru_cache + +from aiogram.utils.deep_linking import ( + create_start_link, + create_startgroup_link, + decode_payload, + encode_payload, +) +from tests.mocked_bot import MockedBot + +PAYLOADS = [ + "foo", + "AAbbCCddEEff1122334455", + "aaBBccDDeeFF5544332211", + -12345678901234567890, + 12345678901234567890, +] +WRONG_PAYLOADS = [ + "@BotFather", + "Some:special$characters#=", + "spaces spaces spaces", + 1234567890123456789.0, +] + + +@pytest.fixture(params=PAYLOADS, name="payload") +def payload_fixture(request): + return request.param + + +@pytest.fixture(params=WRONG_PAYLOADS, name="wrong_payload") +def wrong_payload_fixture(request): + return request.param + + +@pytest.fixture(autouse=True) +def get_bot_user_fixture(monkeypatch): + """Monkey patching of bot.me calling.""" + + @alru_cache() + async def get_bot_user_mock(self): + from aiogram.types import User + + return User( + id=12345678, + is_bot=True, + first_name="FirstName", + last_name="LastName", + username="username", + language_code="uk-UA", + ) + + monkeypatch.setattr(MockedBot, "me", get_bot_user_mock) + + +@pytest.mark.asyncio +class TestDeepLinking: + async def test_get_start_link(self, bot, payload): + link = await create_start_link(bot=bot, payload=payload) + assert link == f"https://t.me/username?start={payload}" + + async def test_wrong_symbols(self, bot, wrong_payload): + with pytest.raises(ValueError): + await create_start_link(bot, wrong_payload) + + async def test_get_startgroup_link(self, bot, payload): + link = await create_startgroup_link(bot, payload) + assert link == f"https://t.me/username?startgroup={payload}" + + async def test_filter_encode_and_decode(self, payload): + encoded = encode_payload(payload) + decoded = decode_payload(encoded) + assert decoded == str(payload) + + async def test_get_start_link_with_encoding(self, bot, wrong_payload): + # define link + link = await create_start_link(bot, wrong_payload, encode=True) + + # define reference link + encoded_payload = encode_payload(wrong_payload) + + assert link == f"https://t.me/username?start={encoded_payload}" + + async def test_64_len_payload(self, bot): + payload = "p" * 64 + link = await create_start_link(bot, payload) + assert link + + async def test_too_long_payload(self, bot): + payload = "p" * 65 + print(payload, len(payload)) + with pytest.raises(ValueError): + await create_start_link(bot, payload) diff --git a/tests/test_utils/test_link.py b/tests/test_utils/test_link.py new file mode 100644 index 00000000..4dbfe8a2 --- /dev/null +++ b/tests/test_utils/test_link.py @@ -0,0 +1,24 @@ +from typing import Any, Dict + +import pytest + +from aiogram.utils.link import create_telegram_link, create_tg_link + + +class TestLink: + @pytest.mark.parametrize( + "base,params,result", + [["user", dict(id=42), "tg://user?id=42"]], + ) + def test_create_tg_link(self, base: str, params: Dict[str, Any], result: str): + assert create_tg_link(base, **params) == result + + @pytest.mark.parametrize( + "base,params,result", + [ + ["username", dict(), "https://t.me/username"], + ["username", dict(start="test"), "https://t.me/username?start=test"], + ], + ) + def test_create_telegram_link(self, base: str, params: Dict[str, Any], result: str): + assert create_telegram_link(base, **params) == result diff --git a/tests/test_utils/test_markdown.py b/tests/test_utils/test_markdown.py index 12e44ccf..815b1c5d 100644 --- a/tests/test_utils/test_markdown.py +++ b/tests/test_utils/test_markdown.py @@ -35,7 +35,7 @@ class TestMarkdown: [hitalic, ("test", "test"), " ", "test test"], [code, ("test", "test"), " ", "`test test`"], [hcode, ("test", "test"), " ", "test test"], - [pre, ("test", "test"), " ", "```test test```"], + [pre, ("test", "test"), " ", "```\ntest test\n```"], [hpre, ("test", "test"), " ", "
test test
"], [underline, ("test", "test"), " ", "__\rtest test__\r"], [hunderline, ("test", "test"), " ", "test test"], diff --git a/tests/test_utils/test_text_decorations.py b/tests/test_utils/test_text_decorations.py index 6cb5105d..da171575 100644 --- a/tests/test_utils/test_text_decorations.py +++ b/tests/test_utils/test_text_decorations.py @@ -55,7 +55,7 @@ class TestTextDecoration: [markdown_decoration, MessageEntity(type="bold", offset=0, length=5), "*test*"], [markdown_decoration, MessageEntity(type="italic", offset=0, length=5), "_\rtest_\r"], [markdown_decoration, MessageEntity(type="code", offset=0, length=5), "`test`"], - [markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```test```"], + [markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```\ntest\n```"], [ markdown_decoration, MessageEntity(type="pre", offset=0, length=5, language="python"),