Backport and improvements (#601)

* Backport RedisStorage, deep-linking
* Allow prereleases for aioredis
* Bump dependencies
* Correctly skip Redis tests on Windows
* Reformat tests code and bump Makefile
This commit is contained in:
Alex Root Junior 2021-06-15 01:45:31 +03:00 committed by GitHub
parent 32bc05130f
commit 83d6ab48c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 1004 additions and 327 deletions

View file

@ -43,6 +43,12 @@ jobs:
virtualenvs-create: true virtualenvs-create: true
virtualenvs-in-project: true virtualenvs-in-project: true
- name: Setup redis
if: ${{ matrix.os != 'windows-latest' }}
uses: shogo82148/actions-setup-redis@v1
with:
redis-version: 6
- name: Load cached venv - name: Load cached venv
id: cached-poetry-dependencies id: cached-poetry-dependencies
uses: actions/cache@v2 uses: actions/cache@v2
@ -64,7 +70,14 @@ jobs:
run: | run: |
poetry run black --check --diff aiogram tests poetry run black --check --diff aiogram tests
- name: Run tests - name: Run tests (with Redis)
if: ${{ matrix.os != 'windows-latest' }}
run: |
poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml --redis redis://localhost:6379/0
- name: Run tests (without Redis)
# Redis can't be used on GitHub Windows Runners
if: ${{ matrix.os == 'windows-latest' }}
run: | run: |
poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml

View file

@ -4,8 +4,14 @@ base_python := python3
py := poetry run py := poetry run
python := $(py) python python := $(py) python
package_dir := aiogram
tests_dir := tests
scripts_dir := scripts
code_dir := $(package_dir) $(tests_dir) $(scripts_dir)
reports_dir := reports reports_dir := reports
redis_connection := redis://localhost:6379
.PHONY: help .PHONY: help
help: help:
@echo "=======================================================================================" @echo "======================================================================================="
@ -17,13 +23,8 @@ help:
@echo " clean: Delete temporary files" @echo " clean: Delete temporary files"
@echo "" @echo ""
@echo "Code quality:" @echo "Code quality:"
@echo " isort: Run isort tool" @echo " lint: Lint code by isort, black, flake8 and mypy tools"
@echo " black: Run black tool" @echo " reformat: Reformat code by isort and black tools"
@echo " flake8: Run flake8 tool"
@echo " flake8-report: Run flake8 with HTML reporting"
@echo " mypy: Run mypy tool"
@echo " mypy-report: Run mypy tool with HTML reporting"
@echo " lint: Run isort, black, flake8 and mypy tools"
@echo "" @echo ""
@echo "Tests:" @echo "Tests:"
@echo " test: Run tests" @echo " test: Run tests"
@ -65,33 +66,17 @@ clean:
# Code quality # Code quality
# ================================================================================================= # =================================================================================================
.PHONY: isort
isort:
$(py) isort aiogram tests scripts
.PHONY: black
black:
$(py) black aiogram tests scripts
.PHONY: flake8
flake8:
$(py) flake8 aiogram
.PHONY: flake8-report
flake8-report:
mkdir -p $(reports_dir)/flake8
$(py) flake8 --format=html --htmldir=$(reports_dir)/flake8 aiogram
.PHONY: mypy
mypy:
$(py) mypy aiogram
.PHONY: mypy-report
mypy-report:
$(py) mypy aiogram --html-report $(reports_dir)/typechecking
.PHONY: lint .PHONY: lint
lint: isort black flake8 mypy lint:
$(py) isort --check-only $(code_dir)
$(py) black --check --diff $(code_dir)
$(py) flake8 $(code_dir)
$(py) mypy $(package_dir)
.PHONY: reformat
reformat:
$(py) black $(code_dir)
$(py) isort $(code_dir)
# ================================================================================================= # =================================================================================================
# Tests # Tests
@ -99,12 +84,12 @@ lint: isort black flake8 mypy
.PHONY: test .PHONY: test
test: test:
$(py) pytest --cov=aiogram --cov-config .coveragerc tests/ $(py) pytest --cov=aiogram --cov-config .coveragerc tests/ --redis $(redis_connection)
.PHONY: test-coverage .PHONY: test-coverage
test-coverage: test-coverage:
mkdir -p $(reports_dir)/tests/ mkdir -p $(reports_dir)/tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ $(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ --redis $(redis_connection)
.PHONY: test-coverage-report .PHONY: test-coverage-report
test-coverage-report: test-coverage-report:

View file

@ -6,6 +6,8 @@ from .dispatcher import filters, handler
from .dispatcher.dispatcher import Dispatcher from .dispatcher.dispatcher import Dispatcher
from .dispatcher.middlewares.base import BaseMiddleware from .dispatcher.middlewares.base import BaseMiddleware
from .dispatcher.router import Router from .dispatcher.router import Router
from .utils.text_decorations import html_decoration as _html_decoration
from .utils.text_decorations import markdown_decoration as _markdown_decoration
try: try:
import uvloop as _uvloop import uvloop as _uvloop
@ -15,6 +17,8 @@ except ImportError: # pragma: no cover
pass pass
F = MagicFilter() F = MagicFilter()
html = _html_decoration
md = _markdown_decoration
__all__ = ( __all__ = (
"__api_version__", "__api_version__",
@ -29,6 +33,8 @@ __all__ = (
"filters", "filters",
"handler", "handler",
"F", "F",
"html",
"md",
) )
__version__ = "3.0.0-alpha.8" __version__ = "3.0.0-alpha.8"

View file

@ -4,7 +4,7 @@ import asyncio
import contextvars import contextvars
import warnings import warnings
from asyncio import CancelledError, Future, Lock from asyncio import CancelledError, Future, Lock
from typing import Any, AsyncGenerator, Dict, Optional, Union, cast from typing import Any, AsyncGenerator, Dict, Optional, Union
from .. import loggers from .. import loggers
from ..client.bot import Bot from ..client.bot import Bot
@ -13,7 +13,6 @@ from ..types import TelegramObject, Update, User
from ..utils.exceptions.base import TelegramAPIError from ..utils.exceptions.base import TelegramAPIError
from .event.bases import UNHANDLED, SkipHandler from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver from .event.telegram import TelegramEventObserver
from .fsm.context import FSMContext
from .fsm.middleware import FSMContextMiddleware from .fsm.middleware import FSMContextMiddleware
from .fsm.storage.base import BaseStorage from .fsm.storage.base import BaseStorage
from .fsm.storage.memory import MemoryStorage from .fsm.storage.memory import MemoryStorage
@ -32,7 +31,7 @@ class Dispatcher(Router):
self, self,
storage: Optional[BaseStorage] = None, storage: Optional[BaseStorage] = None,
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT, fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
isolate_events: bool = True, isolate_events: bool = False,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super(Dispatcher, self).__init__(**kwargs) super(Dispatcher, self).__init__(**kwargs)
@ -255,7 +254,9 @@ class Dispatcher(Router):
) )
return True # because update was processed but unsuccessful return True # because update was processed but unsuccessful
async def _polling(self, bot: Bot, polling_timeout: int = 30, **kwargs: Any) -> None: async def _polling(
self, bot: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
""" """
Internal polling process Internal polling process
@ -264,7 +265,11 @@ class Dispatcher(Router):
:return: :return:
""" """
async for update in self._listen_updates(bot, polling_timeout=polling_timeout): async for update in self._listen_updates(bot, polling_timeout=polling_timeout):
await self._process_update(bot=bot, update=update, **kwargs) handle_update = self._process_update(bot=bot, update=update, **kwargs)
if handle_as_tasks:
asyncio.create_task(handle_update)
else:
await handle_update
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any: async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
""" """
@ -342,11 +347,15 @@ class Dispatcher(Router):
return None return None
async def start_polling(self, *bots: Bot, polling_timeout: int = 10, **kwargs: Any) -> None: async def start_polling(
self, *bots: Bot, polling_timeout: int = 10, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
""" """
Polling runner Polling runner
:param bots: :param bots:
:param polling_timeout:
:param handle_as_tasks:
:param kwargs: :param kwargs:
:return: :return:
""" """
@ -363,7 +372,12 @@ class Dispatcher(Router):
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name "Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
) )
coro_list.append( coro_list.append(
self._polling(bot=bot, polling_timeout=polling_timeout, **kwargs) self._polling(
bot=bot,
handle_as_tasks=handle_as_tasks,
polling_timeout=polling_timeout,
**kwargs,
)
) )
await asyncio.gather(*coro_list) await asyncio.gather(*coro_list)
finally: finally:
@ -372,22 +386,27 @@ class Dispatcher(Router):
loggers.dispatcher.info("Polling stopped") loggers.dispatcher.info("Polling stopped")
await self.emit_shutdown(**workflow_data) await self.emit_shutdown(**workflow_data)
def run_polling(self, *bots: Bot, polling_timeout: int = 30, **kwargs: Any) -> None: def run_polling(
self, *bots: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
""" """
Run many bots with polling Run many bots with polling
:param bots: Bot instances :param bots: Bot instances
:param polling_timeout: Poling timeout :param polling_timeout: Poling timeout
:param handle_as_tasks: Run task for each event and no wait result
:param kwargs: contextual data :param kwargs: contextual data
:return: :return:
""" """
try: try:
return asyncio.run( return asyncio.run(
self.start_polling(*bots, **kwargs, polling_timeout=polling_timeout) self.start_polling(
*bots,
**kwargs,
polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks,
)
) )
except (KeyboardInterrupt, SystemExit): # pragma: no cover except (KeyboardInterrupt, SystemExit): # pragma: no cover
# Allow to graceful shutdown # Allow to graceful shutdown
pass pass
def current_state(self, chat_id: int, user_id: int) -> FSMContext:
return cast(FSMContext, self.fsm.resolve_context(chat_id=chat_id, user_id=user_id))

View file

@ -1,18 +1,24 @@
from __future__ import annotations from __future__ import annotations
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast from typing import Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast
from pydantic import validator from magic_filter import MagicFilter
from pydantic import Field, validator
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.filters import BaseFilter from aiogram.dispatcher.filters import BaseFilter
from aiogram.types import Message from aiogram.types import Message
from aiogram.utils.deep_linking import decode_payload
CommandPatterType = Union[str, re.Pattern] CommandPatterType = Union[str, re.Pattern]
class CommandException(Exception):
pass
class Command(BaseFilter): class Command(BaseFilter):
""" """
This filter can be helpful for handling commands from the text messages. This filter can be helpful for handling commands from the text messages.
@ -29,6 +35,8 @@ class Command(BaseFilter):
"""Ignore case (Does not work with regexp, use flags instead)""" """Ignore case (Does not work with regexp, use flags instead)"""
commands_ignore_mention: bool = False commands_ignore_mention: bool = False
"""Ignore bot mention. By default bot can not handle commands intended for other bots""" """Ignore bot mention. By default bot can not handle commands intended for other bots"""
command_magic: Optional[MagicFilter] = None
"""Validate command object via Magic filter after all checks done"""
@validator("commands", always=True) @validator("commands", always=True)
def _validate_commands( def _validate_commands(
@ -39,12 +47,54 @@ class Command(BaseFilter):
return value return value
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]: async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
if not message.text: text = message.text or message.caption
if not text:
return False return False
return await self.parse_command(text=message.text, bot=bot) try:
command = await self.parse_command(text=cast(str, message.text), bot=bot)
except CommandException:
return False
return {"command": command}
async def parse_command(self, text: str, bot: Bot) -> Union[bool, Dict[str, CommandObject]]: def extract_command(self, text: str) -> CommandObject:
# First step: separate command with arguments
# "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"]
try:
full_command, *args = text.split(maxsplit=1)
except ValueError:
raise CommandException("not enough values to unpack")
# Separate command into valuable parts
# "/command@mention" -> "/", ("command", "@", "mention")
prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@")
return CommandObject(
prefix=prefix, command=command, mention=mention, args=args[0] if args else None
)
def validate_prefix(self, command: CommandObject) -> None:
if command.prefix not in self.commands_prefix:
raise CommandException("Invalid command prefix")
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
if command.mention and not self.commands_ignore_mention:
me = await bot.me()
if me.username and command.mention.lower() != me.username.lower():
raise CommandException("Mention did not match")
def validate_command(self, command: CommandObject) -> CommandObject:
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
# Command can be presented as regexp pattern or raw string
# then need to validate that in different ways
if isinstance(allowed_command, Pattern): # Regexp
result = allowed_command.match(command.command)
if result:
return replace(command, match=result)
elif command.command == allowed_command: # String
return command
raise CommandException("Command did not match pattern")
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
""" """
Extract command from the text and validate Extract command from the text and validate
@ -52,56 +102,18 @@ class Command(BaseFilter):
:param bot: :param bot:
:return: :return:
""" """
if not text.strip(): command = self.extract_command(text)
return False self.validate_prefix(command=command)
await self.validate_mention(bot=bot, command=command)
command = self.validate_command(command)
self.do_magic(command=command)
return command
# First step: separate command with arguments def do_magic(self, command: CommandObject) -> None:
# "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"] if not self.command_magic:
full_command, *args = text.split(maxsplit=1) return
if not self.command_magic.resolve(command):
# Separate command into valuable parts raise CommandException("Rejected via magic filter")
# "/command@mention" -> "/", ("command", "@", "mention")
prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@")
# Validate prefixes
if prefix not in self.commands_prefix:
return False
# Validate mention
if mention and not self.commands_ignore_mention:
me = await bot.me()
if me.username and mention.lower() != me.username.lower():
return False
# Validate command
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
# Command can be presented as regexp pattern or raw string
# then need to validate that in different ways
if isinstance(allowed_command, Pattern): # Regexp
result = allowed_command.match(command)
if result:
return {
"command": CommandObject(
prefix=prefix,
command=command,
mention=mention,
args=args[0] if args else None,
match=result,
)
}
elif command == allowed_command: # String
return {
"command": CommandObject(
prefix=prefix,
command=command,
mention=mention,
args=args[0] if args else None,
match=None,
)
}
return False
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -143,3 +155,40 @@ class CommandObject:
if self.args: if self.args:
line += " " + self.args line += " " + self.args
return line return line
class CommandStart(Command):
commands: Tuple[str] = Field(("start",), const=True)
commands_prefix: str = Field("/", const=True)
deep_link: bool = False
deep_link_encoded: bool = False
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
"""
Extract command from the text and validate
:param text:
:param bot:
:return:
"""
command = self.extract_command(text)
self.validate_prefix(command=command)
await self.validate_mention(bot=bot, command=command)
command = self.validate_command(command)
command = self.validate_deeplink(command=command)
self.do_magic(command=command)
return command
def validate_deeplink(self, command: CommandObject) -> CommandObject:
if not self.deep_link:
return command
if not command.args:
raise CommandException("Deep-link was missing")
args = command.args
if self.deep_link_encoded:
try:
args = decode_payload(args)
except UnicodeDecodeError as e:
raise CommandException(f"Failed to decode Base64: {e}")
return replace(command, args=args)
return command

View file

@ -1,25 +1,35 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
class FSMContext: class FSMContext:
def __init__(self, storage: BaseStorage, chat_id: int, user_id: int) -> None: def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
self.bot = bot
self.storage = storage self.storage = storage
self.chat_id = chat_id self.chat_id = chat_id
self.user_id = user_id self.user_id = user_id
async def set_state(self, state: StateType = None) -> None: async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(chat_id=self.chat_id, user_id=self.user_id, state=state) await self.storage.set_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
)
async def get_state(self) -> Optional[str]: async def get_state(self) -> Optional[str]:
return await self.storage.get_state(chat_id=self.chat_id, user_id=self.user_id) return await self.storage.get_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def set_data(self, data: Dict[str, Any]) -> None: async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data(chat_id=self.chat_id, user_id=self.user_id, data=data) await self.storage.set_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
)
async def get_data(self) -> Dict[str, Any]: async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data(chat_id=self.chat_id, user_id=self.user_id) return await self.storage.get_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def update_data( async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
@ -27,7 +37,7 @@ class FSMContext:
if data: if data:
kwargs.update(data) kwargs.update(data)
return await self.storage.update_data( return await self.storage.update_data(
chat_id=self.chat_id, user_id=self.user_id, data=kwargs bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
) )
async def clear(self) -> None: async def clear(self) -> None:

View file

@ -1,5 +1,6 @@
from typing import Any, Awaitable, Callable, Dict, Optional from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot
from aiogram.dispatcher.fsm.context import FSMContext from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import BaseStorage from aiogram.dispatcher.fsm.storage.base import BaseStorage
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
@ -24,24 +25,27 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
event: Update, event: Update,
data: Dict[str, Any], data: Dict[str, Any],
) -> Any: ) -> Any:
context = self.resolve_event_context(data) bot: Bot = cast(Bot, data["bot"])
context = self.resolve_event_context(bot, data)
data["fsm_storage"] = self.storage data["fsm_storage"] = self.storage
if context: if context:
data.update({"state": context, "raw_state": await context.get_state()}) data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events: if self.isolate_events:
async with self.storage.lock(chat_id=context.chat_id, user_id=context.user_id): async with self.storage.lock(
bot=bot, chat_id=context.chat_id, user_id=context.user_id
):
return await handler(event, data) return await handler(event, data)
return await handler(event, data) return await handler(event, data)
def resolve_event_context(self, data: Dict[str, Any]) -> Optional[FSMContext]: def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
user = data.get("event_from_user") user = data.get("event_from_user")
chat = data.get("event_chat") chat = data.get("event_chat")
chat_id = chat.id if chat else None chat_id = chat.id if chat else None
user_id = user.id if user else None user_id = user.id if user else None
return self.resolve_context(chat_id=chat_id, user_id=user_id) return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
def resolve_context( def resolve_context(
self, chat_id: Optional[int], user_id: Optional[int] self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
) -> Optional[FSMContext]: ) -> Optional[FSMContext]:
if chat_id is None: if chat_id is None:
chat_id = user_id chat_id = user_id
@ -50,8 +54,8 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
chat_id, user_id = apply_strategy( chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy chat_id=chat_id, user_id=user_id, strategy=self.strategy
) )
return self.get_context(chat_id=chat_id, user_id=user_id) return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
return None return None
def get_context(self, chat_id: int, user_id: int) -> FSMContext: def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id) return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)

