Backport and improvements (#601)

* Backport RedisStorage, deep-linking
* Allow prereleases for aioredis
* Bump dependencies
* Correctly skip Redis tests on Windows
* Reformat tests code and bump Makefile
This commit is contained in:
Alex Root Junior 2021-06-15 01:45:31 +03:00 committed by GitHub
parent 32bc05130f
commit 83d6ab48c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 1004 additions and 327 deletions

View file

@ -43,6 +43,12 @@ jobs:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Setup redis
if: ${{ matrix.os != 'windows-latest' }}
uses: shogo82148/actions-setup-redis@v1
with:
redis-version: 6
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v2
@ -64,7 +70,14 @@ jobs:
run: |
poetry run black --check --diff aiogram tests
- name: Run tests
- name: Run tests (with Redis)
if: ${{ matrix.os != 'windows-latest' }}
run: |
poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml --redis redis://localhost:6379/0
- name: Run tests (without Redis)
# Redis can't be used on GitHub Windows Runners
if: ${{ matrix.os == 'windows-latest' }}
run: |
poetry run pytest --cov=aiogram --cov-config .coveragerc --cov-report=xml

View file

@ -4,8 +4,14 @@ base_python := python3
py := poetry run
python := $(py) python
package_dir := aiogram
tests_dir := tests
scripts_dir := scripts
code_dir := $(package_dir) $(tests_dir) $(scripts_dir)
reports_dir := reports
redis_connection := redis://localhost:6379
.PHONY: help
help:
@echo "======================================================================================="
@ -17,13 +23,8 @@ help:
@echo " clean: Delete temporary files"
@echo ""
@echo "Code quality:"
@echo " isort: Run isort tool"
@echo " black: Run black tool"
@echo " flake8: Run flake8 tool"
@echo " flake8-report: Run flake8 with HTML reporting"
@echo " mypy: Run mypy tool"
@echo " mypy-report: Run mypy tool with HTML reporting"
@echo " lint: Run isort, black, flake8 and mypy tools"
@echo " lint: Lint code by isort, black, flake8 and mypy tools"
@echo " reformat: Reformat code by isort and black tools"
@echo ""
@echo "Tests:"
@echo " test: Run tests"
@ -65,33 +66,17 @@ clean:
# Code quality
# =================================================================================================
.PHONY: isort
isort:
$(py) isort aiogram tests scripts
.PHONY: black
black:
$(py) black aiogram tests scripts
.PHONY: flake8
flake8:
$(py) flake8 aiogram
.PHONY: flake8-report
flake8-report:
mkdir -p $(reports_dir)/flake8
$(py) flake8 --format=html --htmldir=$(reports_dir)/flake8 aiogram
.PHONY: mypy
mypy:
$(py) mypy aiogram
.PHONY: mypy-report
mypy-report:
$(py) mypy aiogram --html-report $(reports_dir)/typechecking
.PHONY: lint
lint: isort black flake8 mypy
lint:
$(py) isort --check-only $(code_dir)
$(py) black --check --diff $(code_dir)
$(py) flake8 $(code_dir)
$(py) mypy $(package_dir)
.PHONY: reformat
reformat:
$(py) black $(code_dir)
$(py) isort $(code_dir)
# =================================================================================================
# Tests
@ -99,12 +84,12 @@ lint: isort black flake8 mypy
.PHONY: test
test:
$(py) pytest --cov=aiogram --cov-config .coveragerc tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc tests/ --redis $(redis_connection)
.PHONY: test-coverage
test-coverage:
mkdir -p $(reports_dir)/tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/
$(py) pytest --cov=aiogram --cov-config .coveragerc --html=$(reports_dir)/tests/index.html tests/ --redis $(redis_connection)
.PHONY: test-coverage-report
test-coverage-report:

View file

@ -6,6 +6,8 @@ from .dispatcher import filters, handler
from .dispatcher.dispatcher import Dispatcher
from .dispatcher.middlewares.base import BaseMiddleware
from .dispatcher.router import Router
from .utils.text_decorations import html_decoration as _html_decoration
from .utils.text_decorations import markdown_decoration as _markdown_decoration
try:
import uvloop as _uvloop
@ -15,6 +17,8 @@ except ImportError: # pragma: no cover
pass
F = MagicFilter()
html = _html_decoration
md = _markdown_decoration
__all__ = (
"__api_version__",
@ -29,6 +33,8 @@ __all__ = (
"filters",
"handler",
"F",
"html",
"md",
)
__version__ = "3.0.0-alpha.8"

View file

@ -4,7 +4,7 @@ import asyncio
import contextvars
import warnings
from asyncio import CancelledError, Future, Lock
from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
from typing import Any, AsyncGenerator, Dict, Optional, Union
from .. import loggers
from ..client.bot import Bot
@ -13,7 +13,6 @@ from ..types import TelegramObject, Update, User
from ..utils.exceptions.base import TelegramAPIError
from .event.bases import UNHANDLED, SkipHandler
from .event.telegram import TelegramEventObserver
from .fsm.context import FSMContext
from .fsm.middleware import FSMContextMiddleware
from .fsm.storage.base import BaseStorage
from .fsm.storage.memory import MemoryStorage
@ -32,7 +31,7 @@ class Dispatcher(Router):
self,
storage: Optional[BaseStorage] = None,
fsm_strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
isolate_events: bool = True,
isolate_events: bool = False,
**kwargs: Any,
) -> None:
super(Dispatcher, self).__init__(**kwargs)
@ -255,7 +254,9 @@ class Dispatcher(Router):
)
return True # because update was processed but unsuccessful
async def _polling(self, bot: Bot, polling_timeout: int = 30, **kwargs: Any) -> None:
async def _polling(
self, bot: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
"""
Internal polling process
@ -264,7 +265,11 @@ class Dispatcher(Router):
:return:
"""
async for update in self._listen_updates(bot, polling_timeout=polling_timeout):
await self._process_update(bot=bot, update=update, **kwargs)
handle_update = self._process_update(bot=bot, update=update, **kwargs)
if handle_as_tasks:
asyncio.create_task(handle_update)
else:
await handle_update
async def _feed_webhook_update(self, bot: Bot, update: Update, **kwargs: Any) -> Any:
"""
@ -342,11 +347,15 @@ class Dispatcher(Router):
return None
async def start_polling(self, *bots: Bot, polling_timeout: int = 10, **kwargs: Any) -> None:
async def start_polling(
self, *bots: Bot, polling_timeout: int = 10, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
"""
Polling runner
:param bots:
:param polling_timeout:
:param handle_as_tasks:
:param kwargs:
:return:
"""
@ -363,7 +372,12 @@ class Dispatcher(Router):
"Run polling for bot @%s id=%d - %r", user.username, bot.id, user.full_name
)
coro_list.append(
self._polling(bot=bot, polling_timeout=polling_timeout, **kwargs)
self._polling(
bot=bot,
handle_as_tasks=handle_as_tasks,
polling_timeout=polling_timeout,
**kwargs,
)
)
await asyncio.gather(*coro_list)
finally:
@ -372,22 +386,27 @@ class Dispatcher(Router):
loggers.dispatcher.info("Polling stopped")
await self.emit_shutdown(**workflow_data)
def run_polling(self, *bots: Bot, polling_timeout: int = 30, **kwargs: Any) -> None:
def run_polling(
self, *bots: Bot, polling_timeout: int = 30, handle_as_tasks: bool = True, **kwargs: Any
) -> None:
"""
Run many bots with polling
:param bots: Bot instances
:param polling_timeout: Poling timeout
:param handle_as_tasks: Run task for each event and no wait result
:param kwargs: contextual data
:return:
"""
try:
return asyncio.run(
self.start_polling(*bots, **kwargs, polling_timeout=polling_timeout)
self.start_polling(
*bots,
**kwargs,
polling_timeout=polling_timeout,
handle_as_tasks=handle_as_tasks,
)
)
except (KeyboardInterrupt, SystemExit): # pragma: no cover
# Allow to graceful shutdown
pass
def current_state(self, chat_id: int, user_id: int) -> FSMContext:
return cast(FSMContext, self.fsm.resolve_context(chat_id=chat_id, user_id=user_id))

View file

@ -1,18 +1,24 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Union, cast
from dataclasses import dataclass, field, replace
from typing import Any, Dict, Match, Optional, Pattern, Sequence, Tuple, Union, cast
from pydantic import validator
from magic_filter import MagicFilter
from pydantic import Field, validator
from aiogram import Bot
from aiogram.dispatcher.filters import BaseFilter
from aiogram.types import Message
from aiogram.utils.deep_linking import decode_payload
CommandPatterType = Union[str, re.Pattern]
class CommandException(Exception):
pass
class Command(BaseFilter):
"""
This filter can be helpful for handling commands from the text messages.
@ -29,6 +35,8 @@ class Command(BaseFilter):
"""Ignore case (Does not work with regexp, use flags instead)"""
commands_ignore_mention: bool = False
"""Ignore bot mention. By default bot can not handle commands intended for other bots"""
command_magic: Optional[MagicFilter] = None
"""Validate command object via Magic filter after all checks done"""
@validator("commands", always=True)
def _validate_commands(
@ -39,12 +47,54 @@ class Command(BaseFilter):
return value
async def __call__(self, message: Message, bot: Bot) -> Union[bool, Dict[str, Any]]:
if not message.text:
text = message.text or message.caption
if not text:
return False
return await self.parse_command(text=message.text, bot=bot)
try:
command = await self.parse_command(text=cast(str, message.text), bot=bot)
except CommandException:
return False
return {"command": command}
async def parse_command(self, text: str, bot: Bot) -> Union[bool, Dict[str, CommandObject]]:
def extract_command(self, text: str) -> CommandObject:
# First step: separate command with arguments
# "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"]
try:
full_command, *args = text.split(maxsplit=1)
except ValueError:
raise CommandException("not enough values to unpack")
# Separate command into valuable parts
# "/command@mention" -> "/", ("command", "@", "mention")
prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@")
return CommandObject(
prefix=prefix, command=command, mention=mention, args=args[0] if args else None
)
def validate_prefix(self, command: CommandObject) -> None:
if command.prefix not in self.commands_prefix:
raise CommandException("Invalid command prefix")
async def validate_mention(self, bot: Bot, command: CommandObject) -> None:
if command.mention and not self.commands_ignore_mention:
me = await bot.me()
if me.username and command.mention.lower() != me.username.lower():
raise CommandException("Mention did not match")
def validate_command(self, command: CommandObject) -> CommandObject:
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
# Command can be presented as regexp pattern or raw string
# then need to validate that in different ways
if isinstance(allowed_command, Pattern): # Regexp
result = allowed_command.match(command.command)
if result:
return replace(command, match=result)
elif command.command == allowed_command: # String
return command
raise CommandException("Command did not match pattern")
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
"""
Extract command from the text and validate
@ -52,56 +102,18 @@ class Command(BaseFilter):
:param bot:
:return:
"""
if not text.strip():
return False
command = self.extract_command(text)
self.validate_prefix(command=command)
await self.validate_mention(bot=bot, command=command)
command = self.validate_command(command)
self.do_magic(command=command)
return command
# First step: separate command with arguments
# "/command@mention arg1 arg2" -> "/command@mention", ["arg1 arg2"]
full_command, *args = text.split(maxsplit=1)
# Separate command into valuable parts
# "/command@mention" -> "/", ("command", "@", "mention")
prefix, (command, _, mention) = full_command[0], full_command[1:].partition("@")
# Validate prefixes
if prefix not in self.commands_prefix:
return False
# Validate mention
if mention and not self.commands_ignore_mention:
me = await bot.me()
if me.username and mention.lower() != me.username.lower():
return False
# Validate command
for allowed_command in cast(Sequence[CommandPatterType], self.commands):
# Command can be presented as regexp pattern or raw string
# then need to validate that in different ways
if isinstance(allowed_command, Pattern): # Regexp
result = allowed_command.match(command)
if result:
return {
"command": CommandObject(
prefix=prefix,
command=command,
mention=mention,
args=args[0] if args else None,
match=result,
)
}
elif command == allowed_command: # String
return {
"command": CommandObject(
prefix=prefix,
command=command,
mention=mention,
args=args[0] if args else None,
match=None,
)
}
return False
def do_magic(self, command: CommandObject) -> None:
if not self.command_magic:
return
if not self.command_magic.resolve(command):
raise CommandException("Rejected via magic filter")
class Config:
arbitrary_types_allowed = True
@ -143,3 +155,40 @@ class CommandObject:
if self.args:
line += " " + self.args
return line
class CommandStart(Command):
commands: Tuple[str] = Field(("start",), const=True)
commands_prefix: str = Field("/", const=True)
deep_link: bool = False
deep_link_encoded: bool = False
async def parse_command(self, text: str, bot: Bot) -> CommandObject:
"""
Extract command from the text and validate
:param text:
:param bot:
:return:
"""
command = self.extract_command(text)
self.validate_prefix(command=command)
await self.validate_mention(bot=bot, command=command)
command = self.validate_command(command)
command = self.validate_deeplink(command=command)
self.do_magic(command=command)
return command
def validate_deeplink(self, command: CommandObject) -> CommandObject:
if not self.deep_link:
return command
if not command.args:
raise CommandException("Deep-link was missing")
args = command.args
if self.deep_link_encoded:
try:
args = decode_payload(args)
except UnicodeDecodeError as e:
raise CommandException(f"Failed to decode Base64: {e}")
return replace(command, args=args)
return command

View file

@ -1,25 +1,35 @@
from typing import Any, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
class FSMContext:
def __init__(self, storage: BaseStorage, chat_id: int, user_id: int) -> None:
def __init__(self, bot: Bot, storage: BaseStorage, chat_id: int, user_id: int) -> None:
self.bot = bot
self.storage = storage
self.chat_id = chat_id
self.user_id = user_id
async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(chat_id=self.chat_id, user_id=self.user_id, state=state)
await self.storage.set_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, state=state
)
async def get_state(self) -> Optional[str]:
return await self.storage.get_state(chat_id=self.chat_id, user_id=self.user_id)
return await self.storage.get_state(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data(chat_id=self.chat_id, user_id=self.user_id, data=data)
await self.storage.set_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=data
)
async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data(chat_id=self.chat_id, user_id=self.user_id)
return await self.storage.get_data(
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id
)
async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
@ -27,7 +37,7 @@ class FSMContext:
if data:
kwargs.update(data)
return await self.storage.update_data(
chat_id=self.chat_id, user_id=self.user_id, data=kwargs
bot=self.bot, chat_id=self.chat_id, user_id=self.user_id, data=kwargs
)
async def clear(self) -> None:

View file

@ -1,5 +1,6 @@
from typing import Any, Awaitable, Callable, Dict, Optional
from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from aiogram.dispatcher.fsm.strategy import FSMStrategy, apply_strategy
@ -24,24 +25,27 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
event: Update,
data: Dict[str, Any],
) -> Any:
context = self.resolve_event_context(data)
bot: Bot = cast(Bot, data["bot"])
context = self.resolve_event_context(bot, data)
data["fsm_storage"] = self.storage
if context:
data.update({"state": context, "raw_state": await context.get_state()})
if self.isolate_events:
async with self.storage.lock(chat_id=context.chat_id, user_id=context.user_id):
async with self.storage.lock(
bot=bot, chat_id=context.chat_id, user_id=context.user_id
):
return await handler(event, data)
return await handler(event, data)
def resolve_event_context(self, data: Dict[str, Any]) -> Optional[FSMContext]:
def resolve_event_context(self, bot: Bot, data: Dict[str, Any]) -> Optional[FSMContext]:
user = data.get("event_from_user")
chat = data.get("event_chat")
chat_id = chat.id if chat else None
user_id = user.id if user else None
return self.resolve_context(chat_id=chat_id, user_id=user_id)
return self.resolve_context(bot=bot, chat_id=chat_id, user_id=user_id)
def resolve_context(
self, chat_id: Optional[int], user_id: Optional[int]
self, bot: Bot, chat_id: Optional[int], user_id: Optional[int]
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
@ -50,8 +54,8 @@ class FSMContextMiddleware(BaseMiddleware[Update]):
chat_id, user_id = apply_strategy(
chat_id=chat_id, user_id=user_id, strategy=self.strategy
)
return self.get_context(chat_id=chat_id, user_id=user_id)
return self.get_context(bot=bot, chat_id=chat_id, user_id=user_id)
return None
def get_context(self, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(storage=self.storage, chat_id=chat_id, user_id=user_id)
def get_context(self, bot: Bot, chat_id: int, user_id: int) -> FSMContext:
return FSMContext(bot=bot, storage=self.storage, chat_id=chat_id, user_id=user_id)

View file

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Union
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
StateType = Optional[Union[str, State]]
@ -11,34 +12,42 @@ class BaseStorage(ABC):
@abstractmethod
@asynccontextmanager
async def lock(
self, chat_id: int, user_id: int
self, bot: Bot, chat_id: int, user_id: int
) -> AsyncGenerator[None, None]: # pragma: no cover
yield None
@abstractmethod
async def set_state(
self, chat_id: int, user_id: int, state: StateType = None
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None: # pragma: no cover
pass
@abstractmethod
async def get_state(self, chat_id: int, user_id: int) -> Optional[str]: # pragma: no cover
async def get_state(
self, bot: Bot, chat_id: int, user_id: int
) -> Optional[str]: # pragma: no cover
pass
@abstractmethod
async def set_data(
self, chat_id: int, user_id: int, data: Dict[str, Any]
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> None: # pragma: no cover
pass
@abstractmethod
async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]: # pragma: no cover
async def get_data(
self, bot: Bot, chat_id: int, user_id: int
) -> Dict[str, Any]: # pragma: no cover
pass
async def update_data(
self, chat_id: int, user_id: int, data: Dict[str, Any]
self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]
) -> Dict[str, Any]:
current_data = await self.get_data(chat_id=chat_id, user_id=user_id)
current_data = await self.get_data(bot=bot, chat_id=chat_id, user_id=user_id)
current_data.update(data)
await self.set_data(chat_id=chat_id, user_id=user_id, data=current_data)
await self.set_data(bot=bot, chat_id=chat_id, user_id=user_id, data=current_data)
return current_data.copy()
@abstractmethod
async def close(self) -> None: # pragma: no cover
pass

View file

@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, DefaultDict, Dict, Optional
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
@ -17,23 +18,30 @@ class MemoryStorageRecord:
class MemoryStorage(BaseStorage):
def __init__(self) -> None:
self.storage: DefaultDict[int, DefaultDict[int, MemoryStorageRecord]] = defaultdict(
lambda: defaultdict(MemoryStorageRecord)
)
self.storage: DefaultDict[
Bot, DefaultDict[int, DefaultDict[int, MemoryStorageRecord]]
] = defaultdict(lambda: defaultdict(lambda: defaultdict(MemoryStorageRecord)))
async def close(self) -> None:
pass
@asynccontextmanager
async def lock(self, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
async with self.storage[chat_id][user_id].lock:
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
async with self.storage[bot][chat_id][user_id].lock:
yield None
async def set_state(self, chat_id: int, user_id: int, state: StateType = None) -> None:
self.storage[chat_id][user_id].state = state.state if isinstance(state, State) else state
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
self.storage[bot][chat_id][user_id].state = (
state.state if isinstance(state, State) else state
)
async def get_state(self, chat_id: int, user_id: int) -> Optional[str]:
return self.storage[chat_id][user_id].state
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
return self.storage[bot][chat_id][user_id].state
async def set_data(self, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
self.storage[chat_id][user_id].data = data.copy()
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
self.storage[bot][chat_id][user_id].data = data.copy()
async def get_data(self, chat_id: int, user_id: int) -> Dict[str, Any]:
return self.storage[chat_id][user_id].data.copy()
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
return self.storage[bot][chat_id][user_id].data.copy()

View file

@ -0,0 +1,101 @@
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, cast
from aioredis import ConnectionPool, Redis
from aiogram import Bot
from aiogram.dispatcher.fsm.state import State
from aiogram.dispatcher.fsm.storage.base import BaseStorage, StateType
PrefixFactoryType = Callable[[Bot], str]
STATE_KEY = "state"
STATE_DATA_KEY = "data"
STATE_LOCK_KEY = "lock"
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
class RedisStorage(BaseStorage):
def __init__(
self,
redis: Redis,
prefix: str = "fsm",
prefix_bot: Union[bool, PrefixFactoryType, Dict[int, str]] = False,
state_ttl: Optional[int] = None,
data_ttl: Optional[int] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis
self.prefix = prefix
self.prefix_bot = prefix_bot
self.state_ttl = state_ttl
self.data_ttl = data_ttl
self.lock_kwargs = lock_kwargs
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> "RedisStorage":
if connection_kwargs is None:
connection_kwargs = {}
pool = ConnectionPool.from_url(url, **connection_kwargs)
redis = Redis(connection_pool=pool)
return cls(redis=redis, **kwargs)
async def close(self) -> None:
await self.redis.close()
def generate_key(self, bot: Bot, *parts: Any) -> str:
prefix_parts = [self.prefix]
if self.prefix_bot:
if isinstance(self.prefix_bot, dict):
prefix_parts.append(self.prefix_bot[bot.id])
elif callable(self.prefix_bot):
prefix_parts.append(self.prefix_bot(bot))
else:
prefix_parts.append(str(bot.id))
prefix_parts.extend(parts)
return ":".join(map(str, prefix_parts))
@asynccontextmanager
async def lock(self, bot: Bot, chat_id: int, user_id: int) -> AsyncGenerator[None, None]:
key = self.generate_key(bot, chat_id, user_id, STATE_LOCK_KEY)
async with self.redis.lock(name=key, **self.lock_kwargs):
yield None
async def set_state(
self, bot: Bot, chat_id: int, user_id: int, state: StateType = None
) -> None:
key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
if state is None:
await self.redis.delete(key)
else:
await self.redis.set(
key, state.state if isinstance(state, State) else state, ex=self.state_ttl
)
async def get_state(self, bot: Bot, chat_id: int, user_id: int) -> Optional[str]:
key = self.generate_key(bot, chat_id, user_id, STATE_KEY)
value = await self.redis.get(key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
async def set_data(self, bot: Bot, chat_id: int, user_id: int, data: Dict[str, Any]) -> None:
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
if not data:
await self.redis.delete(key)
return
json_data = bot.session.json_dumps(data)
await self.redis.set(key, json_data, ex=self.data_ttl)
async def get_data(self, bot: Bot, chat_id: int, user_id: int) -> Dict[str, Any]:
key = self.generate_key(bot, chat_id, user_id, STATE_DATA_KEY)
value = await self.redis.get(key)
if value is None:
return {}
if isinstance(value, bytes):
value = value.decode("utf-8")
return cast(Dict[str, Any], bot.session.json_loads(value))

View file

@ -0,0 +1,34 @@
import hashlib
import hmac
from typing import Any, Dict
def check_signature(token: str, hash: str, **kwargs: Any) -> bool:
"""
Generate hexadecimal representation
of the HMAC-SHA-256 signature of the data-check-string
with the SHA256 hash of the bot's token used as a secret key
:param token:
:param hash:
:param kwargs: all params received on auth
:return:
"""
secret = hashlib.sha256(token.encode("utf-8"))
check_string = "\n".join(map(lambda k: f"{k}={kwargs[k]}", sorted(kwargs)))
hmac_string = hmac.new(
secret.digest(), check_string.encode("utf-8"), digestmod=hashlib.sha256
).hexdigest()
return hmac_string == hash
def check_integrity(token: str, data: Dict[str, Any]) -> bool:
"""
Verify the authentication and the integrity
of the data received on user's auth
:param token: Bot's token
:param data: all data that came on auth
:return:
"""
return check_signature(token, **data)

View file

@ -0,0 +1,131 @@
"""
Deep linking
Telegram bots have a deep linking mechanism, that allows for passing
additional parameters to the bot on startup. It could be a command that
launches the bot or an auth token to connect the user's Telegram
account to their account on some external service.
You can read detailed description in the source:
https://core.telegram.org/bots#deep-linking
We have add some utils to get deep links more handy.
Basic link example:
.. code-block:: python
from aiogram.utils.deep_linking import get_start_link
link = await get_start_link('foo')
# result: 'https://t.me/MyBot?start=foo'
Encoded link example:
.. code-block:: python
from aiogram.utils.deep_linking import get_start_link
link = await get_start_link('foo', encode=True)
# result: 'https://t.me/MyBot?start=Zm9v'
Decode it back example:
.. code-block:: python
from aiogram.utils.deep_linking import decode_payload
from aiogram.types import Message
@dp.message_handler(commands=["start"])
async def handler(message: Message):
args = message.get_args()
payload = decode_payload(args)
await message.answer(f"Your payload: {payload}")
"""
import re
from base64 import urlsafe_b64decode, urlsafe_b64encode
from typing import Literal, cast
from aiogram import Bot
from aiogram.utils.link import create_telegram_link
BAD_PATTERN = re.compile(r"[^_A-z0-9-]")
async def create_start_link(bot: Bot, payload: str, encode: bool = False) -> str:
"""
Create 'start' deep link with your payload.
If you need to encode payload or pass special characters -
set encode as True
:param bot: bot instance
:param payload: args passed with /start
:param encode: encode payload with base64url
:return: link
"""
username = (await bot.me()).username
return create_deep_link(username=username, link_type="start", payload=payload, encode=encode)
async def create_startgroup_link(bot: Bot, payload: str, encode: bool = False) -> str:
"""
Create 'startgroup' deep link with your payload.
If you need to encode payload or pass special characters -
set encode as True
:param bot: bot instance
:param payload: args passed with /start
:param encode: encode payload with base64url
:return: link
"""
username = (await bot.me()).username
return create_deep_link(
username=username, link_type="startgroup", payload=payload, encode=encode
)
def create_deep_link(
username: str, link_type: Literal["start", "startgroup"], payload: str, encode: bool = False
) -> str:
"""
Create deep link.
:param username:
:param link_type: `start` or `startgroup`
:param payload: any string-convertible data
:param encode: pass True to encode the payload
:return: deeplink
"""
if not isinstance(payload, str):
payload = str(payload)
if encode:
payload = encode_payload(payload)
if re.search(BAD_PATTERN, payload):
raise ValueError(
"Wrong payload! Only A-Z, a-z, 0-9, _ and - are allowed. "
"Pass `encode=True` or encode payload manually."
)
if len(payload) > 64:
raise ValueError("Payload must be up to 64 characters long.")
return create_telegram_link(username, **{cast(str, link_type): payload})
def encode_payload(payload: str) -> str:
"""Encode payload with URL-safe base64url."""
payload = str(payload)
bytes_payload: bytes = urlsafe_b64encode(payload.encode())
str_payload = bytes_payload.decode()
return str_payload.replace("=", "")
def decode_payload(payload: str) -> str:
"""Decode payload with URL-safe base64url."""
payload += "=" * (4 - len(payload) % 4)
result: bytes = urlsafe_b64decode(payload)
return result.decode()

View file

@ -2,13 +2,27 @@ from __future__ import annotations
from itertools import chain
from itertools import cycle as repeat_all
from typing import Any, Generator, Generic, Iterable, List, Optional, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Generator,
Generic,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
no_type_check,
)
from aiogram.dispatcher.filters.callback_data import CallbackData
from aiogram.types import (
CallbackGame,
InlineKeyboardButton,
InlineKeyboardMarkup,
KeyboardButton,
LoginUrl,
ReplyKeyboardMarkup,
)
@ -239,3 +253,28 @@ def repeat_last(items: Iterable[T]) -> Generator[T, None, None]:
except StopIteration:
finished = True
yield value
class InlineKeyboardConstructor(KeyboardConstructor[InlineKeyboardButton]):
if TYPE_CHECKING: # pragma: no cover
@no_type_check
def button(
self,
text: str,
url: Optional[str] = None,
login_url: Optional[LoginUrl] = None,
callback_data: Optional[Union[str, CallbackData]] = None,
switch_inline_query: Optional[str] = None,
switch_inline_query_current_chat: Optional[str] = None,
callback_game: Optional[CallbackGame] = None,
pay: Optional[bool] = None,
**kwargs: Any,
) -> "KeyboardConstructor[InlineKeyboardButton]":
...
def as_markup(self, **kwargs: Any) -> InlineKeyboardMarkup:
...
def __init__(self) -> None:
super().__init__(InlineKeyboardButton)

18
aiogram/utils/link.py Normal file
View file

@ -0,0 +1,18 @@
from typing import Any
from urllib.parse import urlencode, urljoin
def create_tg_link(link: str, **kwargs: Any) -> str:
url = f"tg://{link}"
if kwargs:
query = urlencode(kwargs)
url += f"?{query}"
return url
def create_telegram_link(uri: str, **kwargs: Any) -> str:
url = urljoin("https://t.me", uri)
if kwargs:
query = urlencode(query=kwargs)
url += f"?{query}"
return url

View file

@ -183,7 +183,7 @@ class MarkdownDecoration(TextDecoration):
return f"`{value}`"
def pre(self, value: str) -> str:
return f"```{value}```"
return f"```\n{value}\n```"
def pre_language(self, value: str, language: str) -> str:
return f"```{language}\n{value}\n```"

View file

@ -1,6 +1,6 @@
[mypy]
;plugins = pydantic.mypy
python_version = 3.7
python_version = 3.8
show_error_codes = True
show_error_context = True
pretty = True
@ -29,3 +29,6 @@ ignore_missing_imports = True
[mypy-uvloop]
ignore_missing_imports = True
[mypy-aioredis]
ignore_missing_imports = True

106
poetry.lock generated
View file

@ -38,6 +38,21 @@ aiohttp = ">=2.3.2"
attrs = ">=19.2.0"
python-socks = {version = ">=1.0.1", extras = ["asyncio"]}
[[package]]
name = "aioredis"
version = "2.0.0a1"
description = "asyncio (PEP 3156) Redis support"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
async-timeout = "*"
typing-extensions = "*"
[package.extras]
hiredis = ["hiredis (>=1.0)"]
[[package]]
name = "alabaster"
version = "0.7.12"
@ -156,7 +171,7 @@ lxml = ["lxml"]
[[package]]
name = "black"
version = "21.5b1"
version = "21.6b0"
description = "The uncompromising code formatter."
category = "dev"
optional = false
@ -172,8 +187,9 @@ toml = ">=0.10.1"
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.6.0)", "aiohttp-cors"]
d = ["aiohttp (>=3.6.0)", "aiohttp-cors (>=0.4.0)"]
python2 = ["typed-ast (>=1.4.2)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "cfgv"
@ -218,9 +234,6 @@ category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
[package.dependencies]
toml = {version = "*", optional = true, markers = "extra == \"toml\""}
[package.extras]
toml = ["toml"]
@ -234,7 +247,7 @@ python-versions = ">=3.5"
[[package]]
name = "distlib"
version = "0.3.1"
version = "0.3.2"
description = "Distribution utilities"
category = "dev"
optional = false
@ -301,7 +314,7 @@ test = ["pytest", "pytest-cov", "pytest-xdist"]
[[package]]
name = "identify"
version = "2.2.5"
version = "2.2.10"
description = "File identification library for Python"
category = "dev"
optional = false
@ -312,11 +325,11 @@ license = ["editdistance-s"]
[[package]]
name = "idna"
version = "3.1"
version = "3.2"
description = "Internationalized Domain Names in Applications (IDNA)"
category = "main"
optional = false
python-versions = ">=3.4"
python-versions = ">=3.5"
[[package]]
name = "imagesize"
@ -328,7 +341,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "importlib-metadata"
version = "4.0.1"
version = "4.5.0"
description = "Read metadata from Python packages"
category = "dev"
optional = false
@ -351,7 +364,7 @@ python-versions = "*"
[[package]]
name = "ipython"
version = "7.23.1"
version = "7.24.1"
description = "IPython: Productive Interactive Computing"
category = "dev"
optional = false
@ -371,7 +384,7 @@ pygments = "*"
traitlets = ">=4.2"
[package.extras]
all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.16)", "pygments", "qtconsole", "requests", "testpath"]
all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.17)", "pygments", "qtconsole", "requests", "testpath"]
doc = ["Sphinx (>=1.3)"]
kernel = ["ipykernel"]
nbconvert = ["nbconvert"]
@ -379,7 +392,7 @@ nbformat = ["nbformat"]
notebook = ["notebook", "ipywidgets"]
parallel = ["ipyparallel"]
qtconsole = ["qtconsole"]
test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.16)"]
test = ["nose (>=0.10.1)", "requests", "testpath", "pygments", "nbformat", "ipykernel", "numpy (>=1.17)"]
[[package]]
name = "ipython-genutils"
@ -739,18 +752,19 @@ testing = ["coverage", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "2.12.0"
version = "2.12.1"
description = "Pytest plugin for measuring coverage."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
coverage = ">=5.2.1"
pytest = ">=4.6"
toml = "*"
[package.extras]
testing = ["fields", "hunter", "process-tests (==2.0.2)", "six", "pytest-xdist", "virtualenv"]
testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
[[package]]
name = "pytest-html"
@ -764,6 +778,17 @@ python-versions = ">=3.6"
pytest = ">=5.0,<6.0.0 || >6.0.0"
pytest-metadata = "*"
[[package]]
name = "pytest-lazy-fixture"
version = "0.6.3"
description = "It helps to use fixtures in pytest.mark.parametrize"
category = "dev"
optional = false
python-versions = "*"
[package.dependencies]
pytest = ">=3.2.5"
[[package]]
name = "pytest-metadata"
version = "1.11.0"
@ -930,17 +955,17 @@ test = ["pytest", "pytest-cov"]
[[package]]
name = "sphinx-copybutton"
version = "0.3.1"
version = "0.3.2"
description = "Add a copy button to each of your code cells."
category = "main"
optional = false
python-versions = "*"
python-versions = ">=3.6"
[package.dependencies]
sphinx = ">=1.8"
[package.extras]
code_style = ["flake8 (>=3.7.0,<3.8.0)", "black", "pre-commit (==1.17.0)"]
code_style = ["pre-commit (==2.12.1)"]
[[package]]
name = "sphinx-intl"
@ -1171,11 +1196,12 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pyt
docs = ["sphinx", "sphinx-intl", "sphinx-autobuild", "sphinx-copybutton", "furo", "sphinx-prompt", "Sphinx-Substitution-Extensions"]
fast = []
proxy = ["aiohttp-socks"]
redis = ["aioredis"]
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "2fcd44a8937b3ea48196c8eba8ceb0533281af34c884103bcc5b4f5f16b817d5"
content-hash = "362a6caf937b1c457599cbf2cd5d000eab4cac529bd7fe8c257ae713ebc63331"
[metadata.files]
aiofiles = [
@ -1225,6 +1251,10 @@ aiohttp-socks = [
{file = "aiohttp_socks-0.5.5-py3-none-any.whl", hash = "sha256:faaa25ed4dc34440ca888d23e089420f3b1918dc4ecf062c3fd9474827ad6a39"},
{file = "aiohttp_socks-0.5.5.tar.gz", hash = "sha256:2eb2059756bde34c55bb429541cbf2eba3fd53e36ac80875b461221e2858b04a"},
]
aioredis = [
{file = "aioredis-2.0.0a1-py3-none-any.whl", hash = "sha256:32d7910724282a475c91b8b34403867069a4f07bf0c5ad5fe66cd797322f9a0d"},
{file = "aioredis-2.0.0a1.tar.gz", hash = "sha256:5884f384b8ecb143bb73320a96e7c464fd38e117950a7d48340a35db8e35e7d2"},
]
alabaster = [
{file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"},
{file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"},
@ -1274,8 +1304,8 @@ beautifulsoup4 = [
{file = "beautifulsoup4-4.9.3.tar.gz", hash = "sha256:84729e322ad1d5b4d25f805bfa05b902dd96450f43842c4e99067d5e1369eb25"},
]
black = [
{file = "black-21.5b1-py3-none-any.whl", hash = "sha256:8a60071a0043876a4ae96e6c69bd3a127dad2c1ca7c8083573eb82f92705d008"},
{file = "black-21.5b1.tar.gz", hash = "sha256:23695358dbcb3deafe7f0a3ad89feee5999a46be5fec21f4f1d108be0bcdb3b1"},
{file = "black-21.6b0-py3-none-any.whl", hash = "sha256:dfb8c5a069012b2ab1e972e7b908f5fb42b6bbabcba0a788b86dc05067c7d9c7"},
{file = "black-21.6b0.tar.gz", hash = "sha256:dc132348a88d103016726fe360cb9ede02cecf99b76e3660ce6c596be132ce04"},
]
cfgv = [
{file = "cfgv-3.3.0-py2.py3-none-any.whl", hash = "sha256:b449c9c6118fe8cca7fa5e00b9ec60ba08145d281d52164230a69211c5d597a1"},
@ -1352,8 +1382,8 @@ decorator = [
{file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"},
]
distlib = [
{file = "distlib-0.3.1-py2.py3-none-any.whl", hash = "sha256:8c09de2c67b3e7deef7184574fc060ab8a793e7adbb183d942c389c8b13c52fb"},
{file = "distlib-0.3.1.zip", hash = "sha256:edf6116872c863e1aa9d5bb7cb5e05a022c519a4594dc703843343a9ddd9bff1"},
{file = "distlib-0.3.2-py2.py3-none-any.whl", hash = "sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c"},
{file = "distlib-0.3.2.zip", hash = "sha256:106fef6dc37dd8c0e2c0a60d3fca3e77460a48907f335fa28420463a6f799736"},
]
docutils = [
{file = "docutils-0.17.1-py2.py3-none-any.whl", hash = "sha256:cf316c8370a737a022b72b56874f6602acf974a37a9fba42ec2876387549fc61"},
@ -1376,28 +1406,28 @@ furo = [
{file = "furo-2020.12.30b24.tar.gz", hash = "sha256:30171899c9c06d692a778e6daf6cb2e5cbb05efc6006e1692e5e776007dc8a8c"},
]
identify = [
{file = "identify-2.2.5-py2.py3-none-any.whl", hash = "sha256:9c3ab58543c03bd794a1735e4552ef6dec49ec32053278130d525f0982447d47"},
{file = "identify-2.2.5.tar.gz", hash = "sha256:bc1705694253763a3160b943316867792ec00ba7a0ee40b46e20aebaf4e0c46a"},
{file = "identify-2.2.10-py2.py3-none-any.whl", hash = "sha256:18d0c531ee3dbc112fa6181f34faa179de3f57ea57ae2899754f16a7e0ff6421"},
{file = "identify-2.2.10.tar.gz", hash = "sha256:5b41f71471bc738e7b586308c3fca172f78940195cb3bf6734c1e66fdac49306"},
]
idna = [
{file = "idna-3.1-py3-none-any.whl", hash = "sha256:5205d03e7bcbb919cc9c19885f9920d622ca52448306f2377daede5cf3faac16"},
{file = "idna-3.1.tar.gz", hash = "sha256:c5b02147e01ea9920e6b0a3f1f7bb833612d507592c837a6c49552768f4054e1"},
{file = "idna-3.2-py3-none-any.whl", hash = "sha256:14475042e284991034cb48e06f6851428fb14c4dc953acd9be9a5e95c7b6dd7a"},
{file = "idna-3.2.tar.gz", hash = "sha256:467fbad99067910785144ce333826c71fb0e63a425657295239737f7ecd125f3"},
]
imagesize = [
{file = "imagesize-1.2.0-py2.py3-none-any.whl", hash = "sha256:6965f19a6a2039c7d48bca7dba2473069ff854c36ae6f19d2cde309d998228a1"},
{file = "imagesize-1.2.0.tar.gz", hash = "sha256:b1f6b5a4eab1f73479a50fb79fcf729514a900c341d8503d62a62dbc4127a2b1"},
]
importlib-metadata = [
{file = "importlib_metadata-4.0.1-py3-none-any.whl", hash = "sha256:d7eb1dea6d6a6086f8be21784cc9e3bcfa55872b52309bc5fad53a8ea444465d"},
{file = "importlib_metadata-4.0.1.tar.gz", hash = "sha256:8c501196e49fb9df5df43833bdb1e4328f64847763ec8a50703148b73784d581"},
{file = "importlib_metadata-4.5.0-py3-none-any.whl", hash = "sha256:833b26fb89d5de469b24a390e9df088d4e52e4ba33b01dc5e0e4f41b81a16c00"},
{file = "importlib_metadata-4.5.0.tar.gz", hash = "sha256:b142cc1dd1342f31ff04bb7d022492b09920cb64fed867cd3ea6f80fe3ebd139"},
]
iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
]
ipython = [
{file = "ipython-7.23.1-py3-none-any.whl", hash = "sha256:f78c6a3972dde1cc9e4041cbf4de583546314ba52d3c97208e5b6b2221a9cb7d"},
{file = "ipython-7.23.1.tar.gz", hash = "sha256:714810a5c74f512b69d5f3b944c86e592cee0a5fb9c728e582f074610f6cf038"},
{file = "ipython-7.24.1-py3-none-any.whl", hash = "sha256:d513e93327cf8657d6467c81f1f894adc125334ffe0e4ddd1abbb1c78d828703"},
{file = "ipython-7.24.1.tar.gz", hash = "sha256:9bc24a99f5d19721fb8a2d1408908e9c0520a17fff2233ffe82620847f17f1b6"},
]
ipython-genutils = [
{file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
@ -1637,13 +1667,17 @@ pytest-asyncio = [
{file = "pytest_asyncio-0.15.1-py3-none-any.whl", hash = "sha256:3042bcdf1c5d978f6b74d96a151c4cfb9dcece65006198389ccd7e6c60eb1eea"},
]
pytest-cov = [
{file = "pytest-cov-2.12.0.tar.gz", hash = "sha256:8535764137fecce504a49c2b742288e3d34bc09eed298ad65963616cc98fd45e"},
{file = "pytest_cov-2.12.0-py2.py3-none-any.whl", hash = "sha256:95d4933dcbbacfa377bb60b29801daa30d90c33981ab2a79e9ab4452c165066e"},
{file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"},
{file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"},
]
pytest-html = [
{file = "pytest-html-3.1.1.tar.gz", hash = "sha256:3ee1cf319c913d19fe53aeb0bc400e7b0bc2dbeb477553733db1dad12eb75ee3"},
{file = "pytest_html-3.1.1-py3-none-any.whl", hash = "sha256:b7f82f123936a3f4d2950bc993c2c1ca09ce262c9ae12f9ac763a2401380b455"},
]
pytest-lazy-fixture = [
{file = "pytest-lazy-fixture-0.6.3.tar.gz", hash = "sha256:0e7d0c7f74ba33e6e80905e9bfd81f9d15ef9a790de97993e34213deb5ad10ac"},
{file = "pytest_lazy_fixture-0.6.3-py3-none-any.whl", hash = "sha256:e0b379f38299ff27a653f03eaa69b08a6fd4484e46fd1c9907d984b9f9daeda6"},
]
pytest-metadata = [
{file = "pytest-metadata-1.11.0.tar.gz", hash = "sha256:71b506d49d34e539cc3cfdb7ce2c5f072bea5c953320002c95968e0238f8ecf1"},
{file = "pytest_metadata-1.11.0-py2.py3-none-any.whl", hash = "sha256:576055b8336dd4a9006dd2a47615f76f2f8c30ab12b1b1c039d99e834583523f"},
@ -1763,8 +1797,8 @@ sphinx-autobuild = [
{file = "sphinx_autobuild-2020.9.1-py3-none-any.whl", hash = "sha256:df5c72cb8b8fc9b31279c4619780c4e95029be6de569ff60a8bb2e99d20f63dd"},
]
sphinx-copybutton = [
{file = "sphinx-copybutton-0.3.1.tar.gz", hash = "sha256:0e0461df394515284e3907e3f418a0c60ef6ab6c9a27a800c8552772d0a402a2"},
{file = "sphinx_copybutton-0.3.1-py3-none-any.whl", hash = "sha256:5125c718e763596e6e52d92e15ee0d6f4800ad3817939be6dee51218870b3e3d"},
{file = "sphinx-copybutton-0.3.2.tar.gz", hash = "sha256:f901f17e7dadc063bcfca592c5160f9113ec17501a59e046af3edb82b7527656"},
{file = "sphinx_copybutton-0.3.2-py3-none-any.whl", hash = "sha256:f16f8ed8dfc60f2b34a58cb69bfa04722e24be2f6d7e04db5554c32cde4df815"},
]
sphinx-intl = [
{file = "sphinx-intl-2.0.1.tar.gz", hash = "sha256:b25a6ec169347909e8d983eefe2d8adecb3edc2f27760db79b965c69950638b4"},

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "aiogram"
version = "3.0.0-alpha.8"
version = "3.0.0-alpha.9"
description = "Modern and fully asynchronous framework for Telegram Bot API"
authors = ["Alex Root Junior <jroot.junior@gmail.com>"]
license = "MIT"
@ -38,8 +38,9 @@ Babel = "^2.9.1"
aiofiles = "^0.6.0"
async_lru = "^1.0.2"
aiohttp-socks = { version = "^0.5.5", optional = true }
aioredis = { version = "^2.0.0a1", allow-prereleases = true, optional = true }
typing-extensions = { version = "^3.7.4", python = "<3.8" }
magic-filter = {version = "1.0.0a1", allow-prereleases = true}
magic-filter = { version = "1.0.0a1", allow-prereleases = true }
sphinx = { version = "^3.1.0", optional = true }
sphinx-intl = { version = "^2.0.1", optional = true }
sphinx-autobuild = { version = "^2020.9.1", optional = true }
@ -50,6 +51,7 @@ Sphinx-Substitution-Extensions = { version = "^2020.9.30", optional = true }
[tool.poetry.dev-dependencies]
aiohttp-socks = "^0.5"
aioredis = { version = "^2.0.0a1", allow-prereleases = true }
ipython = "^7.22.0"
uvloop = { version = "^0.15.2", markers = "sys_platform == 'darwin' or sys_platform == 'linux'" }
black = "^21.4b2"
@ -79,9 +81,11 @@ sphinx-copybutton = "^0.3.1"
furo = "^2020.11.15-beta.17"
sphinx-prompt = "^1.3.0"
Sphinx-Substitution-Extensions = "^2020.9.30"
pytest-lazy-fixture = "^0.6.3"
[tool.poetry.extras]
fast = ["uvloop"]
redis = ["aioredis"]
proxy = ["aiohttp-socks"]
docs = [
"sphinx",

View file

@ -1,9 +1,64 @@
import pytest
from _pytest.config import UsageError
from aioredis.connection import parse_url as parse_redis_url
from aiogram import Bot
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot
def pytest_addoption(parser):
parser.addoption("--redis", default=None, help="run tests which require redis connection")
def pytest_configure(config):
config.addinivalue_line("markers", "redis: marked tests require redis connection to run")
def pytest_collection_modifyitems(config, items):
redis_uri = config.getoption("--redis")
if redis_uri is None:
skip_redis = pytest.mark.skip(reason="need --redis option with redis URI to run")
for item in items:
if "redis" in item.keywords:
item.add_marker(skip_redis)
return
try:
parse_redis_url(redis_uri)
except ValueError as e:
raise UsageError(f"Invalid redis URI {redis_uri!r}: {e}")
@pytest.fixture(scope="session")
def redis_server(request):
redis_uri = request.config.getoption("--redis")
return redis_uri
@pytest.fixture()
@pytest.mark.redis
async def redis_storage(redis_server):
if not redis_server:
pytest.skip("Redis is not available here")
storage = RedisStorage.from_url(redis_server)
try:
yield storage
finally:
conn = await storage.redis
await conn.flushdb()
await storage.close()
@pytest.fixture()
async def memory_storage():
storage = MemoryStorage()
try:
yield storage
finally:
await storage.close()
@pytest.fixture()
def bot():
bot = MockedBot()

7
tests/docker-compose.yml Normal file
View file

@ -0,0 +1,7 @@
version: "3.9"
services:
redis:
image: redis:6-alpine
ports:
- "${REDIS_PORT-6379}:6379"

View file

@ -6,7 +6,7 @@ import pytest
from aiogram.client.session.base import BaseSession, TelegramType
from aiogram.client.telegram import PRODUCTION, TelegramAPIServer
from aiogram.methods import DeleteMessage, GetMe, Response, TelegramMethod
from aiogram.methods import DeleteMessage, GetMe, TelegramMethod
from aiogram.types import UNSET
try:
@ -20,7 +20,9 @@ class CustomSession(BaseSession):
async def close(self):
pass
async def make_request(self, token: str, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET) -> None: # type: ignore
async def make_request(
self, token: str, method: TelegramMethod[TelegramType], timeout: Optional[int] = UNSET
) -> None: # type: ignore
assert isinstance(token, str)
assert isinstance(method, TelegramMethod)

View file

@ -3,7 +3,7 @@ from typing import Union
import pytest
from aiogram.methods import EditMessageMedia, Request
from aiogram.types import BufferedInputFile, InputMedia, InputMediaPhoto, Message
from aiogram.types import BufferedInputFile, InputMediaPhoto, Message
from tests.mocked_bot import MockedBot

View file

@ -2,8 +2,8 @@ import datetime
from typing import Optional
import pytest
from aiogram.types import Chat, Message
from aiogram.types import Chat, Message
from tests.mocked_bot import MockedBot

View file

@ -3,7 +3,7 @@ import datetime
import pytest
from aiogram.methods import Request, SendAudio
from aiogram.types import Audio, Chat, File, Message
from aiogram.types import Audio, Chat, Message
from tests.mocked_bot import MockedBot

View file

@ -1,6 +1,6 @@
import pytest
from aiogram.methods import Request, SetChatAdministratorCustomTitle, SetChatTitle
from aiogram.methods import Request, SetChatAdministratorCustomTitle
from tests.mocked_bot import MockedBot

View file

@ -1,7 +1,7 @@
import pytest
from aiogram.methods import Request, SetChatPhoto
from aiogram.types import BufferedInputFile, InputFile
from aiogram.types import BufferedInputFile
from tests.mocked_bot import MockedBot

View file

@ -9,8 +9,6 @@ import pytest
from aiogram import Bot
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import UNHANDLED, SkipHandler
from aiogram.dispatcher.fsm.strategy import FSMStrategy
from aiogram.dispatcher.middlewares.user_context import UserContextMiddleware
from aiogram.dispatcher.router import Router
from aiogram.methods import GetMe, GetUpdates, SendMessage
from aiogram.types import (
@ -423,7 +421,7 @@ class TestDispatcher:
assert User.get_current(False)
return kwargs
result = await router.update.trigger(update, test="PASS")
result = await router.update.trigger(update, test="PASS", bot=None)
assert isinstance(result, dict)
assert result["event_update"] == update
assert result["event_router"] == router
@ -526,8 +524,9 @@ class TestDispatcher:
assert len(log_records) == 1
assert "Cause exception while process update" in log_records[0]
@pytest.mark.parametrize("as_task", [True, False])
@pytest.mark.asyncio
async def test_polling(self, bot: MockedBot):
async def test_polling(self, bot: MockedBot, as_task: bool):
dispatcher = Dispatcher()
async def _mock_updates(*_):
@ -539,7 +538,10 @@ class TestDispatcher:
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
) as patched_listen_updates:
patched_listen_updates.return_value = _mock_updates()
await dispatcher._polling(bot=bot)
await dispatcher._polling(bot=bot, handle_as_tasks=as_task)
if as_task:
pass
else:
mocked_process_update.assert_awaited()
@pytest.mark.asyncio
@ -548,9 +550,12 @@ class TestDispatcher:
router = Router()
dp.include_router(router)
class CustomException(Exception):
pass
@router.message()
async def message_handler(message: Message):
raise Exception("KABOOM")
raise CustomException("KABOOM")
update = Update(
update_id=42,
@ -562,23 +567,23 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"),
),
)
with pytest.raises(Exception, match="KABOOM"):
await dp.update.trigger(update)
with pytest.raises(CustomException, match="KABOOM"):
await dp.update.trigger(update, bot=None)
@router.errors()
async def error_handler(event: Update, exception: Exception):
return "KABOOM"
response = await dp.update.trigger(update)
response = await dp.update.trigger(update, bot=None)
assert response == "KABOOM"
@dp.errors()
async def root_error_handler(event: Update, exception: Exception):
return exception
response = await dp.update.trigger(update)
response = await dp.update.trigger(update, bot=None)
assert isinstance(response, Exception)
assert isinstance(response, CustomException)
assert str(response) == "KABOOM"
@pytest.mark.asyncio
@ -654,20 +659,3 @@ class TestDispatcher:
log_records = [rec.message for rec in caplog.records]
assert "Cause exception while process update" in log_records[0]
@pytest.mark.parametrize(
"strategy,case,expected",
[
[FSMStrategy.USER_IN_CHAT, (-42, 42), (-42, 42)],
[FSMStrategy.CHAT, (-42, 42), (-42, -42)],
[FSMStrategy.GLOBAL_USER, (-42, 42), (42, 42)],
[FSMStrategy.USER_IN_CHAT, (42, 42), (42, 42)],
[FSMStrategy.CHAT, (42, 42), (42, 42)],
[FSMStrategy.GLOBAL_USER, (42, 42), (42, 42)],
],
)
def test_get_current_state_context(self, strategy, case, expected):
dp = Dispatcher(fsm_strategy=strategy)
chat_id, user_id = case
state = dp.current_state(chat_id=chat_id, user_id=user_id)
assert (state.chat_id, state.user_id) == expected

View file

@ -5,7 +5,6 @@ import pytest
from aiogram import F
from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject
from aiogram.dispatcher.filters import Text
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.handler.base import BaseHandler
from aiogram.types import Update

View file

@ -1,10 +1,11 @@
import datetime
import re
from typing import Match
import pytest
from aiogram import F
from aiogram.dispatcher.filters import Command, CommandObject
from aiogram.dispatcher.filters.command import CommandStart
from aiogram.methods import GetMe
from aiogram.types import Chat, Message, User
from tests.mocked_bot import MockedBot
@ -18,45 +19,54 @@ class TestCommandFilter:
assert cmd.commands[0] == "start"
assert cmd == Command(commands=["start"])
@pytest.mark.parametrize(
"text,command,result",
[
["/test@tbot", Command(commands=["test"], commands_prefix="/"), True],
["!test", Command(commands=["test"], commands_prefix="/"), False],
["/test@mention", Command(commands=["test"], commands_prefix="/"), False],
["/tests", Command(commands=["test"], commands_prefix="/"), False],
["/", Command(commands=["test"], commands_prefix="/"), False],
["/ test", Command(commands=["test"], commands_prefix="/"), False],
["", Command(commands=["test"], commands_prefix="/"), False],
[" ", Command(commands=["test"], commands_prefix="/"), False],
["test", Command(commands=["test"], commands_prefix="/"), False],
[" test", Command(commands=["test"], commands_prefix="/"), False],
["a", Command(commands=["test"], commands_prefix="/"), False],
["/test@tbot some args", Command(commands=["test"]), True],
["/test42@tbot some args", Command(commands=[re.compile(r"test(\d+)")]), True],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "some args"),
True,
],
[
"/test42@tbot some args",
Command(commands=[re.compile(r"test(\d+)")], command_magic=F.args == "test"),
False,
],
["/start test", CommandStart(), True],
["/start", CommandStart(deep_link=True), False],
["/start test", CommandStart(deep_link=True), True],
["/start test", CommandStart(deep_link=True, deep_link_encoded=True), False],
["/start dGVzdA", CommandStart(deep_link=True, deep_link_encoded=True), True],
],
)
@pytest.mark.asyncio
async def test_parse_command(self, bot: MockedBot):
# TODO: parametrize
async def test_parse_command(self, bot: MockedBot, text: str, result: bool, command: Command):
# TODO: test ignore case
# TODO: test ignore mention
bot.add_result_for(
GetMe, ok=True, result=User(id=42, is_bot=True, first_name="The bot", username="tbot")
)
command = Command(commands=["test", re.compile(r"test(\d+)")], commands_prefix="/")
assert await command.parse_command("/test@tbot", bot)
assert not await command.parse_command("!test", bot)
assert not await command.parse_command("/test@mention", bot)
assert not await command.parse_command("/tests", bot)
assert not await command.parse_command("/", bot)
assert not await command.parse_command("/ test", bot)
assert not await command.parse_command("", bot)
assert not await command.parse_command(" ", bot)
assert not await command.parse_command("test", bot)
assert not await command.parse_command(" test", bot)
assert not await command.parse_command("a", bot)
message = Message(
message_id=0, text=text, chat=Chat(id=42, type="private"), date=datetime.datetime.now()
)
result = await command.parse_command("/test@tbot some args", bot)
assert isinstance(result, dict)
assert "command" in result
assert isinstance(result["command"], CommandObject)
assert result["command"].command == "test"
assert result["command"].mention == "tbot"
assert result["command"].args == "some args"
result = await command.parse_command("/test42@tbot some args", bot)
assert isinstance(result, dict)
assert "command" in result
assert isinstance(result["command"], CommandObject)
assert result["command"].command == "test42"
assert result["command"].mention == "tbot"
assert result["command"].args == "some args"
assert isinstance(result["command"].match, Match)
response = await command(message, bot)
assert bool(response) is result
@pytest.mark.asyncio
@pytest.mark.parametrize(

View file

@ -1,45 +0,0 @@
import pytest
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage, MemoryStorageRecord
@pytest.fixture()
def storage():
return MemoryStorage()
class TestMemoryStorage:
@pytest.mark.asyncio
async def test_set_state(self, storage: MemoryStorage):
assert await storage.get_state(chat_id=-42, user_id=42) is None
await storage.set_state(chat_id=-42, user_id=42, state="state")
assert await storage.get_state(chat_id=-42, user_id=42) == "state"
assert -42 in storage.storage
assert 42 in storage.storage[-42]
assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
assert storage.storage[-42][42].state == "state"
@pytest.mark.asyncio
async def test_set_data(self, storage: MemoryStorage):
assert await storage.get_data(chat_id=-42, user_id=42) == {}
await storage.set_data(chat_id=-42, user_id=42, data={"foo": "bar"})
assert await storage.get_data(chat_id=-42, user_id=42) == {"foo": "bar"}
assert -42 in storage.storage
assert 42 in storage.storage[-42]
assert isinstance(storage.storage[-42][42], MemoryStorageRecord)
assert storage.storage[-42][42].data == {"foo": "bar"}
@pytest.mark.asyncio
async def test_update_data(self, storage: MemoryStorage):
assert await storage.get_data(chat_id=-42, user_id=42) == {}
assert await storage.update_data(chat_id=-42, user_id=42, data={"foo": "bar"}) == {
"foo": "bar"
}
assert await storage.update_data(chat_id=-42, user_id=42, data={"baz": "spam"}) == {
"foo": "bar",
"baz": "spam",
}

View file

@ -0,0 +1,21 @@
import pytest
from aiogram.dispatcher.fsm.storage.redis import RedisStorage
from tests.mocked_bot import MockedBot
@pytest.mark.redis
class TestRedisStorage:
@pytest.mark.parametrize(
"prefix_bot,result",
[
[False, "fsm:-1:2"],
[True, "fsm:42:-1:2"],
[{42: "kaboom"}, "fsm:kaboom:-1:2"],
[lambda bot: "kaboom", "fsm:kaboom:-1:2"],
],
)
@pytest.mark.asyncio
async def test_generate_key(self, bot: MockedBot, redis_server, prefix_bot, result):
storage = RedisStorage.from_url(redis_server, prefix_bot=prefix_bot)
assert storage.generate_key(bot, -1, 2) == result

View file

@ -0,0 +1,44 @@
import pytest
from aiogram.dispatcher.fsm.storage.base import BaseStorage
from tests.mocked_bot import MockedBot
@pytest.mark.parametrize(
"storage",
[pytest.lazy_fixture("redis_storage"), pytest.lazy_fixture("memory_storage")],
)
class TestStorages:
@pytest.mark.asyncio
async def test_lock(self, bot: MockedBot, storage: BaseStorage):
# TODO: ?!?
async with storage.lock(bot=bot, chat_id=-42, user_id=42):
assert True, "You are kidding me?"
@pytest.mark.asyncio
async def test_set_state(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state="state")
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) == "state"
await storage.set_state(bot=bot, chat_id=-42, user_id=42, state=None)
assert await storage.get_state(bot=bot, chat_id=-42, user_id=42) is None
@pytest.mark.asyncio
async def test_set_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {"foo": "bar"}
await storage.set_data(bot=bot, chat_id=-42, user_id=42, data={})
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
@pytest.mark.asyncio
async def test_update_data(self, bot: MockedBot, storage: BaseStorage):
assert await storage.get_data(bot=bot, chat_id=-42, user_id=42) == {}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"foo": "bar"}
) == {"foo": "bar"}
assert await storage.update_data(
bot=bot, chat_id=-42, user_id=42, data={"baz": "spam"}
) == {"foo": "bar", "baz": "spam"}

View file

@ -2,27 +2,28 @@ import pytest
from aiogram.dispatcher.fsm.context import FSMContext
from aiogram.dispatcher.fsm.storage.memory import MemoryStorage
from tests.mocked_bot import MockedBot
@pytest.fixture()
def state():
def state(bot: MockedBot):
storage = MemoryStorage()
ctx = storage.storage[-42][42]
ctx = storage.storage[bot][-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
return FSMContext(storage=storage, user_id=-42, chat_id=42)
return FSMContext(bot=bot, storage=storage, user_id=-42, chat_id=42)
class TestFSMContext:
@pytest.mark.asyncio
async def test_address_mapping(self):
async def test_address_mapping(self, bot: MockedBot):
storage = MemoryStorage()
ctx = storage.storage[-42][42]
ctx = storage.storage[bot][-42][42]
ctx.state = "test"
ctx.data = {"foo": "bar"}
state = FSMContext(storage=storage, chat_id=-42, user_id=42)
state2 = FSMContext(storage=storage, chat_id=42, user_id=42)
state3 = FSMContext(storage=storage, chat_id=69, user_id=69)
state = FSMContext(bot=bot, storage=storage, chat_id=-42, user_id=42)
state2 = FSMContext(bot=bot, storage=storage, chat_id=42, user_id=42)
state3 = FSMContext(bot=bot, storage=storage, chat_id=69, user_id=69)
assert await state.get_state() == "test"
assert await state2.get_state() is None

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest
from aiogram.dispatcher.handler import ChosenInlineResultHandler
from aiogram.types import CallbackQuery, ChosenInlineResult, User
from aiogram.types import ChosenInlineResult, User
class TestChosenInlineResultHandler:

View file

@ -2,16 +2,7 @@ from typing import Any
import pytest
from aiogram.dispatcher.handler import ErrorHandler, PollHandler
from aiogram.types import (
CallbackQuery,
InlineQuery,
Poll,
PollOption,
ShippingAddress,
ShippingQuery,
User,
)
from aiogram.dispatcher.handler import ErrorHandler
class TestErrorHandler:

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest
from aiogram.dispatcher.handler import InlineQueryHandler
from aiogram.types import CallbackQuery, InlineQuery, User
from aiogram.types import InlineQuery, User
class TestCallbackQueryHandler:

View file

@ -3,15 +3,7 @@ from typing import Any
import pytest
from aiogram.dispatcher.handler import PollHandler
from aiogram.types import (
CallbackQuery,
InlineQuery,
Poll,
PollOption,
ShippingAddress,
ShippingQuery,
User,
)
from aiogram.types import Poll, PollOption
class TestShippingQueryHandler:

View file

@ -3,7 +3,7 @@ from typing import Any
import pytest
from aiogram.dispatcher.handler import ShippingQueryHandler
from aiogram.types import CallbackQuery, InlineQuery, ShippingAddress, ShippingQuery, User
from aiogram.types import ShippingAddress, ShippingQuery, User
class TestShippingQueryHandler:

View file

@ -0,0 +1,27 @@
import pytest
from aiogram.utils.auth_widget import check_integrity
TOKEN = "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11"
@pytest.fixture
def data():
return {
"id": "42",
"first_name": "John",
"last_name": "Smith",
"username": "username",
"photo_url": "https://t.me/i/userpic/320/picname.jpg",
"auth_date": "1565810688",
"hash": "c303db2b5a06fe41d23a9b14f7c545cfc11dcc7473c07c9c5034ae60062461ce",
}
class TestCheckIntegrity:
def test_ok(self, data):
assert check_integrity(TOKEN, data) is True
def test_fail(self, data):
data.pop("username")
assert check_integrity(TOKEN, data) is False

View file

@ -0,0 +1,94 @@
import pytest
from async_lru import alru_cache
from aiogram.utils.deep_linking import (
create_start_link,
create_startgroup_link,
decode_payload,
encode_payload,
)
from tests.mocked_bot import MockedBot
PAYLOADS = [
"foo",
"AAbbCCddEEff1122334455",
"aaBBccDDeeFF5544332211",
-12345678901234567890,
12345678901234567890,
]
WRONG_PAYLOADS = [
"@BotFather",
"Some:special$characters#=",
"spaces spaces spaces",
1234567890123456789.0,
]
@pytest.fixture(params=PAYLOADS, name="payload")
def payload_fixture(request):
return request.param
@pytest.fixture(params=WRONG_PAYLOADS, name="wrong_payload")
def wrong_payload_fixture(request):
return request.param
@pytest.fixture(autouse=True)
def get_bot_user_fixture(monkeypatch):
"""Monkey patching of bot.me calling."""
@alru_cache()
async def get_bot_user_mock(self):
from aiogram.types import User
return User(
id=12345678,
is_bot=True,
first_name="FirstName",
last_name="LastName",
username="username",
language_code="uk-UA",
)
monkeypatch.setattr(MockedBot, "me", get_bot_user_mock)
@pytest.mark.asyncio
class TestDeepLinking:
async def test_get_start_link(self, bot, payload):
link = await create_start_link(bot=bot, payload=payload)
assert link == f"https://t.me/username?start={payload}"
async def test_wrong_symbols(self, bot, wrong_payload):
with pytest.raises(ValueError):
await create_start_link(bot, wrong_payload)
async def test_get_startgroup_link(self, bot, payload):
link = await create_startgroup_link(bot, payload)
assert link == f"https://t.me/username?startgroup={payload}"
async def test_filter_encode_and_decode(self, payload):
encoded = encode_payload(payload)
decoded = decode_payload(encoded)
assert decoded == str(payload)
async def test_get_start_link_with_encoding(self, bot, wrong_payload):
# define link
link = await create_start_link(bot, wrong_payload, encode=True)
# define reference link
encoded_payload = encode_payload(wrong_payload)
assert link == f"https://t.me/username?start={encoded_payload}"
async def test_64_len_payload(self, bot):
payload = "p" * 64
link = await create_start_link(bot, payload)
assert link
async def test_too_long_payload(self, bot):
payload = "p" * 65
print(payload, len(payload))
with pytest.raises(ValueError):
await create_start_link(bot, payload)

View file

@ -0,0 +1,24 @@
from typing import Any, Dict
import pytest
from aiogram.utils.link import create_telegram_link, create_tg_link
class TestLink:
@pytest.mark.parametrize(
"base,params,result",
[["user", dict(id=42), "tg://user?id=42"]],
)
def test_create_tg_link(self, base: str, params: Dict[str, Any], result: str):
assert create_tg_link(base, **params) == result
@pytest.mark.parametrize(
"base,params,result",
[
["username", dict(), "https://t.me/username"],
["username", dict(start="test"), "https://t.me/username?start=test"],
],
)
def test_create_telegram_link(self, base: str, params: Dict[str, Any], result: str):
assert create_telegram_link(base, **params) == result

View file

@ -35,7 +35,7 @@ class TestMarkdown:
[hitalic, ("test", "test"), " ", "<i>test test</i>"],
[code, ("test", "test"), " ", "`test test`"],
[hcode, ("test", "test"), " ", "<code>test test</code>"],
[pre, ("test", "test"), " ", "```test test```"],
[pre, ("test", "test"), " ", "```\ntest test\n```"],
[hpre, ("test", "test"), " ", "<pre>test test</pre>"],
[underline, ("test", "test"), " ", "__\rtest test__\r"],
[hunderline, ("test", "test"), " ", "<u>test test</u>"],

View file

@ -55,7 +55,7 @@ class TestTextDecoration:
[markdown_decoration, MessageEntity(type="bold", offset=0, length=5), "*test*"],
[markdown_decoration, MessageEntity(type="italic", offset=0, length=5), "_\rtest_\r"],
[markdown_decoration, MessageEntity(type="code", offset=0, length=5), "`test`"],
[markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```test```"],
[markdown_decoration, MessageEntity(type="pre", offset=0, length=5), "```\ntest\n```"],
[
markdown_decoration,
MessageEntity(type="pre", offset=0, length=5, language="python"),