View file

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Union from typing import Any, AsyncGenerator, Dict, Optional, Union
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State from aiogram.dispatcher.fsm.state import State
StateType = Optional[Union[str, State]] StateType = Optional[Union[str, State]]
@ -11,34 +12,42 @@ class BaseStorage(ABC):
@abstractmethod @abstractmethod
@asynccontextmanager @asynccontextmanager
async def lock( async def lock(
self, chat_id: int, user_id: int self, bot: Bot, chat_id: int, user_id: int
) -> AsyncGenerator[None, None]: # pragma: no cover ) -> AsyncGenerator[None, None]: # pragma: no cover
yield None yield None
@abstractmethod @abstractmethod
async def set_state( async def set_state(
self, chat_id: int, user_id: int, state: StateType = None self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None: # pragma: no cover ) -> None: # pragma: no cover
pass pass
@abstractmethod @abstractmethod
async def get_state(self, chat_id: int, user_id: int) -> Optional[str]: # pragma: no cover async def get_state(
self, bot: Bot, chat_id: int, user_id: int
) -> Optional[str]: # pragma: no cover
pass pass
@abstractmethod @abstractmethod
async def set_data( async def set_data(
self, chat_id: int, user_id: int, data: Dict[str, Any] self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> None: # pragma: no cover ) -> None: # pragma: no cover
pass pass
@abstractmethod @abstractmethod
async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]: # pragma: no cover async def get_data(
self, bot: Bot, chat_id: int, user_id: int
) -> Dict[str, Any]: # pragma: no cover
pass pass
async def update_data( async def update_data(
self, chat_id: int, user_id: int, data: Dict[str, Any] self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
current_data = await self.get_data(chat_id=chat_id, user_id=user_id) current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
current_data.update(data) current_data.update(data)
await self.set_data(chat_id=chat_id, user_id=user_id, data=current_data) await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
return current_data.copy() return current_data.copy()
@abstractmethod
async def close(self) -> None: # pragma: no cover
pass

View file

@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
@ -17,23 +18,30 @@ class MemoryStorageRecord:
class MemoryStorage(BaseStorage): class MemoryStorage(BaseStorage):
def __init__(self) -> None: def __init__(self) -> None:
self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict( self.storage: DefaultDict[
lambda: defaultdict(MemoryStorageRecord) Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
) ] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
async def close(self) -> None:
pass
@asynccontextmanager @asynccontextmanager
async def lock(self, chat_id: int, user_id: int) -> AsyncGenerator[None, None]: async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
async with self.storage[chat_id][user_id].lock: async with self.storage[bot][chat_id][user_id].lock:
yield None yield None
async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None: async def set_state(
self.storage[chat_id][user_id].state = state.state if isinstance(state, State) else state self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
self.storage[bot][chat_id][user_id].state = (
state.state if isinstance(state, State) else state
)
async def get_state(self, chat_id: int, user_id: int) -> Optional[str]: async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
return self.storage[chat_id][user_id].state return self.storage[bot][chat_id][user_id].state
async def set_data(self, chat_id: int, user_id: int, data: Dict[str, Any]) -> None: async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
self.storage[chat_id][user_id].data = data.copy() self.storage[bot][chat_id][user_id].data = data.copy()
async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]: async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
return self.storage[chat_id][user_id].data.copy() return self.storage[bot][chat_id][user_id].data.copy()

View file

@ -0,0 +1,101 @@
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
from aioredis import ConnectionPool, Redis
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
PrefixFactoryType = Callable[[Bot], str]
STATE_KEY = "state"
STATE_DATA_KEY = "data"
STATE_LOCK_KEY = "lock"
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
class RedisStorage(BaseStorage):
def __init__(
self,
redis: Redis,
prefix: str = "fsm",
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
state_ttl: Optional[int] = None,
data_ttl: Optional[int] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis
self.prefix = prefix
self.prefix_bot = prefix_bot
self.state_ttl = state_ttl
self.data_ttl = data_ttl
self.lock_kwargs = lock_kwargs
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> "RedisStorage":
if connection_kwargs is None:
connection_kwargs = {}
pool = ConnectionPool.from_url(url, **connection_kwargs)
redis = Redis(connection_pool=pool)
return cls(redis=redis, **kwargs)
async def close(self) -> None:
await self.redis.close()
def generate_key(self, bot: Bot, *parts: Any) -> str:
prefix_parts = [self.prefix]
if self.prefix_bot:
if isinstance(self.prefix_bot, dict):
prefix_parts.append(self.prefix_bot[bot.id])
elif callable(self.prefix_bot):
prefix_parts.append(self.prefix_bot(bot))
else:
prefix_parts.append(str(bot.id))
prefix_parts.extend(parts)
return ":".join(map(str, prefix_parts))
@asynccontextmanager
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
key = self.generate_key(bot, chat_id, user_id, STATE_LOCK_KEY)
async with self.redis.lock(name=key, **self.lock_kwargs):
yield None
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
if state is None:
await self.redis.delete(key)
else:
await self.redis.set(
key, state.state if isinstance(state, State) else state, ex=self.state_ttl
)
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
value = await self.redis.get(key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
if not data:
await self.redis.delete(key)
return
json_data = bot.session.json_dumps(data)
await self.redis.set(key, json_data, ex=self.data_ttl)
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
value = await self.redis.get(key)
if value is None:
return {}
if isinstance(value, bytes):
value = value.decode("utf-8")
return cast(Dict[str, Any], bot.session.json_loads(value))

View file

@ -0,0 +1,34 @@
import hashlib
import hmac
from typing import Any, Dict
def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
"""
Generate hexadecimal representation
of the HMAC-SHA-256 signature of the data-check-string
with the SHA256 hash of the bot's token used as a secret key
:param token:
:param hash:
:param kwargs: all params received on auth
:return:
"""
secret = hashlib.sha256(token.encode("utf-8"))
check_string = "\n".join(map(lambda k: f"{k}={kwargs[k]}", sorted(kwargs)))
hmac_string = hmac.new(
secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256
).hexdigest()
return hmac_string == hash
def check_integrity(token: str, data: Dict[str, Any]) -> bool:
"""
Verify the authentication and the integrity
of the data received on user's auth
:param token: Bot's token
:param data: all data that came on auth
:return:
"""
return check_signature(token, **data)

View file

@ -0,0 +1,131 @@
"""
Deep linking
Telegram bots have a deep linking mechanism, that allows for passing
additional parameters to the bot on startup. It could be a command that
launches the bot or an auth token to connect the user's Telegram
account to their account on some external service.
You can read detailed description in the source:
https://core.telegram.org/bots#deep-linking
We have add some utils to get deep links more handy.
Basic link example:
.. code-block:: python
from aiogram.utils.deep_linking import get_start_link
link = await get_start_link('foo')
# result: 'https://t.me/MyBot?start=foo'
Encoded link example:
.. code-block:: python
from aiogram.utils.deep_linking import get_start_link
link = await get_start_link('foo', encode=True)
# result: 'https://t.me/MyBot?start=Zm9v'
Decode it back example:
.. code-block:: python
from aiogram.utils.deep_linking import decode_payload
from aiogram.types import Message
@dp.message_handler(commands=["start"])
async def handler(message: Message):
args = message.get_args()
payload = decode_payload(args)
await message.answer(f"Your payload: {payload}")
"""
import re
from base64 import urlsafe_b64decode, urlsafe_b64encode
from typing import Literal, cast
from aiogram import Bot
from aiogram.utils.link import create_telegram_link
BAD_PATTERN = re.compile(r"[^_A-z0-9-]")
async def create_start_link(bot: Bot, payload: str, encode: bool = False) -> str:
"""
Create 'start' deep link with your payload.
If you need to encode payload or pass special characters -
set encode as True
:param bot: bot instance
:param payload: args passed with /start
:param encode: encode payload with base64url
:return: link
"""
username = (await bot.me()).username
return create_deep_link(username=username, link_type="start", payload=payload, encode=encode)
async def create_startgroup_link(bot: Bot, payload: str, encode: bool = False) -> str:
"""
Create 'startgroup' deep link with your payload.
If you need to encode payload or pass special characters -
set encode as True
:param bot: bot instance
:param payload: args passed with /start
:param encode: encode payload with base64url
:return: link
"""
username = (await bot.me()).username
return create_deep_link(
username=username, link_type="startgroup", payload=payload, encode=encode
)
def create_deep_link(
username: str, link_type: Literal["start", "startgroup"], payload: str, encode: bool = False
) -> str:
"""
Create deep link.
:param username:
:param link_type: `start` or `startgroup`
:param payload: any string-convertible data
:param encode: pass True to encode the payload
:return: deeplink
"""
if not isinstance(payload, str):
payload = str(payload)
if encode:
payload = encode_payload(payload)
if re.search(BAD_PATTERN, payload):
raise ValueError(
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
"Pass `encode=True` or encode payload manually."
)
if len(payload) > 64:
raise ValueError("Payload must be up to 64 characters long.")
return create_telegram_link(username, **{cast(str, link_type): payload})
def encode_payload(payload: str) -> str:
"""Encode payload with URL-safe base64url."""
payload = str(payload)
bytes_payload: bytes = urlsafe_b64encode(payload.encode())
str_payload = bytes_payload.decode()
return str_payload.replace("=", "")
def decode_payload(payload: str) -> str:
"""Decode payload with URL-safe base64url."""
payload += "=" * (4 - len(payload) % 4)
result: bytes = urlsafe_b64decode(payload)
return result.decode()

View file

@ -2,13 +2,27 @@ from __future__ import annotations
from itertools import chain from itertools import chain
from itertools import cycle as repeat_all from itertools import cycle as repeat_all
from typing import Any, Generator, Generic, Iterable, List, Optional, Type, TypeVar, Union from typing import (
TYPE_CHECKING,
Any,
Generator,
Generic,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
no_type_check,
)
from aiogram.dispatcher.filters.callback_data import CallbackData from aiogram.dispatcher.filters.callback_data import CallbackData
from aiogram.types import ( from aiogram.types import (
CallbackGame,
InlineKeyboardButton, InlineKeyboardButton,
InlineKeyboardMarkup, InlineKeyboardMarkup,
KeyboardButton, KeyboardButton,
LoginUrl,
ReplyKeyboardMarkup, ReplyKeyboardMarkup,
) )
@ -239,3 +253,28 @@ def repeat_last(items: Iterable[T]) -> Generator[T, None, None]:
except StopIteration: except StopIteration:
finished = True finished = True
yield value yield value
class InlineKeyboardConstructor(KeyboardConstructor[InlineKeyboardButton]):
if TYPE_CHECKING: # pragma: no cover
@no_type_check
def button(
self,
text: str,
url: Optional[str] = None,
login_url: Optional[LoginUrl] = None,
callback_data: Optional[Union[str, CallbackData]] = None,
switch_inline_query: Optional[str] = None,
switch_inline_query_current_chat: Optional[str] = None,
callback_game: Optional[CallbackGame] = None,
pay: Optional[bool] = None,
**kwargs: Any,
) -> "KeyboardConstructor[InlineKeyboardButton]":
...
def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup:
...
def __init__(self) -> None:
super().__init__(InlineKeyboardButton)

18
aiogram/utils/link.py Normal file
View file

@ -0,0 +1,18 @@
from typing import Any
from urllib.parse import urlencode, urljoin
def create_tg_link(link: str, **kwargs: Any) -> str:
url = f"tg://{link}"
if kwargs:
query = urlencode(kwargs)
url += f"?{query}"
return url
def create_telegram_link(uri: str, **kwargs: Any) -> str:
url = urljoin("https://t.me", uri)
if kwargs:
query = urlencode(query=kwargs)
url += f"?{query}"
return url

View file

@ -183,7 +183,7 @@ class MarkdownDecoration(TextDecoration):
return f"`{value}`" return f"`{value}`"
def pre(self, value: str) -> str: def pre(self, value: str) -> str:
return f"```{value}```" return f"```\n{value}\n```"
def pre_language(self, value: str, language: str) -> str: def pre_language(self, value: str, language: str) -> str:
return f"```{language}\n{value}\n```" return f"```{language}\n{value}\n```"

View file

@ -1,6 +1,6 @@
[mypy] [mypy]
;plugins = pydantic.mypy ;plugins = pydantic.mypy
python_version = 3.7 python_version = 3.8
show_error_codes = True show_error_codes = True
show_error_context = True show_error_context = True
pretty = True pretty = True
@ -29,3 +29,6 @@ ignore_missing_imports = True
[mypy-uvloop] [mypy-uvloop]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-aioredis]
ignore_missing_imports = True

106
poetry.lock generated
View file

@ -38,6 +38,21 @@ aiohttp = ">=2.3.2"
attrs = ">=19.2.0" attrs = ">=19.2.0"
python-socks = {version = ">=1.0.1", extras = ["asyncio"]} python-socks = {version = ">=1.0.1", extras = ["asyncio"]}
[[package]]
name = "aioredis"
version = "2.0.0a1"
description = "asyncio (PEP 3156) Redis support"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
async-timeout = "*"
typing-extensions = "*"
[package.extras]
hiredis = ["hiredis (>=1.0)"]
[[package]] [[package]]
name = "alabaster" name = "alabaster"
version = "0.7.12" version = "0.7.12"
@ -156,7 +171,7 @@ lxml = ["lxml"]
[[package]] [[package]]
name = "black" name = "black"
version = "21.5b1" version = "21.6b0"
description = "The uncompromising code formatter." description = "The uncompromising code formatter."
category = "dev" category = "dev"
optional = false optional = false
@ -172,8 +187,9 @@ toml = ">=0.10.1"
[package.extras] [package.extras]
colorama = ["colorama (>=0.4.3)"] colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.6.0)", "aiohttp-cors"] d = ["aiohttp (>=3.6.0)", "aiohttp-cors (>=0.4.0)"]
python2 = ["typed-ast (>=1.4.2)"] python2 = ["typed-ast (>=1.4.2)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]] [[package]]
name = "cfgv" name = "cfgv"
@ -218,9 +234,6 @@ category = "dev"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
[package.dependencies]
toml = {version = "*", optional = true, markers = "extra == \"toml\""}
[package.extras] [package.extras]
toml = ["toml"] toml = ["toml"]
@ -234,7 +247,7 @@ python-versions = ">=3.5"
[[package]] [[package]]
name = "distlib" name = "distlib"
version = "0.3.1" version = "0.3.2"
description = "Distribution utilities" description = "Distribution utilities"
category = "dev" category = "dev"
optional = false optional = false
@ -301,7 +314,7 @@ test = ["pytest", "pytest-cov", "pytest-xdist"]
[[package]] [[package]]
name = "identify" name = "identify"
version = "2.2.5" version = "2.2.10"
description = "File identification library for Python" description = "File identification library for Python"
category = "dev" category = "dev"
optional = false optional = false
@ -312,11 +325,11 @@ license = ["editdistance-s"]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.1" version = "3.2"
description = "Internationalized Domain Names in Applications (IDNA)" description = "Internationalized Domain Names in Applications (IDNA)"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=3.4" python-versions = ">=3.5"
[[package]] [[package]]
name = "imagesize" name = "imagesize"
@ -328,7 +341,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]] [[package]]
name = "importlib-metadata" name = "importlib-metadata"
version = "4.0.1" version = "4.5.0"
description = "Read metadata from Python packages" description = "Read metadata from Python packages"
category = "dev" category = "dev"
optional = false optional = false
@ -351,7 +364,7 @@ python-versions = "*"
[[package]] [[package]]
name = "ipython" name = "ipython"
version = "7.23.1" version = "7.24.1"
description = "IPython: Productive Interactive Computing" description = "IPython: Productive Interactive Computing"
category = "dev" category = "dev"
optional = false optional = false
@ -371,7 +384,7 @@ pygments = "*"
traitlets = ">=4.2" traitlets = ">=4.2"
[package.extras] [package.extras]
all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.16)", "pygments", "qtconsole", "requests", "testpath"] all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.17)", "pygments", "qtconsole", "requests", "testpath"]
doc = ["Sphinx (>=1.3)"] doc = ["Sphinx (>=1.3)"]
kernel = ["ipykernel"] kernel = ["ipykernel"]
nbconvert = ["nbconvert"] nbconvert = ["nbconvert"]
@ -379,7 +392,7 @@ nbformat = ["nbformat"]
notebook = ["notebook", "ipywidgets"] notebook = ["notebook", "ipywidgets"]
parallel = ["ipyparallel"] parallel = ["ipyparallel"]
qtconsole = ["qtconsole"] qtconsole = ["qtconsole"]
test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.16)"] test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.17)"]
[[package]] [[package]]
name = "ipython-genutils" name = "ipython-genutils"
@ -739,18 +752,19 @@ testing = ["coverage", "hypothesis (>=5.7.1)"]
[[package]] [[package]]
name = "pytest-cov" name = "pytest-cov"
version = "2.12.0" version = "2.12.1"
description = "Pytest plugin for measuring coverage." description = "Pytest plugin for measuring coverage."
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[package.dependencies] [package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]} coverage = ">=5.2.1"
pytest = ">=4.6" pytest = ">=4.6"
toml = "*"
[package.extras] [package.extras]
testing = ["fields", "hunter", "process-tests (==2.0.2)", "six", "pytest-xdist", "virtualenv"] testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
[[package]] [[package]]
name = "pytest-html" name = "pytest-html"
@ -764,6 +778,17 @@ python-versions = ">=3.6"
pytest = ">=5.0,<6.0.0 || >6.0.0" pytest = ">=5.0,<6.0.0 || >6.0.0"
pytest-metadata = "*" pytest-metadata = "*"
[[package]]
name = "pytest-lazy-fixture"
version = "0.6.3"
description = "It helps to use fixtures in pytest.mark.parametrize"
category = "dev"
optional = false
python-versions = "*"
[package.dependencies]
pytest = ">=3.2.5"
[[package]] [[package]]
name = "pytest-metadata" name = "pytest-metadata"
version = "1.11.0" version = "1.11.0"
@ -930,17 +955,17 @@ test = ["pytest", "pytest-cov"]
[[package]] [[package]]
name = "sphinx-copybutton" name = "sphinx-copybutton"
version = "0.3.1" version = "0.3.2"
description = "Add a copy button to each of your code cells." description = "Add a copy button to each of your code cells."
category = "main" category = "main"
optional = false optional = false
python-versions = "*" python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
sphinx = ">=1.8" sphinx = ">=1.8"
[package.extras] [package.extras]
code_style = ["flake8 (>=3.7.0,<3.8.0)", "black", "pre-commit (==1.17.0)"] code_style = ["pre-commit (==2.12.1)"]
[[package]] [[package]]
name = "sphinx-intl" name = "sphinx-intl"
@ -1171,11 +1196,12 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pyt
docs = ["sphinx", "sphinx-intl", "sphinx-autobuild", "sphinx-copybutton", "furo", "sphinx-prompt", "Sphinx-Substitution-Extensions"] docs = ["sphinx", "sphinx-intl", "sphinx-autobuild", "sphinx-copybutton", "furo", "sphinx-prompt", "Sphinx-Substitution-Extensions"]
fast = [] fast = []
proxy = ["aiohttp-socks"] proxy = ["aiohttp-socks"]
redis = ["aioredis"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "2fcd44a8937b3ea48196c8eba8ceb0533281af34c884103bcc5b4f5f16b817d5" content-hash = "362a6caf937b1c457599cbf2cd5d000eab4cac529bd7fe8c257ae713ebc63331"
[metadata.files] [metadata.files]
aiofiles = [ aiofiles = [
@ -1225,6 +1251,10 @@ aiohttp-socks = [
{file = "aiohttp_socks-0.5.5-py3-none-any.whl", hash = "sha256:faaa25ed4dc34440ca888d23e089420f3b1918dc4ecf062c3fd9474827ad6a39"}, {file = "aiohttp_socks-0.5.5-py3-none-any.whl", hash = "sha256:faaa25ed4dc34440ca888d23e089420f3b1918dc4ecf062c3fd9474827ad6a39"},
{file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"}, {file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"},
] ]
aioredis = [
{file = "aioredis-2.0.0a1-py3-none-any.whl", hash = "sha256:32d7910724282a475c91b8b34403867069a4f07bf0c5ad5fe66cd797322f9a0d"},
{file = "aioredis-2.0.0a1.tar.gz", hash = "sha256:5884f384b8ecb143bb73320a96e7c464fd38e117950a7d48340a35db8e35e7d2"},
]
alabaster = [ alabaster = [
{file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"}, {file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"},
{file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"}, {file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"},
@ -1274,8 +1304,8 @@ beautifulsoup4 = [
{file = "beautifulsoup4-4.9.3.tar.gz", hash = "sha256:84729e322ad1d5b4d25f805bfa05b902dd96450f43842c4e99067d5e1369eb25"}, {file = "beautifulsoup4-4.9.3.tar.gz", hash = "sha256:84729e322ad1d5b4d25f805bfa05b902dd96450f43842c4e99067d5e1369eb25"},
] ]
black = [ black = [
{file = "black-21.5b1-py3-none-any.whl", hash = "sha256:8a60071a0043876a4ae96e6c69bd3a127dad2c1ca7c8083573eb82f92705d008"}, {file = "black-21.6b0-py3-none-any.whl", hash = "sha256:dfb8c5a069012b2ab1e972e7b908f5fb42b6bbabcba0a788b86dc05067c7d9c7"},
{file = "black-21.5b1.tar.gz", hash = "sha256:23695358dbcb3deafe7f0a3ad89feee5999a46be5fec21f4f1d108be0bcdb3b1"}, {file = "black-21.6b0.tar.gz", hash = "sha256:dc132348a88d103016726fe360cb9ede02cecf99b76e3660ce6c596be132ce04"},
] ]
cfgv = [ cfgv = [
{file = "cfgv-3.3.0-py2.py3-none-any.whl", hash = "sha256:b449c9c6118fe8cca7fa5e00b9ec60ba08145d281d52164230a69211c5d597a1"}, {file = "cfgv-3.3.0-py2.py3-none-any.whl", hash = "sha256:b449c9c6118fe8cca7fa5e00b9ec60ba08145d281d52164230a69211c5d597a1"},
@ -1352,8 +1382,8 @@ decorator = [
{file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"}, {file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"},
] ]
distlib = [ distlib = [
{file = "distlib-0.3.1-py2.py3-none-any.whl", hash = "sha256:8c09de2c67b3e7deef7184574fc060ab8a793e7adbb183d942c389c8b13c52fb"}, {file = "distlib-0.3.2-py2.py3-none-any.whl", hash = "sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c"},
{file = "distlib-0.3.1.zip", hash = "sha256:edf6116872c863e1aa9d5bb7cb5e05a022c519a4594dc703843343a9ddd9bff1"}, {file = "distlib-0.3.2.zip", hash = "sha256:106fef6dc37dd8c0e2c0a60d3fca3e77460a48907f335fa28420463a6f799736"},
] ]
docutils = [ docutils = [
{file = "docutils-0.17.1-py2.py3-none-any.whl", hash = "sha256:cf316c8370a737a022b72b56874f6602acf974a37a9fba42ec2876387549fc61"}, {file = "docutils-0.17.1-py2.py3-none-any.whl", hash = "sha256:cf316c8370a737a022b72b56874f6602acf974a37a9fba42ec2876387549fc61"},
@ -1376,28 +1406,28 @@ furo = [
{file = "furo-2020.12.30b24.tar.gz", hash = "sha256:30171899c9c06d692a778e6daf6cb2e5cbb05efc6006e1692e5e776007dc8a8c"}, {file = "furo-2020.12.30b24.tar.gz", hash = "sha256:30171899c9c06d692a778e6daf6cb2e5cbb05efc6006e1692e5e776007dc8a8c"},
] ]
identify = [ identify = [
{file = "identify-2.2.5-py2.py3-none-any.whl", hash = "sha256:9c3ab58543c03bd794a1735e4552ef6dec49ec32053278130d525f0982447d47"}, {file = "identify-2.2.10-py2.py3-none-any.whl", hash = "sha256:18d0c531ee3dbc112fa6181f34faa179de3f57ea57ae2899754f16a7e0ff6421"},
{file = "identify-2.2.5.tar.gz", hash = "sha256:bc1705694253763a3160b943316867792ec00ba7a0ee40b46e20aebaf4e0c46a"}, {file = "identify-2.2.10.tar.gz", hash = "sha256:5b41f71471bc738e7b586308c3fca172f78940195cb3bf6734c1e66fdac49306"},
] ]
idna = [ idna = [
{file = "idna-3.1-py3-none-any.whl", hash = "sha256:5205d03e7bcbb919cc9c19885f9920d622ca52448306f2377daede5cf3faac16"}, {file = "idna-3.2-py3-none-any.whl", hash = "sha256:14475042e284991034cb48e06f6851428fb14c4dc953acd9be9a5e95c7b6dd7a"},
{file = "idna-3.1.tar.gz", hash = "sha256:c5b02147e01ea9920e6b0a3f1f7bb833612d507592c837a6c49552768f4054e1"}, {file = "idna-3.2.tar.gz", hash = "sha256:467fbad99067910785144ce333826c71fb0e63a425657295239737f7ecd125f3"},
] ]
imagesize = [ imagesize = [
{file = "imagesize-1.2.0-py2.py3-none-any.whl", hash = "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1"}, {file = "imagesize-1.2.0-py2.py3-none-any.whl", hash = "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1"},
{file = "imagesize-1.2.0.tar.gz", hash = "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"}, {file = "imagesize-1.2.0.tar.gz", hash = "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"},
] ]
importlib-metadata = [ importlib-metadata = [
{file = "importlib_metadata-4.0.1-py3-none-any.whl", hash = "sha256:d7eb1dea6d6a6086f8be21784cc9e3bcfa55872b52309bc5fad53a8ea444465d"}, {file = "importlib_metadata-4.5.0-py3-none-any.whl", hash = "sha256:833b26fb89d5de469b24a390e9df088d4e52e4ba33b01dc5e0e4f41b81a16c00"},
{file = "importlib_metadata-4.0.1.tar.gz", hash = "sha256:8c501196e49fb9df5df43833bdb1e4328f64847763ec8a50703148b73784d581"}, {file = "importlib_metadata-4.5.0.tar.gz", hash = "sha256:b142cc1dd1342f31ff04bb7d022492b09920cb64fed867cd3ea6f80fe3ebd139"},
] ]
iniconfig = [ iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
] ]
ipython = [ ipython = [
{file = "ipython-7.23.1-py3-none-any.whl", hash = "sha256:f78c6a3972dde1cc9e4041cbf4de583546314ba52d3c97208e5b6b2221a9cb7d"}, {file = "ipython-7.24.1-py3-none-any.whl", hash = "sha256:d513e93327cf8657d6467c81f1f894adc125334ffe0e4ddd1abbb1c78d828703"},
{file = "ipython-7.23.1.tar.gz", hash = "sha256:714810a5c74f512b69d5f3b944c86e592cee0a5fb9c728e582f074610f6cf038"}, {file = "ipython-7.24.1.tar.gz", hash = "sha256:9bc24a99f5d19721fb8a2d1408908e9c0520a17fff2233ffe82620847f17f1b6"},
] ]
ipython-genutils = [ ipython-genutils = [
{file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
@ -1637,13 +1667,17 @@ pytest-asyncio = [
{file = "pytest_asyncio-0.15.1-py3-none-any.whl", hash = "sha256:3042bcdf1c5d978f6b74d96a151c4cfb9dcece65006198389ccd7e6c60eb1eea"}, {file = "pytest_asyncio-0.15.1-py3-none-any.whl", hash = "sha256:3042bcdf1c5d978f6b74d96a151c4cfb9dcece65006198389ccd7e6c60eb1eea"},
] ]
pytest-cov = [ pytest-cov = [
{file = "pytest-cov-2.12.0.tar.gz", hash = "sha256:8535764137fecce504a49c2b742288e3d34bc09eed298ad65963616cc98fd45e"}, {file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"},
{file = "pytest_cov-2.12.0-py2.py3-none-any.whl", hash = "sha256:95d4933dcbbacfa377bb60b29801daa30d90c33981ab2a79e9ab4452c165066e"}, {file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"},
] ]
pytest-html = [ pytest-html = [
{file = "pytest-html-3.1.1.tar.gz", hash = "sha256:3ee1cf319c913d19fe53aeb0bc400e7b0bc2dbeb477553733db1dad12eb75ee3"}, {file = "pytest-html-3.1.1.tar.gz", hash = "sha256:3ee1cf319c913d19fe53aeb0bc400e7b0bc2dbeb477553733db1dad12eb75ee3"},
{file = "pytest_html-3.1.1-py3-none-any.whl", hash = "sha256:b7f82f123936a3f4d2950bc993c2c1ca09ce262c9ae12f9ac763a2401380b455"}, {file = "pytest_html-3.1.1-py3-none-any.whl", hash = "sha256:b7f82f123936a3f4d2950bc993c2c1ca09ce262c9ae12f9ac763a2401380b455"},
] ]
pytest-lazy-fixture = [
{file = "pytest-lazy-fixture-0.6.3.tar.gz", hash = "sha256:0e7d0c7f74ba33e6e80905e9bfd81f9d15ef9a790de97993e34213deb5ad10ac"},
{file = "pytest_lazy_fixture-0.6.3-py3-none-any.whl", hash = "sha256:e0b379f38299ff27a653f03eaa69b08a6fd4484e46fd1c9907d984b9f9daeda6"},
]
pytest-metadata = [ pytest-metadata = [
{file = "pytest-metadata-1.11.0.tar.gz", hash = "sha256:71b506d49d34e539cc3cfdb7ce2c5f072bea5c953320002c95968e0238f8ecf1"}, {file = "pytest-metadata-1.11.0.tar.gz", hash = "sha256:71b506d49d34e539cc3cfdb7ce2c5f072bea5c953320002c95968e0238f8ecf1"},
{file = "pytest_metadata-1.11.0-py2.py3-none-any.whl", hash = "sha256:576055b8336dd4a9006dd2a47615f76f2f8c30ab12b1b1c039d99e834583523f"}, {file = "pytest_metadata-1.11.0-py2.py3-none-any.whl", hash = "sha256:576055b8336dd4a9006dd2a47615f76f2f8c30ab12b1b1c039d99e834583523f"},
@ -1763,8 +1797,8 @@ sphinx-autobuild = [
{file = "sphinx_autobuild-2020.9.1-py3-none-any.whl", hash = "sha256:df5c72cb8b8fc9b31279c4619780c4e95029be6de569ff60a8bb2e99d20f63dd"}, {file = "sphinx_autobuild-2020.9.1-py3-none-any.whl", hash = "sha256:df5c72cb8b8fc9b31279c4619780c4e95029be6de569ff60a8bb2e99d20f63dd"},
] ]
sphinx-copybutton = [ sphinx-copybutton = [
{file = "sphinx-copybutton-0.3.1.tar.gz", hash = "sha256:0e0461df394515284e3907e3f418a0c60ef6ab6c9a27a800c8552772d0a402a2"}, {file = "sphinx-copybutton-0.3.2.tar.gz", hash = "sha256:f901f17e7dadc063bcfca592c5160f9113ec17501a59e046af3edb82b7527656"},
{file = "sphinx_copybutton-0.3.1-py3-none-any.whl", hash = "sha256:5125c718e763596e6e52d92e15ee0d6f4800ad3817939be6dee51218870b3e3d"}, {file = "sphinx_copybutton-0.3.2-py3-none-any.whl", hash = "sha256:f16f8ed8dfc60f2b34a58cb69bfa04722e24be2f6d7e04db5554c32cde4df815"},
] ]
sphinx-intl = [ sphinx-intl = [
{file = "sphinx-intl-2.0.1.tar.gz", hash = "sha256:b25a6ec169347909e8d983eefe2d8adecb3edc2f27760db79b965c69950638b4"}, {file = "sphinx-intl-2.0.1.tar.gz", hash = "sha256:b25a6ec169347909e8d983eefe2d8adecb3edc2f27760db79b965c69950638b4"},

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "aiogram" name = "aiogram"
version = "3.0.0-alpha.8" version = "3.0.0-alpha.9"
description = "Modern and fully asynchronous framework for Telegram Bot API" description = "Modern and fully asynchronous framework for Telegram Bot API"
authors = ["Alex Root Junior <jroot.junior@gmail.com>"] authors = ["Alex Root Junior <jroot.junior@gmail.com>"]
license = "MIT" license = "MIT"
@ -38,6 +38,7 @@ Babel = "^2.9.1"
aiofiles = "^0.6.0" aiofiles = "^0.6.0"
async_lru = "^1.0.2" async_lru = "^1.0.2"
aiohttp-socks = { version = "^0.5.5", optional = true } aiohttp-socks = { version = "^0.5.5", optional = true }
aioredis = { version = "^2.0.0a1", allow-prereleases = true, optional = true }
typing-extensions = { version = "^3.7.4", python = "<3.8" } typing-extensions = { version = "^3.7.4", python = "<3.8" }
magic-filter = { version = "1.0.0a1", allow-prereleases = true } magic-filter = { version = "1.0.0a1", allow-prereleases = true }
sphinx = { version = "^3.1.0", optional = true } sphinx = { version = "^3.1.0", optional = true }
@ -50,6 +51,7 @@ Sphinx-Substitution-Extensions = { version = "^2020.9.30", optional = true }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
aiohttp-socks = "^0.5" aiohttp-socks = "^0.5"
aioredis = { version = "^2.0.0a1", allow-prereleases = true }
ipython = "^7.22.0" ipython = "^7.22.0"
uvloop = { version = "^0.15.2", markers = "sys_platform == 'darwin' or sys_platform == 'linux'" } uvloop = { version = "^0.15.2", markers = "sys_platform == 'darwin' or sys_platform == 'linux'" }
black = "^21.4b2" black = "^21.4b2"
@ -79,9 +81,11 @@ sphinx-copybutton = "^0.3.1"
furo = "^2020.11.15-beta.17" furo = "^2020.11.15-beta.17"
sphinx-prompt = "^1.3.0" sphinx-prompt = "^1.3.0"
Sphinx-Substitution-Extensions = "^2020.9.30" Sphinx-Substitution-Extensions = "^2020.9.30"
pytest-lazy-fixture = "^0.6.3"
[tool.poetry.extras] [tool.poetry.extras]
fast = ["uvloop"] fast = ["uvloop"]
redis = ["aioredis"]
proxy = ["aiohttp-socks"] proxy = ["aiohttp-socks"]
docs = [ docs = [
"sphinx", "sphinx",

View file

@ -1,9 +1,64 @@
import pytest import pytest
from _pytest.config import UsageError
from aioredis.connection import parse_url as parse_redis_url
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot
def pytest_addoption(parser):
parser.addoption("--redis", default=None, help="run tests which require redis connection")
def pytest_configure(config):
config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
def pytest_collection_modifyitems(config, items):
redis_uri = config.getoption("--redis")
if redis_uri is None:
skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
for item in items:
if "redis" in item.keywords:
item.add_marker(skip_redis)
return
try:
parse_redis_url(redis_uri)
except ValueError as e:
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
@pytest.fixture(scope="session")
def redis_server(request):
redis_uri = request.config.getoption("--redis")
return redis_uri
@pytest.fixture()
@pytest.mark.redis
async def redis_storage(redis_server):
if not redis_server:
pytest.skip("Redis is not available here")
storage = RedisStorage.from_url(redis_server)
try:
yield storage
finally:
conn = await storage.redis
await conn.flushdb()
await storage.close()
@pytest.fixture()
async def memory_storage():
storage = MemoryStorage()
try:
yield storage
finally:
await storage.close()
@pytest.fixture() @pytest.fixture()
def bot(): def bot():
bot = MockedBot() bot = MockedBot()

7
tests/docker-compose.yml Normal file
View file

@ -0,0 +1,7 @@
version: "3.9"
services:
redis:
image: redis:6-alpine
ports:
- "${REDIS_PORT-6379}:6379"

View file

@ -6,7 +6,7 @@ import pytest
from aiogram.client.session.base import BaseSession, TelegramType from aiogram.client.session.base import BaseSession, TelegramType
from aiogram.client.telegram import PRODUCTION, TelegramAPIServer from aiogram.client.telegram import PRODUCTION, TelegramAPIServer
from aiogram.methods import DeleteMessage, GetMe, Response, TelegramMethod from aiogram.methods import DeleteMessage, GetMe, TelegramMethod
from aiogram.types import UNSET from aiogram.types import UNSET
try: try:
@ -20,7 +20,9 @@ class CustomSession(BaseSession):
async def close(self): async def close(self):
pass pass
async def make_request(self, token: str, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET) -> None: # type: ignore async def make_request(
self, token: str, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET
) -> None: # type: ignore
assert isinstance(token, str) assert isinstance(token, str)
assert isinstance(method, TelegramMethod) assert isinstance(method, TelegramMethod)

View file

@ -3,7 +3,7 @@ from typing import Union
import pytest import pytest
from aiogram.methods import EditMessageMedia, Request from aiogram.methods import EditMessageMedia, Request
from aiogram.types import BufferedInputFile, InputMedia, InputMediaPhoto, Message from aiogram.types import BufferedInputFile, InputMediaPhoto, Message
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot

View file

@ -2,8 +2,8 @@ import datetime
from typing import Optional from typing import Optional
import pytest import pytest
from aiogram.types import Chat, Message
from aiogram.types import Chat, Message
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot

View file

@ -3,7 +3,7 @@ import datetime
import pytest import pytest
from aiogram.methods import Request, SendAudio from aiogram.methods import Request, SendAudio
from aiogram.types import Audio, Chat, File, Message from aiogram.types import Audio, Chat, Message
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aiogram.methods import Request, SetChatAdministratorCustomTitle, SetChatTitle from aiogram.methods import Request, SetChatAdministratorCustomTitle
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot

View file

@ -1,7 +1,7 @@
import pytest import pytest
from aiogram.methods import Request, SetChatPhoto from aiogram.methods import Request, SetChatPhoto
from aiogram.types import BufferedInputFile, InputFile from aiogram.types import BufferedInputFile
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot

View file

@ -9,8 +9,6 @@ import pytest
from aiogram import Bot from aiogram import Bot
from aiogram.dispatcher.dispatcher import Dispatcher from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
from aiogram.dispatcher.fsm.strategy import FSMStrategy
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.dispatcher.router import Router from aiogram.dispatcher.router import Router
from aiogram.methods import GetMe, GetUpdates, SendMessage from aiogram.methods import GetMe, GetUpdates, SendMessage
from aiogram.types import ( from aiogram.types import (
@ -423,7 +421,7 @@ class TestDispatcher:
assert User.get_current(False) assert User.get_current(False)
return kwargs return kwargs
result = await router.update.trigger(update, test="PASS") result = await router.update.trigger(update, test="PASS", bot=None)
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["event_update"] == update assert result["event_update"] == update
assert result["event_router"] == router assert result["event_router"] == router
@ -526,8 +524,9 @@ class TestDispatcher:
assert len(log_records) == 1 assert len(log_records) == 1
assert "Cause exception while process update" in log_records[0] assert "Cause exception while process update" in log_records[0]
@pytest.mark.parametrize("as_task", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_polling(self, bot: MockedBot): async def test_polling(self, bot: MockedBot, as_task: bool):
dispatcher = Dispatcher() dispatcher = Dispatcher()
async def _mock_updates(*_): async def _mock_updates(*_):
@ -539,7 +538,10 @@ class TestDispatcher:
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates" "aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
) as patched_listen_updates: ) as patched_listen_updates:
patched_listen_updates.return_value = _mock_updates() patched_listen_updates.return_value = _mock_updates()
await dispatcher._polling(bot=bot) await dispatcher._polling(bot=bot, handle_as_tasks=as_task)
if as_task:
pass
else:
mocked_process_update.assert_awaited() mocked_process_update.assert_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -548,9 +550,12 @@ class TestDispatcher:
router = Router() router = Router()
dp.include_router(router) dp.include_router(router)
class CustomException(Exception):
pass
@router.message() @router.message()
async def message_handler(message: Message): async def message_handler(message: Message):
raise Exception("KABOOM") raise CustomException("KABOOM")
update = Update( update = Update(
update_id=42, update_id=42,
@ -562,23 +567,23 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"), from_user=User(id=42, is_bot=False, first_name="Test"),
), ),
) )
with pytest.raises(Exception, match="KABOOM"): with pytest.raises(CustomException, match="KABOOM"):
await dp.update.trigger(update) await dp.update.trigger(update, bot=None)
@router.errors() @router.errors()
async def error_handler(event: Update, exception: Exception): async def error_handler(event: Update, exception: Exception):
return "KABOOM" return "KABOOM"
response = await dp.update.trigger(update) response = await dp.update.trigger(update, bot=None)
assert response == "KABOOM" assert response == "KABOOM"
@dp.errors() @dp.errors()
async def root_error_handler(event: Update, exception: Exception): async def root_error_handler(event: Update, exception: Exception):
return exception return exception
response = await dp.update.trigger(update) response = await dp.update.trigger(update, bot=None)
assert isinstance(response, Exception) assert isinstance(response, CustomException)
assert str(response) == "KABOOM" assert str(response) == "KABOOM"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -654,20 +659,3 @@ class TestDispatcher:
log_records = [rec.message for rec in caplog.records] log_records = [rec.message for rec in caplog.records]
assert "Cause exception while process update" in log_records[0] assert "Cause exception while process update" in log_records[0]
@pytest.mark.parametrize(
"strategy,case,expected",
[
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
[FSMStrategy.CHAT, (-42, 42), (-42, -42)],
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
[FSMStrategy.CHAT, (42, 42), (42, 42)],
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
],
)
def test_get_current_state_context(self, strategy, case, expected):
dp = Dispatcher(fsm_strategy=strategy)
chat_id, user_id = case
state = dp.current_state(chat_id=chat_id, user_id=user_id)
assert (state.chat_id, state.user_id) == expected

View file

@ -5,7 +5,6 @@ import pytest
from aiogram import F from aiogram import F
from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject
from aiogram.dispatcher.filters import Text
from aiogram.dispatcher.filters.base import BaseFilter from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler from aiogram.dispatcher.handler.base import BaseHandler
from aiogram.types import Update from aiogram.types import Update

View file

@ -1,10 +1,11 @@
import datetime import datetime
import re import re
from typing import Match
import pytest import pytest
from aiogram import F
from aiogram.dispatcher.filters import Command, CommandObject from aiogram.dispatcher.filters import Command, CommandObject
from aiogram.dispatcher.filters.command import CommandStart
from aiogram.methods import GetMe from aiogram.methods import GetMe
from aiogram.types import Chat, Message, User from aiogram.types import Chat, Message, User
from tests.mocked_bot import MockedBot from tests.mocked_bot import MockedBot
@ -18,45 +19,54 @@ class TestCommandFilter:
assert cmd.commands[0] == "start" assert cmd.commands[0] == "start"
assert cmd == Command(commands=["start"]) assert cmd == Command(commands=["start"])
@pytest.mark.parametrize(
"text,command,result",
[
["/test@tbot", Command(commands=["test"], commands_prefix="/"), True],
["!test", Command(commands=["test"], commands_prefix="/"), False],
["/test@mention", Command(commands=["test"], commands_prefix="/"), False],
["/tests", Command(commands=["test"], commands_prefix="/"), False],
["/", Command(commands=["test"], commands_prefix="/"), False],
["/ test", Command(commands=["test"], commands_prefix="/"), False],
["", Command(commands=["test"], commands_prefix="/"), False],
[" ", Command(commands=["test"], commands_prefix="/"), False],
["test", Command(commands=["test"], commands_prefix="/"), False],
[" test", Command(commands=["test"], commands_prefix="/"), False],
["a", Command(commands=["test"], commands_prefix="/"), False],
["/test@tbot some args", Command(commands=["test"]), True],
["/test42@tbot some args", Command(commands=[re.compile(r"test(\d+)")]), True],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "some args"),
True,
],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "test"),
False,
],
["/start test", CommandStart(), True],
["/start", CommandStart(deep_link=True), False],
["/start test", CommandStart(deep_link=True), True],
["/start test", CommandStart(deep_link=True, deep_link_encoded=True), False],
["/start dGVzdA", CommandStart(deep_link=True, deep_link_encoded=True), True],
],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_command(self, bot: MockedBot): async def test_parse_command(self, bot: MockedBot, text: str, result: bool, command: Command):
# TODO: parametrize
# TODO: test ignore case # TODO: test ignore case
# TODO: test ignore mention # TODO: test ignore mention
bot.add_result_for( bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot") GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
) )
command = Command(commands=["test", re.compile(r"test(\d+)")], commands_prefix="/")
assert await command.parse_command("/test@tbot", bot) message = Message(
assert not await command.parse_command("!test", bot) message_id=0, text=text, chat=Chat(id=42, type="private"), date=datetime.datetime.now()
assert not await command.parse_command("/test@mention", bot) )
assert not await command.parse_command("/tests", bot)
assert not await command.parse_command("/", bot)
assert not await command.parse_command("/ test", bot)
assert not await command.parse_command("", bot)
assert not await command.parse_command(" ", bot)
assert not await command.parse_command("test", bot)
assert not await command.parse_command(" test", bot)
assert not await command.parse_command("a", bot)
result = await command.parse_command("/test@tbot some args", bot) response = await command(message, bot)
assert isinstance(result, dict) assert bool(response) is result
assert "command" in result
assert isinstance(result["command"], CommandObject)
assert result["command"].command == "test"
assert result["command"].mention == "tbot"
assert result["command"].args == "some args"
result = await command.parse_command("/test42@tbot some args", bot)
assert isinstance(result, dict)
assert "command" in result
assert isinstance(result["command"], CommandObject)
assert result["command"].command == "test42"
assert result["command"].mention == "tbot"
assert result["command"].args == "some args"
assert isinstance(result["command"].match, Match)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -1,45 +0,0 @@
import pytest
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage, MemoryStorageRecord
@pytest.fixture()
def storage():
return MemoryStorage()
class TestMemoryStorage:
@pytest.mark.asyncio
async def test_set_state(self, storage: MemoryStorage):
assert await storage.get_state(chat_id=-42, user_id=42) is None
await storage.set_state(chat_id=-42, user_id=42, state="state")
assert await storage.get_state(chat_id=-42, user_id=42) == "state"
assert -42 in storage.storage
assert 42 in storage.storage[-42]
assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
assert storage.storage[-42][42].state == "state"
@pytest.mark.asyncio
async def test_set_data(self, storage: MemoryStorage):
assert await storage.get_data(chat_id=-42, user_id=42) == {}
await storage.set_data(chat_id=-42, user_id=42, data={"foo": "bar"})
assert await storage.get_data(chat_id=-42, user_id=42) == {"foo": "bar"}
assert -42 in storage.storage
assert 42 in storage.storage[-42]
assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
assert storage.storage[-42][42].data == {"foo": "bar"}
@pytest.mark.asyncio
async def test_update_data(self, storage: MemoryStorage):
assert await storage.get_data(chat_id=-42, user_id=42) == {}
assert await storage.update_data(chat_id=-42, user_id=42, data={"foo": "bar"}) == {
"foo": "bar"
}
assert await storage.update_data(chat_id=-42, user_id=42, data={"baz": "spam"}) == {
"foo": "bar",
"baz": "spam",
}

View file

@ -0,0 +1,21 @@
import pytest
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot
@pytest.mark.redis
class TestRedisStorage:
@pytest.mark.parametrize(
"prefix_bot,result",
[
[False, "fsm:-1:2"],
[True, "fsm:42:-1:2"],
[{42: "kaboom"}, "fsm:kaboom:-1:2"],
[lambda bot: "kaboom", "fsm:kaboom:-1:2"],
],
)
@pytest.mark.asyncio
async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result):
storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot)
assert storage.generate_key(bot, -1, 2) == result

View file

@ -0,0 +1,44 @@
import pytest
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from tests.mocked_bot import MockedBot
@pytest.mark.parametrize(
"storage",
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
)
class TestStorages:
@pytest.mark.asyncio
async def test_lock(self, bot: MockedBot, storage: BaseStorage):
# TODO: ?!?
async with storage.lock(bot=bot, chat_id=-42, user_id=42):
assert True, "You are kidding me?"
@pytest.mark.asyncio
async def test_set_state(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state")
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state"
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None)
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
@pytest.mark.asyncio
async def test_set_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
@pytest.mark.asyncio
async def test_update_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}
) == {"foo": "bar"}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"}
) == {"foo": "bar", "baz": "spam"}

View file

@ -2,27 +2,28 @@ import pytest
from aiogram.dispatcher.fsm.context import FSMContext from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from tests.mocked_bot import MockedBot
@pytest.fixture() @pytest.fixture()
def state(): def state(bot: MockedBot):
storage = MemoryStorage() storage = MemoryStorage()
ctx = storage.storage[-42][42] ctx = storage.storage[bot][-42][42]
ctx.state = "test" ctx.state = "test"
ctx.data = {"foo": "bar"} ctx.data = {"foo": "bar"}
return FSMContext(storage=storage, user_id=-42, chat_id=42) return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42)
class TestFSMContext: class TestFSMContext:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_address_mapping(self): async def test_address_mapping(self, bot: MockedBot):
storage = MemoryStorage() storage = MemoryStorage()
ctx = storage.storage[-42][42] ctx = storage.storage[bot][-42][42]
ctx.state = "test" ctx.state = "test"
ctx.data = {"foo": "bar"} ctx.data = {"foo": "bar"}
state = FSMContext(storage=storage, chat_id=-42, user_id=42) state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42)
state2 = FSMContext(storage=storage, chat_id=42, user_id=42) state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42)
state3 = FSMContext(storage=storage, chat_id=69, user_id=69) state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69)
assert await state.get_state() == "test" assert await state.get_state() == "test"
assert await state2.get_state() is None assert await state2.get_state() is None

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest import pytest
from aiogram.dispatcher.handler import ChosenInlineResultHandler from aiogram.dispatcher.handler import ChosenInlineResultHandler
from aiogram.types import CallbackQuery, ChosenInlineResult, User from aiogram.types import ChosenInlineResult, User
class TestChosenInlineResultHandler: class TestChosenInlineResultHandler:

View file

@ -2,16 +2,7 @@ from typing import Any
import pytest import pytest
from aiogram.dispatcher.handler import ErrorHandler, PollHandler from aiogram.dispatcher.handler import ErrorHandler
from aiogram.types import (
CallbackQuery,
InlineQuery,
Poll,
PollOption,
ShippingAddress,
ShippingQuery,
User,
)
class TestErrorHandler: class TestErrorHandler:

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest import pytest
from aiogram.dispatcher.handler import InlineQueryHandler from aiogram.dispatcher.handler import InlineQueryHandler
from aiogram.types import CallbackQuery, InlineQuery, User from aiogram.types import InlineQuery, User
class TestCallbackQueryHandler: class TestCallbackQueryHandler:

View file

@ -3,15 +3,7 @@ from typing import Any
import pytest import pytest
from aiogram.dispatcher.handler import PollHandler from aiogram.dispatcher.handler import PollHandler
from aiogram.types import ( from aiogram.types import Poll, PollOption
CallbackQuery,
InlineQuery,
Poll,
PollOption,
ShippingAddress,
ShippingQuery,
User,
)
class TestShippingQueryHandler: class TestShippingQueryHandler:

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest import pytest
from aiogram.dispatcher.handler import ShippingQueryHandler from aiogram.dispatcher.handler import ShippingQueryHandler
from aiogram.types import CallbackQuery, InlineQuery, ShippingAddress, ShippingQuery, User from aiogram.types import ShippingAddress, ShippingQuery, User
class TestShippingQueryHandler: class TestShippingQueryHandler:

View file

@ -0,0 +1,27 @@
import pytest
from aiogram.utils.auth_widget import check_integrity
TOKEN = "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11"
@pytest.fixture
def data():
return {
"id": "42",
"first_name": "John",
"last_name": "Smith",
"username": "username",
"photo_url": "https://t.me/i/userpic/320/picname.jpg",
"auth_date": "1565810688",
"hash": "c303db2b5a06fe41d23a9b14f7c545cfc11dcc7473c07c9c5034ae60062461ce",
}
class TestCheckIntegrity:
def test_ok(self, data):
assert check_integrity(TOKEN, data) is True
def test_fail(self, data):
data.pop("username")
assert check_integrity(TOKEN, data) is False

View file

@ -0,0 +1,94 @@
import pytest
from async_lru import alru_cache
from aiogram.utils.deep_linking import (
create_start_link,
create_startgroup_link,
decode_payload,
encode_payload,
)
from tests.mocked_bot import MockedBot
PAYLOADS = [
"foo",
"AAbbCCddEEff1122334455",
"aaBBccDDeeFF5544332211",
-12345678901234567890,
12345678901234567890,
]
WRONG_PAYLOADS = [
"@BotFather",
"Some:special$characters#=",
"spaces spaces spaces",
1234567890123456789.0,
]
@pytest.fixture(params=PAYLOADS, name="payload")
def payload_fixture(request):
return request.param
@pytest.fixture(params=WRONG_PAYLOADS, name="wrong_payload")
def wrong_payload_fixture(request):
return request.param
@pytest.fixture(autouse=True)
def get_bot_user_fixture(monkeypatch):
"""Monkey patching of bot.me calling."""
@alru_cache()
async def get_bot_user_mock(self):
from aiogram.types import User
return User(
id=12345678,
is_bot=True,
first_name="FirstName",
last_name="LastName",
username="username",
language_code="uk-UA",
)
monkeypatch.setattr(MockedBot, "me", get_bot_user_mock)
@pytest.mark.asyncio
class TestDeepLinking:
async def test_get_start_link(self, bot, payload):
link = await create_start_link(bot=bot, payload=payload)
assert link == f"https://t.me/username?start={payload}"
async def test_wrong_symbols(self, bot, wrong_payload):
with pytest.raises(ValueError):
await create_start_link(bot, wrong_payload)
async def test_get_startgroup_link(self, bot, payload):
link = await create_startgroup_link(bot, payload)
assert link == f"https://t.me/username?startgroup={payload}"
async def test_filter_encode_and_decode(self, payload):
encoded = encode_payload(payload)
decoded = decode_payload(encoded)
assert decoded == str(payload)
async def test_get_start_link_with_encoding(self, bot, wrong_payload):
# define link
link = await create_start_link(bot, wrong_payload, encode=True)
# define reference link
encoded_payload = encode_payload(wrong_payload)
assert link == f"https://t.me/username?start={encoded_payload}"
async def test_64_len_payload(self, bot):
payload = "p" * 64
link = await create_start_link(bot, payload)
assert link
async def test_too_long_payload(self, bot):
payload = "p" * 65
print(payload, len(payload))
with pytest.raises(ValueError):
await create_start_link(bot, payload)

View file

@ -0,0 +1,24 @@
from typing import Any, Dict
import pytest
from aiogram.utils.link import create_telegram_link, create_tg_link
class TestLink:
@pytest.mark.parametrize(
"base,params,result",
[["user", dict(id=42), "tg://user?id=42"]],
)
def test_create_tg_link(self, base: str, params: Dict[str, Any], result: str):
assert create_tg_link(base, **params) == result
@pytest.mark.parametrize(
"base,params,result",
[
["username", dict(), "https://t.me/username"],
["username", dict(start="test"), "https://t.me/username?start=test"],
],
)
def test_create_telegram_link(self, base: str, params: Dict[str, Any], result: str):
assert create_telegram_link(base, **params) == result

View file

@ -35,7 +35,7 @@ class TestMarkdown:
[hitalic, ("test", "test"), " ", "<i>test test</i>"], [hitalic, ("test", "test"), " ", "<i>test test</i>"],
[code, ("test", "test"), " ", "`test test`"], [code, ("test", "test"), " ", "`test test`"],
[hcode, ("test", "test"), " ", "<code>test test</code>"], [hcode, ("test", "test"), " ", "<code>test test</code>"],
[pre, ("test", "test"), " ", "```test test```"], [pre, ("test", "test"), " ", "```\ntest test\n```"],
[hpre, ("test", "test"), " ", "<pre>test test</pre>"], [hpre, ("test", "test"), " ", "<pre>test test</pre>"],
[underline, ("test", "test"), " ", "__\rtest test__\r"], [underline, ("test", "test"), " ", "__\rtest test__\r"],
[hunderline, ("test", "test"), " ", "<u>test test</u>"], [hunderline, ("test", "test"), " ", "<u>test test</u>"],

View file

@ -55,7 +55,7 @@ class TestTextDecoration:
[markdown_decoration, MessageEntity(type="bold", offset=0, length=5), "*test*"], [markdown_decoration, MessageEntity(type="bold", offset=0, length=5), "*test*"],
[markdown_decoration, MessageEntity(type="italic", offset=0, length=5), "_\rtest_\r"], [markdown_decoration, MessageEntity(type="italic", offset=0, length=5), "_\rtest_\r"],
[markdown_decoration, MessageEntity(type="code", offset=0, length=5), "`test`"], [markdown_decoration, MessageEntity(type="code", offset=0, length=5), "`test`"],
[markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```test```"], [markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```\ntest\n```"],
[ [
markdown_decoration, markdown_decoration,
MessageEntity(type="pre", offset=0, length=5, language="python"), MessageEntity(type="pre", offset=0, length=5, language="python"),