diff --git a/CHANGES/1280.feature.rst b/CHANGES/1280.feature.rst new file mode 100644 index 00000000..965ff1fc --- /dev/null +++ b/CHANGES/1280.feature.rst @@ -0,0 +1,2 @@ +Introduced Scenes feature that helps you to simplify user interactions using Finite State Machines. +Read more about 👉 :ref:`Scenes `. diff --git a/aiogram/__init__.py b/aiogram/__init__.py index 9aedf85b..1fbee0ec 100644 --- a/aiogram/__init__.py +++ b/aiogram/__init__.py @@ -1,3 +1,4 @@ +import asyncio as _asyncio from contextlib import suppress from aiogram.dispatcher.flags import FlagGenerator @@ -14,11 +15,9 @@ from .utils.text_decorations import html_decoration as html from .utils.text_decorations import markdown_decoration as md with suppress(ImportError): - import asyncio - import uvloop as _uvloop - asyncio.set_event_loop_policy(_uvloop.EventLoopPolicy()) + _asyncio.set_event_loop_policy(_uvloop.EventLoopPolicy()) F = MagicFilter() diff --git a/aiogram/dispatcher/event/handler.py b/aiogram/dispatcher/event/handler.py index 1a353e1e..8c283dd2 100644 --- a/aiogram/dispatcher/event/handler.py +++ b/aiogram/dispatcher/event/handler.py @@ -18,7 +18,7 @@ CallbackType = Callable[..., Any] @dataclass -class CallableMixin: +class CallableObject: callback: CallbackType awaitable: bool = field(init=False) params: Set[str] = field(init=False) @@ -49,7 +49,7 @@ class CallableMixin: @dataclass -class FilterObject(CallableMixin): +class FilterObject(CallableObject): magic: Optional[MagicFilter] = None def __post_init__(self) -> None: @@ -76,7 +76,7 @@ class FilterObject(CallableMixin): @dataclass -class HandlerObject(CallableMixin): +class HandlerObject(CallableObject): filters: Optional[List[FilterObject]] = None flags: Dict[str, Any] = field(default_factory=dict) diff --git a/aiogram/exceptions.py b/aiogram/exceptions.py index 2632fcdc..d195aa7b 100644 --- a/aiogram/exceptions.py +++ b/aiogram/exceptions.py @@ -37,6 +37,12 @@ class CallbackAnswerException(AiogramError): """ +class SceneException(AiogramError): + """ + Exception for scenes. + """ + + class UnsupportedKeywordArgument(DetailedAiogramError): """ Exception raised when a keyword argument is passed as filter. diff --git a/aiogram/fsm/scene.py b/aiogram/fsm/scene.py new file mode 100644 index 00000000..fe4de5c4 --- /dev/null +++ b/aiogram/fsm/scene.py @@ -0,0 +1,912 @@ +from __future__ import annotations + +import inspect +from collections import defaultdict +from dataclasses import dataclass, replace +from enum import Enum, auto +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union + +from typing_extensions import Self + +from aiogram import loggers +from aiogram.dispatcher.dispatcher import Dispatcher +from aiogram.dispatcher.event.bases import NextMiddlewareType +from aiogram.dispatcher.event.handler import CallableObject, CallbackType +from aiogram.dispatcher.flags import extract_flags_from_object +from aiogram.dispatcher.router import Router +from aiogram.exceptions import SceneException +from aiogram.filters import StateFilter +from aiogram.fsm.context import FSMContext +from aiogram.fsm.state import State +from aiogram.fsm.storage.memory import MemoryStorageRecord +from aiogram.types import TelegramObject, Update + + +class HistoryManager: + def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10): + self._size = size + self._state = state + self._history_state = FSMContext( + storage=state.storage, key=replace(state.key, destiny=destiny) + ) + + async def push(self, state: Optional[str], data: Dict[str, Any]) -> None: + history_data = await self._history_state.get_data() + history = history_data.setdefault("history", []) + history.append({"state": state, "data": data}) + if len(history) > self._size: + history = history[-self._size :] + loggers.scene.debug("Push state=%s data=%s to history", state, data) + + await self._history_state.update_data(history=history) + + async def pop(self) -> Optional[MemoryStorageRecord]: + history_data = await self._history_state.get_data() + history = history_data.setdefault("history", []) + if not history: + return None + record = history.pop() + state = record["state"] + data = record["data"] + if not history: + await self._history_state.set_data({}) + else: + await self._history_state.update_data(history=history) + loggers.scene.debug("Pop state=%s data=%s from history", state, data) + return MemoryStorageRecord(state=state, data=data) + + async def get(self) -> Optional[MemoryStorageRecord]: + history_data = await self._history_state.get_data() + history = history_data.setdefault("history", []) + if not history: + return None + return MemoryStorageRecord(**history[-1]) + + async def all(self) -> List[MemoryStorageRecord]: + history_data = await self._history_state.get_data() + history = history_data.setdefault("history", []) + return [MemoryStorageRecord(**item) for item in history] + + async def clear(self) -> None: + loggers.scene.debug("Clear history") + await self._history_state.set_data({}) + + async def snapshot(self) -> None: + state = await self._state.get_state() + data = await self._state.get_data() + await self.push(state, data) + + async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None: + await self._state.set_state(state) + await self._state.set_data(data) + + async def rollback(self) -> Optional[str]: + previous_state = await self.pop() + if not previous_state: + await self._set_state(None, {}) + return None + + loggers.scene.debug( + "Rollback to state=%s data=%s", + previous_state.state, + previous_state.data, + ) + await self._set_state(previous_state.state, previous_state.data) + return previous_state.state + + +class ObserverDecorator: + def __init__( + self, + name: str, + filters: tuple[CallbackType, ...], + action: SceneAction | None = None, + after: Optional[After] = None, + ) -> None: + self.name = name + self.filters = filters + self.action = action + self.after = after + + def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None: + handlers = getattr(target, "__aiogram_handler__", None) + if not handlers: + handlers = [] + setattr(target, "__aiogram_handler__", handlers) + + handlers.append( + HandlerContainer( + name=self.name, + handler=target, + filters=self.filters, + after=self.after, + ) + ) + + def _wrap_action(self, target: CallbackType) -> None: + assert self.action is not None, "Scene action is not specified" + + action = getattr(target, "__aiogram_action__", None) + if action is None: + action = defaultdict(dict) + setattr(target, "__aiogram_action__", action) + action[self.action][self.name] = CallableObject(target) + + def __call__(self, target: CallbackType) -> CallbackType: + if inspect.isfunction(target): + if self.action is None: + self._wrap_filter(target) + else: + self._wrap_action(target) + else: + raise TypeError("Only function or method is allowed") + return target + + def leave(self) -> ActionContainer: + return ActionContainer(self.name, self.filters, SceneAction.leave) + + def enter(self, target: Type[Scene]) -> ActionContainer: + return ActionContainer(self.name, self.filters, SceneAction.enter, target) + + def exit(self) -> ActionContainer: + return ActionContainer(self.name, self.filters, SceneAction.exit) + + def back(self) -> ActionContainer: + return ActionContainer(self.name, self.filters, SceneAction.back) + + +class SceneAction(Enum): + enter = auto() + leave = auto() + exit = auto() + back = auto() + + +class ActionContainer: + def __init__( + self, + name: str, + filters: Tuple[CallbackType, ...], + action: SceneAction, + target: Optional[Union[Type[Scene], str]] = None, + ) -> None: + self.name = name + self.filters = filters + self.action = action + self.target = target + + async def execute(self, wizard: SceneWizard) -> None: + if self.action == SceneAction.enter and self.target is not None: + await wizard.goto(self.target) + elif self.action == SceneAction.leave: + await wizard.leave() + elif self.action == SceneAction.exit: + await wizard.exit() + elif self.action == SceneAction.back: + await wizard.back() + + +class HandlerContainer: + def __init__( + self, + name: str, + handler: CallbackType, + filters: Tuple[CallbackType, ...], + after: Optional[After] = None, + ) -> None: + self.name = name + self.handler = handler + self.filters = filters + self.after = after + + +@dataclass() +class SceneConfig: + state: Optional[str] + """Scene state""" + handlers: List[HandlerContainer] + """Scene handlers""" + actions: Dict[SceneAction, Dict[str, CallableObject]] + """Scene actions""" + reset_data_on_enter: Optional[bool] = None + """Reset scene data on enter""" + reset_history_on_enter: Optional[bool] = None + """Reset scene history on enter""" + callback_query_without_state: Optional[bool] = None + """Allow callback query without state""" + + +async def _empty_handler(*args: Any, **kwargs: Any) -> None: + pass + + +class SceneHandlerWrapper: + def __init__( + self, + scene: Type[Scene], + handler: CallbackType, + after: Optional[After] = None, + ) -> None: + self.scene = scene + self.handler = CallableObject(handler) + self.after = after + + async def __call__( + self, + event: TelegramObject, + **kwargs: Any, + ) -> Any: + state: FSMContext = kwargs["state"] + scenes: ScenesManager = kwargs["scenes"] + event_update: Update = kwargs["event_update"] + scene = self.scene( + wizard=SceneWizard( + scene_config=self.scene.__scene_config__, + manager=scenes, + state=state, + update_type=event_update.event_type, + event=event, + data=kwargs, + ) + ) + + result = await self.handler.call(scene, event, **kwargs) + + if self.after: + action_container = ActionContainer( + "after", + (), + self.after.action, + self.after.scene, + ) + await action_container.execute(scene.wizard) + return result + + def __await__(self) -> Self: + return self + + def __str__(self) -> str: + result = f"SceneHandlerWrapper({self.scene}, {self.handler.callback}" + if self.after: + result += f", after={self.after}" + result += ")" + return result + + +class Scene: + """ + Represents a scene in a conversation flow. + + A scene is a specific state in a conversation where certain actions can take place. + + Each scene has a set of filters that determine when it should be triggered, + and a set of handlers that define the actions to be executed when the scene is active. + + .. note:: + This class is not meant to be used directly. Instead, it should be subclassed + to define custom scenes. + """ + + __scene_config__: ClassVar[SceneConfig] + """Scene configuration.""" + + def __init__( + self, + wizard: SceneWizard, + ) -> None: + self.wizard = wizard + self.wizard.scene = self + + def __init_subclass__(cls, **kwargs: Any) -> None: + state_name = kwargs.pop("state", None) + reset_data_on_enter = kwargs.pop("reset_data_on_enter", None) + reset_history_on_enter = kwargs.pop("reset_history_on_enter", None) + callback_query_without_state = kwargs.pop("callback_query_without_state", None) + + super().__init_subclass__(**kwargs) + + handlers: list[HandlerContainer] = [] + actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict) + + for base in cls.__bases__: + if not issubclass(base, Scene): + continue + + parent_scene_config = getattr(base, "__scene_config__", None) + if not parent_scene_config: + continue + + handlers.extend(parent_scene_config.handlers) + for action, action_handlers in parent_scene_config.actions.items(): + actions[action].update(action_handlers) + + if reset_data_on_enter is None: + reset_data_on_enter = parent_scene_config.reset_data_on_enter + if reset_history_on_enter is None: + reset_history_on_enter = parent_scene_config.reset_history_on_enter + if callback_query_without_state is None: + callback_query_without_state = parent_scene_config.callback_query_without_state + + for name in vars(cls): + value = getattr(cls, name) + + if scene_handlers := getattr(value, "__aiogram_handler__", None): + handlers.extend(scene_handlers) + if isinstance(value, ObserverDecorator): + handlers.append( + HandlerContainer( + value.name, + _empty_handler, + value.filters, + after=value.after, + ) + ) + if hasattr(value, "__aiogram_action__"): + for action, action_handlers in value.__aiogram_action__.items(): + actions[action].update(action_handlers) + + cls.__scene_config__ = SceneConfig( + state=state_name, + handlers=handlers, + actions=dict(actions), + reset_data_on_enter=reset_data_on_enter, + reset_history_on_enter=reset_history_on_enter, + callback_query_without_state=callback_query_without_state, + ) + + @classmethod + def add_to_router(cls, router: Router) -> None: + """ + Adds the scene to the given router. + + :param router: + :return: + """ + scene_config = cls.__scene_config__ + used_observers = set() + + for handler in scene_config.handlers: + router.observers[handler.name].register( + SceneHandlerWrapper( + cls, + handler.handler, + after=handler.after, + ), + *handler.filters, + flags=extract_flags_from_object(handler.handler), + ) + used_observers.add(handler.name) + + for observer_name in used_observers: + if scene_config.callback_query_without_state and observer_name == "callback_query": + continue + router.observers[observer_name].filter(StateFilter(scene_config.state)) + + @classmethod + def as_router(cls, name: Optional[str] = None) -> Router: + """ + Returns the scene as a router. + + :return: new router + """ + if name is None: + name = ( + f"Scene '{cls.__module__}.{cls.__qualname__}' " + f"for state {cls.__scene_config__.state!r}" + ) + router = Router(name=name) + cls.add_to_router(router) + return router + + @classmethod + def as_handler(cls, **kwargs: Any) -> CallbackType: + """ + Create an entry point handler for the scene, can be used to simplify the handler + that starts the scene. + + >>> router.message.register(MyScene.as_handler(), Command("start")) + """ + + async def enter_to_scene_handler(event: TelegramObject, scenes: ScenesManager) -> None: + await scenes.enter(cls, **kwargs) + + return enter_to_scene_handler + + +class SceneWizard: + """ + A class that represents a wizard for managing scenes in a Telegram bot. + + Instance of this class is passed to each scene as a parameter. + So, you can use it to transition between scenes, get and set data, etc. + + .. note:: + + This class is not meant to be used directly. Instead, it should be used + as a parameter in the scene constructor. + + """ + + def __init__( + self, + scene_config: SceneConfig, + manager: ScenesManager, + state: FSMContext, + update_type: str, + event: TelegramObject, + data: Dict[str, Any], + ): + """ + A class that represents a wizard for managing scenes in a Telegram bot. + + :param scene_config: The configuration of the scene. + :param manager: The scene manager. + :param state: The FSMContext object for storing the state of the scene. + :param update_type: The type of the update event. + :param event: The TelegramObject represents the event. + :param data: Additional data for the scene. + """ + self.scene_config = scene_config + self.manager = manager + self.state = state + self.update_type = update_type + self.event = event + self.data = data + + self.scene: Optional[Scene] = None + + async def enter(self, **kwargs: Any) -> None: + """ + Enter method is used to transition into a scene in the SceneWizard class. + It sets the state, clears data and history if specified, + and triggers entering event of the scene. + + :param kwargs: Additional keyword arguments. + :return: None + """ + loggers.scene.debug("Entering scene %r", self.scene_config.state) + if self.scene_config.reset_data_on_enter: + await self.state.set_data({}) + if self.scene_config.reset_history_on_enter: + await self.manager.history.clear() + await self.state.set_state(self.scene_config.state) + await self._on_action(SceneAction.enter, **kwargs) + + async def leave(self, _with_history: bool = True, **kwargs: Any) -> None: + """ + Leaves the current scene. + This method is used to exit a scene and transition to the next scene. + + :param _with_history: Whether to include history in the snapshot. Defaults to True. + :param kwargs: Additional keyword arguments. + :return: None + + """ + loggers.scene.debug("Leaving scene %r", self.scene_config.state) + if _with_history: + await self.manager.history.snapshot() + await self._on_action(SceneAction.leave, **kwargs) + + async def exit(self, **kwargs: Any) -> None: + """ + Exit the current scene and enter the default scene/state. + + :param kwargs: Additional keyword arguments. + :return: None + """ + loggers.scene.debug("Exiting scene %r", self.scene_config.state) + await self.manager.history.clear() + await self._on_action(SceneAction.exit, **kwargs) + await self.manager.enter(None, _check_active=False, **kwargs) + + async def back(self, **kwargs: Any) -> None: + """ + This method is used to go back to the previous scene. + + :param kwargs: Keyword arguments that can be passed to the method. + :return: None + """ + loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state) + await self.leave(_with_history=False, **kwargs) + new_scene = await self.manager.history.rollback() + await self.manager.enter(new_scene, _check_active=False, **kwargs) + + async def retake(self, **kwargs: Any) -> None: + """ + This method allows to re-enter the current scene. + + :param kwargs: Additional keyword arguments to pass to the scene. + :return: None + """ + assert self.scene_config.state is not None, "Scene state is not specified" + await self.goto(self.scene_config.state, **kwargs) + + async def goto(self, scene: Union[Type[Scene], str], **kwargs: Any) -> None: + """ + The `goto` method transitions to a new scene. + It first calls the `leave` method to perform any necessary cleanup + in the current scene, then calls the `enter` event to enter the specified scene. + + :param scene: The scene to transition to. Can be either a `Scene` instance + or a string representing the scene. + :param kwargs: Additional keyword arguments to pass to the `enter` + method of the scene manager. + :return: None + """ + await self.leave(**kwargs) + await self.manager.enter(scene, _check_active=False, **kwargs) + + async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool: + if not self.scene: + raise SceneException("Scene is not initialized") + + loggers.scene.debug("Call action %r in scene %r", action.name, self.scene_config.state) + action_config = self.scene_config.actions.get(action, {}) + if not action_config: + loggers.scene.debug( + "Action %r not found in scene %r", action.name, self.scene_config.state + ) + return False + + event_type = self.update_type + if event_type not in action_config: + loggers.scene.debug( + "Action %r for event %r not found in scene %r", + action.name, + event_type, + self.scene_config.state, + ) + return False + + await action_config[event_type].call(self.scene, self.event, **{**self.data, **kwargs}) + return True + + async def set_data(self, data: Dict[str, Any]) -> None: + """ + Sets custom data in the current state. + + :param data: A dictionary containing the custom data to be set in the current state. + :return: None + """ + await self.state.set_data(data=data) + + async def get_data(self) -> Dict[str, Any]: + """ + This method returns the data stored in the current state. + + :return: A dictionary containing the data stored in the scene state. + """ + return await self.state.get_data() + + async def update_data( + self, data: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """ + This method updates the data stored in the current state + + :param data: Optional dictionary of data to update. + :param kwargs: Additional key-value pairs of data to update. + :return: Dictionary of updated data + """ + if data: + kwargs.update(data) + return await self.state.update_data(data=kwargs) + + async def clear_data(self) -> None: + """ + Clears the data. + + :return: None + """ + await self.set_data({}) + + +class ScenesManager: + """ + The ScenesManager class is responsible for managing scenes in an application. + It provides methods for entering and exiting scenes, as well as retrieving the active scene. + """ + + def __init__( + self, + registry: SceneRegistry, + update_type: str, + event: TelegramObject, + state: FSMContext, + data: Dict[str, Any], + ) -> None: + self.registry = registry + self.update_type = update_type + self.event = event + self.state = state + self.data = data + + self.history = HistoryManager(self.state) + + async def _get_scene(self, scene_type: Optional[Union[Type[Scene], str]]) -> Scene: + scene_type = self.registry.get(scene_type) + return scene_type( + wizard=SceneWizard( + scene_config=scene_type.__scene_config__, + manager=self, + state=self.state, + update_type=self.update_type, + event=self.event, + data=self.data, + ), + ) + + async def _get_active_scene(self) -> Optional[Scene]: + state = await self.state.get_state() + try: + return await self._get_scene(state) + except SceneException: + return None + + async def enter( + self, + scene_type: Optional[Union[Type[Scene], str]], + _check_active: bool = True, + **kwargs: Any, + ) -> None: + """ + Enters the specified scene. + + :param scene_type: Optional Type[Scene] or str representing the scene type to enter. + :param _check_active: Optional bool indicating whether to check if + there is an active scene to exit before entering the new scene. Defaults to True. + :param kwargs: Additional keyword arguments to pass to the scene's wizard.enter() method. + :return: None + """ + if _check_active: + active_scene = await self._get_active_scene() + if active_scene is not None: + await active_scene.wizard.exit(**kwargs) + + try: + scene = await self._get_scene(scene_type) + except SceneException: + if scene_type is not None: + raise + await self.state.set_state(None) + else: + await scene.wizard.enter(**kwargs) + + async def close(self, **kwargs: Any) -> None: + """ + Close method is used to exit the currently active scene in the ScenesManager. + + :param kwargs: Additional keyword arguments passed to the scene's exit method. + :return: None + """ + scene = await self._get_active_scene() + if not scene: + return + await scene.wizard.exit(**kwargs) + + +class SceneRegistry: + """ + A class that represents a registry for scenes in a Telegram bot. + """ + + def __init__(self, router: Router, register_on_add: bool = True) -> None: + """ + Initialize a new instance of the SceneRegistry class. + + :param router: The router instance used for scene registration. + :param register_on_add: Whether to register the scenes to the router when they are added. + """ + self.router = router + self.register_on_add = register_on_add + + self._scenes: Dict[Optional[str], Type[Scene]] = {} + self._setup_middleware(router) + + def _setup_middleware(self, router: Router) -> None: + if isinstance(router, Dispatcher): + # Small optimization for Dispatcher + # - we don't need to set up middleware for all observers + router.update.outer_middleware(self._update_middleware) + return + + for observer in router.observers.values(): + if observer.event_name in {"update", "error"}: + continue + observer.outer_middleware(self._middleware) + + async def _update_middleware( + self, + handler: NextMiddlewareType[TelegramObject], + event: TelegramObject, + data: Dict[str, Any], + ) -> Any: + assert isinstance(event, Update), "Event must be an Update instance" + + data["scenes"] = ScenesManager( + registry=self, + update_type=event.event_type, + event=event.event, + state=data["state"], + data=data, + ) + return await handler(event, data) + + async def _middleware( + self, + handler: NextMiddlewareType[TelegramObject], + event: TelegramObject, + data: Dict[str, Any], + ) -> Any: + update: Update = data["event_update"] + data["scenes"] = ScenesManager( + registry=self, + update_type=update.event_type, + event=event, + state=data["state"], + data=data, + ) + return await handler(event, data) + + def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None: + """ + This method adds the specified scenes to the registry + and optionally registers it to the router. + + If a scene with the same state already exists in the registry, a SceneException is raised. + + .. warning:: + + If the router is not specified, the scenes will not be registered to the router. + You will need to include the scenes manually to the router or use the register method. + + :param scenes: A variable length parameter that accepts one or more types of scenes. + These scenes are instances of the Scene class. + :param router: An optional parameter that specifies the router + to which the scenes should be added. + :return: None + """ + if not scenes: + raise ValueError("At least one scene must be specified") + + for scene in scenes: + if scene.__scene_config__.state in self._scenes: + raise SceneException( + f"Scene with state {scene.__scene_config__.state!r} already exists" + ) + + self._scenes[scene.__scene_config__.state] = scene + + if router: + router.include_router(scene.as_router()) + elif self.register_on_add: + self.router.include_router(scene.as_router()) + + def register(self, *scenes: Type[Scene]) -> None: + """ + Registers one or more scenes to the SceneRegistry. + + :param scenes: One or more scene classes to register. + :return: None + """ + self.add(*scenes, router=self.router) + + def get(self, scene: Optional[Union[Type[Scene], str]]) -> Type[Scene]: + """ + This method returns the registered Scene object for the specified scene. + The scene parameter can be either a Scene object or a string representing + the name of the scene. If a Scene object is provided, the state attribute + of the SceneConfig object associated with the Scene object will be used as the scene name. + If None or an invalid type is provided, a SceneException will be raised. + + If the specified scene is not registered in the SceneRegistry object, + a SceneException will be raised. + + :param scene: A Scene object or a string representing the name of the scene. + :return: The registered Scene object corresponding to the given scene parameter. + + """ + if inspect.isclass(scene) and issubclass(scene, Scene): + scene = scene.__scene_config__.state + if isinstance(scene, State): + scene = scene.state + if scene is not None and not isinstance(scene, str): + raise SceneException("Scene must be a subclass of Scene or a string") + + try: + return self._scenes[scene] + except KeyError: + raise SceneException(f"Scene {scene!r} is not registered") + + +@dataclass +class After: + action: SceneAction + scene: Optional[Union[Type[Scene], str]] = None + + @classmethod + def exit(cls) -> After: + return cls(action=SceneAction.exit) + + @classmethod + def back(cls) -> After: + return cls(action=SceneAction.back) + + @classmethod + def goto(cls, scene: Optional[Union[Type[Scene], str]]) -> After: + return cls(action=SceneAction.enter, scene=scene) + + +class ObserverMarker: + def __init__(self, name: str) -> None: + self.name = name + + def __call__( + self, + *filters: CallbackType, + after: Optional[After] = None, + ) -> ObserverDecorator: + return ObserverDecorator( + self.name, + filters, + after=after, + ) + + def enter(self, *filters: CallbackType) -> ObserverDecorator: + return ObserverDecorator(self.name, filters, action=SceneAction.enter) + + def leave(self) -> ObserverDecorator: + return ObserverDecorator(self.name, (), action=SceneAction.leave) + + def exit(self) -> ObserverDecorator: + return ObserverDecorator(self.name, (), action=SceneAction.exit) + + def back(self) -> ObserverDecorator: + return ObserverDecorator(self.name, (), action=SceneAction.back) + + +class OnMarker: + """ + The `OnMarker` class is used as a marker class to define different + types of events in the Scenes. + + Attributes: + + - :code:`message`: Event marker for handling `Message` events. + - :code:`edited_message`: Event marker for handling edited `Message` events. + - :code:`channel_post`: Event marker for handling channel `Post` events. + - :code:`edited_channel_post`: Event marker for handling edited channel `Post` events. + - :code:`inline_query`: Event marker for handling `InlineQuery` events. + - :code:`chosen_inline_result`: Event marker for handling chosen `InlineResult` events. + - :code:`callback_query`: Event marker for handling `CallbackQuery` events. + - :code:`shipping_query`: Event marker for handling `ShippingQuery` events. + - :code:`pre_checkout_query`: Event marker for handling `PreCheckoutQuery` events. + - :code:`poll`: Event marker for handling `Poll` events. + - :code:`poll_answer`: Event marker for handling `PollAnswer` events. + - :code:`my_chat_member`: Event marker for handling my chat `Member` events. + - :code:`chat_member`: Event marker for handling chat `Member` events. + - :code:`chat_join_request`: Event marker for handling chat `JoinRequest` events. + - :code:`error`: Event marker for handling `Error` events. + + .. note:: + + This is a marker class and does not contain any methods or implementation logic. + """ + + message = ObserverMarker("message") + edited_message = ObserverMarker("edited_message") + channel_post = ObserverMarker("channel_post") + edited_channel_post = ObserverMarker("edited_channel_post") + inline_query = ObserverMarker("inline_query") + chosen_inline_result = ObserverMarker("chosen_inline_result") + callback_query = ObserverMarker("callback_query") + shipping_query = ObserverMarker("shipping_query") + pre_checkout_query = ObserverMarker("pre_checkout_query") + poll = ObserverMarker("poll") + poll_answer = ObserverMarker("poll_answer") + my_chat_member = ObserverMarker("my_chat_member") + chat_member = ObserverMarker("chat_member") + chat_join_request = ObserverMarker("chat_join_request") + + +on = OnMarker() diff --git a/aiogram/loggers.py b/aiogram/loggers.py index ae871eaf..942c124d 100644 --- a/aiogram/loggers.py +++ b/aiogram/loggers.py @@ -4,3 +4,4 @@ dispatcher = logging.getLogger("aiogram.dispatcher") event = logging.getLogger("aiogram.event") middlewares = logging.getLogger("aiogram.middlewares") webhook = logging.getLogger("aiogram.webhook") +scene = logging.getLogger("aiogram.scene") diff --git a/docs/api/download_file.rst b/docs/api/download_file.rst index d60e8051..06450ba3 100644 --- a/docs/api/download_file.rst +++ b/docs/api/download_file.rst @@ -31,9 +31,7 @@ Download file by `file_path` to destination. If you want to automatically create destination (:obj:`io.BytesIO`) use default value of destination and handle result of this method. -.. autoclass:: aiogram.client.bot.Bot - :members: download_file - :exclude-members: __init__ +.. automethod:: aiogram.client.bot.Bot.download_file There are two options where you can download the file: to **disk** or to **binary I/O object**. @@ -81,9 +79,7 @@ Download file by `file_id` or `Downloadable` object to destination. If you want to automatically create destination (:obj:`io.BytesIO`) use default value of destination and handle result of this method. -.. autoclass:: aiogram.client.bot.Bot - :members: download - :exclude-members: __init__ +.. automethod:: aiogram.client.bot.Bot.download It differs from `download_file <#download-file>`__ **only** in that it accepts `file_id` or an `Downloadable` object (object that contains the `file_id` attribute) instead of `file_path`. diff --git a/docs/dispatcher/finite_state_machine/index.rst b/docs/dispatcher/finite_state_machine/index.rst index afa62bff..d14f282f 100644 --- a/docs/dispatcher/finite_state_machine/index.rst +++ b/docs/dispatcher/finite_state_machine/index.rst @@ -95,6 +95,7 @@ Read more .. toctree:: storages + scene .. _wiki: https://en.wikipedia.org/wiki/Finite-state_machine diff --git a/docs/dispatcher/finite_state_machine/scene.rst b/docs/dispatcher/finite_state_machine/scene.rst new file mode 100644 index 00000000..aeb21bb4 --- /dev/null +++ b/docs/dispatcher/finite_state_machine/scene.rst @@ -0,0 +1,243 @@ +.. _Scenes: + +============= +Scenes Wizard +============= + +.. versionadded:: 3.2 + +.. warning:: + + This feature is experimental and may be changed in future versions. + +**aiogram's** basics API is easy to use and powerful, +allowing the implementation of simple interactions such as triggering a command or message +for a response. +However, certain tasks require a dialogue between the user and the bot. +This is where Scenes come into play. + +Understanding Scenes +==================== + +A Scene in **aiogram** is like an abstract, isolated namespace or room that a user can be +ushered into via the code. When a user is inside a Scene, all other global commands or +message handlers are isolated, and they stop responding to user actions. +Scenes provide a structure for more complex interactions, +effectively isolating and managing contexts for different stages of the conversation. +They allow you to control and manage the flow of the conversation in a more organized manner. + +Scene Lifecycle +--------------- + +Each Scene can be "entered", "left" of "exited", allowing for clear transitions between different +stages of the conversation. +For instance, in a multi-step form filling interaction, each step could be a Scene - +the bot guides the user from one Scene to the next as they provide the required information. + +Scene Listeners +--------------- + +Scenes have their own hooks which are command or message listeners that only act while +the user is within the Scene. +These hooks react to user actions while the user is 'inside' the Scene, +providing the responses or actions appropriate for that context. +When the user is ushered from one Scene to another, the actions and responses change +accordingly as the user is now interacting with the set of listeners inside the new Scene. +These 'Scene-specific' hooks or listeners, detached from the global listening context, +allow for more streamlined and organized bot-user interactions. + + +Scene Interactions +------------------ + +Each Scene is like a self-contained world, with interactions defined within the scope of that Scene. +As such, only the handlers defined within the specific Scene will react to user's input during +the lifecycle of that Scene. + + +Scene Benefits +-------------- + +Scenes can help manage more complex interaction workflows and enable more interactive and dynamic +dialogs between the user and the bot. +This offers great flexibility in handling multi-step interactions or conversations with the users. + +How to use Scenes +================= + +For example we have a quiz bot, which asks the user a series of questions and then displays the results. + +Lets start with the data models, in this example simple data models are used to represent +the questions and answers, in real life you would probably use a database to store the data. + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :lines: 18-94 + :caption: Questions list + +Then, we need to create a Scene class that will represent the quiz game scene: + +.. note:: + + Keyword argument passed into class definition describes the scene name - is the same as state of the scene. + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :pyobject: QuizScene + :emphasize-lines: 1 + :lines: -7 + :caption: Quiz Scene + + +Also we need to define a handler that helps to start the quiz game: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Start command handler + :lines: 260-262 + +Once the scene is defined, we need to register it in the SceneRegistry: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :pyobject: create_dispatcher + :caption: Registering the scene + +So, now we can implement the quiz game logic, each question is sent to the user one by one, +and the user's answer is checked at the end of all questions. + +Now we need to write an entry point for the question handler: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Question handler entry point + :pyobject: QuizScene.on_enter + + +Once scene is entered, we should expect the user's answer, so we need to write a handler for it, +this handler should expect the text message, save the answer and retake +the question handler for the next question: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Answer handler + :pyobject: QuizScene.answer + +When user answer with unknown message, we should expect the text message again: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Unknown message handler + :pyobject: QuizScene.unknown_message + +When all questions are answered, we should show the results to the user, as you can see in the code below, +we use `await self.wizard.exit()` to exit from the scene when questions list is over in the `QuizScene.on_enter` handler. + +Thats means that we need to write an exit handler to show the results to the user: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Show results handler + :pyobject: QuizScene.on_exit + +Also we can implement a actions to exit from the quiz game or go back to the previous question: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Exit handler + :pyobject: QuizScene.exit + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Back handler + :pyobject: QuizScene.back + +Now we can run the bot and test the quiz game: + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Run the bot + :lines: 291- + +Complete them all + +.. literalinclude:: ../../../examples/quiz_scene.py + :language: python + :caption: Quiz Example + + +Components +========== + +- :class:`aiogram.fsm.scene.Scene` - represents a scene, contains handlers +- :class:`aiogram.fsm.scene.SceneRegistry` - container for all scenes in the bot, used to register scenes and resolve them by name +- :class:`aiogram.fsm.scene.ScenesManager` - manages scenes for each user, used to enter, leave and resolve current scene for user +- :class:`aiogram.fsm.scene.SceneConfig` - scene configuration, used to configure scene +- :class:`aiogram.fsm.scene.SceneWizard` - scene wizard, used to interact with user in scene from active scene handler +- Markers - marker for scene handlers, used to mark scene handlers + + +.. autoclass:: aiogram.fsm.scene.Scene + :members: + +.. autoclass:: aiogram.fsm.scene.SceneRegistry + :members: + +.. autoclass:: aiogram.fsm.scene.ScenesManager + :members: + +.. autoclass:: aiogram.fsm.scene.SceneConfig + :members: + +.. autoclass:: aiogram.fsm.scene.SceneWizard + :members: + +Markers +------- + +Markers are similar to the Router event registering mechanism, +but they are used to mark scene handlers in the Scene class. + +It can be imported from :code:`from aiogram.fsm.scene import on` and should be used as decorator. + +Allowed event types: + +- message +- edited_message +- channel_post +- edited_channel_post +- inline_query +- chosen_inline_result +- callback_query +- shipping_query +- pre_checkout_query +- poll +- poll_answer +- my_chat_member +- chat_member +- chat_join_request + +Each event type can be filtered in the same way as in the Router. + +Also each event type can be marked as scene entry point, exit point or leave point. + +If you want to mark the scene can be entered from message or inline query, +you should use :code:`on.message` or :code:`on.inline_query` marker: + +.. code-block:: python + + class MyScene(Scene, name="my_scene"): + @on.message.enter() + async def on_enter(self, message: types.Message): + pass + + @on.callback_query.enter() + async def on_enter(self, callback_query: types.CallbackQuery): + pass + + +Scene has only tree points for transitions: + +- enter point - when user enters to the scene +- leave point - when user leaves the scene and the enter another scene +- exit point - when user exits from the scene diff --git a/examples/quiz_scene.py b/examples/quiz_scene.py new file mode 100644 index 00000000..343fef75 --- /dev/null +++ b/examples/quiz_scene.py @@ -0,0 +1,301 @@ +import asyncio +import logging +from dataclasses import dataclass, field +from os import getenv +from typing import Any + +from aiogram import Bot, Dispatcher, F, Router, html +from aiogram.filters import Command +from aiogram.fsm.context import FSMContext +from aiogram.fsm.scene import Scene, SceneRegistry, ScenesManager, on +from aiogram.fsm.storage.memory import SimpleEventIsolation +from aiogram.types import KeyboardButton, Message, ReplyKeyboardRemove +from aiogram.utils.formatting import ( + Bold, + as_key_value, + as_list, + as_numbered_list, + as_section, +) +from aiogram.utils.keyboard import ReplyKeyboardBuilder + +TOKEN = getenv("BOT_TOKEN") + + +@dataclass +class Answer: + """ + Represents an answer to a question. + """ + + text: str + """The answer text""" + is_correct: bool = False + """Indicates if the answer is correct""" + + +@dataclass +class Question: + """ + Class representing a quiz with a question and a list of answers. + """ + + text: str + """The question text""" + answers: list[Answer] + """List of answers""" + + correct_answer: str = field(init=False) + + def __post_init__(self): + self.correct_answer = next(answer.text for answer in self.answers if answer.is_correct) + + +# Fake data, in real application you should use a database or something else +QUESTIONS = [ + Question( + text="What is the capital of France?", + answers=[ + Answer("Paris", is_correct=True), + Answer("London"), + Answer("Berlin"), + Answer("Madrid"), + ], + ), + Question( + text="What is the capital of Spain?", + answers=[ + Answer("Paris"), + Answer("London"), + Answer("Berlin"), + Answer("Madrid", is_correct=True), + ], + ), + Question( + text="What is the capital of Germany?", + answers=[ + Answer("Paris"), + Answer("London"), + Answer("Berlin", is_correct=True), + Answer("Madrid"), + ], + ), + Question( + text="What is the capital of England?", + answers=[ + Answer("Paris"), + Answer("London", is_correct=True), + Answer("Berlin"), + Answer("Madrid"), + ], + ), + Question( + text="What is the capital of Italy?", + answers=[ + Answer("Paris"), + Answer("London"), + Answer("Berlin"), + Answer("Rome", is_correct=True), + ], + ), +] + + +class QuizScene(Scene, state="quiz"): + """ + This class represents a scene for a quiz game. + + It inherits from Scene class and is associated with the state "quiz". + It handles the logic and flow of the quiz game. + """ + + @on.message.enter() + async def on_enter(self, message: Message, state: FSMContext, step: int | None = 0) -> Any: + """ + Method triggered when the user enters the quiz scene. + + It displays the current question and answer options to the user. + + :param message: + :param state: + :param step: Scene argument, can be passed to the scene using the wizard + :return: + """ + if not step: + # This is the first step, so we should greet the user + await message.answer("Welcome to the quiz!") + + try: + quiz = QUESTIONS[step] + except IndexError: + # This error means that the question's list is over + return await self.wizard.exit() + + markup = ReplyKeyboardBuilder() + markup.add(*[KeyboardButton(text=answer.text) for answer in quiz.answers]) + + if step > 0: + markup.button(text="🔙 Back") + markup.button(text="🚫 Exit") + + await state.update_data(step=step) + return await message.answer( + text=QUESTIONS[step].text, + reply_markup=markup.adjust(2).as_markup(resize_keyboard=True), + ) + + @on.message.exit() + async def on_exit(self, message: Message, state: FSMContext) -> None: + """ + Method triggered when the user exits the quiz scene. + + It calculates the user's answers, displays the summary, and clears the stored answers. + + :param message: + :param state: + :return: + """ + data = await state.get_data() + answers = data.get("answers", {}) + + correct = 0 + incorrect = 0 + user_answers = [] + for step, quiz in enumerate(QUESTIONS): + answer = answers.get(step) + is_correct = answer == quiz.correct_answer + if is_correct: + correct += 1 + icon = "✅" + else: + incorrect += 1 + icon = "❌" + if answer is None: + answer = "no answer" + user_answers.append(f"{quiz.text} ({icon} {html.quote(answer)})") + + content = as_list( + as_section( + Bold("Your answers:"), + as_numbered_list(*user_answers), + ), + "", + as_section( + Bold("Summary:"), + as_list( + as_key_value("Correct", correct), + as_key_value("Incorrect", incorrect), + ), + ), + ) + + await message.answer(**content.as_kwargs(), reply_markup=ReplyKeyboardRemove()) + await state.set_data({}) + + @on.message(F.text == "🔙 Back") + async def back(self, message: Message, state: FSMContext) -> None: + """ + Method triggered when the user selects the "Back" button. + + It allows the user to go back to the previous question. + + :param message: + :param state: + :return: + """ + data = await state.get_data() + step = data["step"] + + previous_step = step - 1 + if previous_step < 0: + # In case when the user tries to go back from the first question, + # we just exit the quiz + return await self.wizard.exit() + return await self.wizard.back(step=previous_step) + + @on.message(F.text == "🚫 Exit") + async def exit(self, message: Message) -> None: + """ + Method triggered when the user selects the "Exit" button. + + It exits the quiz. + + :param message: + :return: + """ + await self.wizard.exit() + + @on.message(F.text) + async def answer(self, message: Message, state: FSMContext) -> None: + """ + Method triggered when the user selects an answer. + + It stores the answer and proceeds to the next question. + + :param message: + :param state: + :return: + """ + data = await state.get_data() + step = data["step"] + answers = data.get("answers", {}) + answers[step] = message.text + await state.update_data(answers=answers) + + await self.wizard.retake(step=step + 1) + + @on.message() + async def unknown_message(self, message: Message) -> None: + """ + Method triggered when the user sends a message that is not a command or an answer. + + It asks the user to select an answer. + + :param message: The message received from the user. + :return: None + """ + await message.answer("Please select an answer.") + + +quiz_router = Router(name=__name__) +# Add handler that initializes the scene +quiz_router.message.register(QuizScene.as_handler(), Command("quiz")) + + +@quiz_router.message(Command("start")) +async def command_start(message: Message, scenes: ScenesManager): + await scenes.close() + await message.answer( + "Hi! This is a quiz bot. To start the quiz, use the /quiz command.", + reply_markup=ReplyKeyboardRemove(), + ) + + +def create_dispatcher(): + # Event isolation is needed to correctly handle fast user responses + dispatcher = Dispatcher( + events_isolation=SimpleEventIsolation(), + ) + dispatcher.include_router(quiz_router) + + # To use scenes, you should create a SceneRegistry and register your scenes there + scene_registry = SceneRegistry(dispatcher) + # ... and then register a scene in the registry + # by default, Scene will be mounted to the router that passed to the SceneRegistry, + # but you can specify the router explicitly using the `router` argument + scene_registry.add(QuizScene) + + return dispatcher + + +async def main(): + dispatcher = create_dispatcher() + bot = Bot(TOKEN) + await dispatcher.start_polling(bot) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) + # Alternatively, you can use aiogram-cli: + # `aiogram run polling quiz_scene:create_dispatcher --log-level info --token BOT_TOKEN` diff --git a/examples/scene.py b/examples/scene.py new file mode 100644 index 00000000..090c64c2 --- /dev/null +++ b/examples/scene.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +from os import getenv +from typing import TypedDict + +from aiogram import Bot, Dispatcher, F, html +from aiogram.filters import Command +from aiogram.fsm.scene import After, Scene, SceneRegistry, on +from aiogram.types import ( + CallbackQuery, + InlineKeyboardButton, + InlineKeyboardMarkup, + KeyboardButton, + Message, + ReplyKeyboardMarkup, + ReplyKeyboardRemove, +) + +BUTTON_CANCEL = KeyboardButton(text="❌ Cancel") +BUTTON_BACK = KeyboardButton(text="🔙 Back") + + +class FSMData(TypedDict, total=False): + name: str + language: str + + +class CancellableScene(Scene): + """ + This scene is used to handle cancel and back buttons, + can be used as a base class for other scenes that needs to support cancel and back buttons. + """ + + @on.message(F.text.casefold() == BUTTON_CANCEL.text.casefold(), after=After.exit()) + async def handle_cancel(self, message: Message): + await message.answer("Cancelled.", reply_markup=ReplyKeyboardRemove()) + + @on.message(F.text.casefold() == BUTTON_BACK.text.casefold(), after=After.back()) + async def handle_back(self, message: Message): + await message.answer("Back.") + + +class LanguageScene(CancellableScene, state="language"): + """ + This scene is used to ask user what language he prefers. + """ + + @on.message.enter() + async def on_enter(self, message: Message): + await message.answer( + "What language do you prefer?", + reply_markup=ReplyKeyboardMarkup( + keyboard=[[BUTTON_BACK, BUTTON_CANCEL]], + resize_keyboard=True, + ), + ) + + @on.message(F.text.casefold() == "python", after=After.exit()) + async def process_python(self, message: Message): + await message.answer( + "Python, you say? That's the language that makes my circuits light up! 😉" + ) + await self.input_language(message) + + @on.message(after=After.exit()) + async def input_language(self, message: Message): + data: FSMData = await self.wizard.get_data() + await self.show_results(message, language=message.text, **data) + + async def show_results(self, message: Message, name: str, language: str) -> None: + await message.answer( + text=f"I'll keep in mind that, {html.quote(name)}, " + f"you like to write bots with {html.quote(language)}.", + reply_markup=ReplyKeyboardRemove(), + ) + + +class LikeBotsScene(CancellableScene, state="like_bots"): + """ + This scene is used to ask user if he likes to write bots. + """ + + @on.message.enter() + async def on_enter(self, message: Message): + await message.answer( + "Did you like to write bots?", + reply_markup=ReplyKeyboardMarkup( + keyboard=[ + [KeyboardButton(text="Yes"), KeyboardButton(text="No")], + [BUTTON_BACK, BUTTON_CANCEL], + ], + resize_keyboard=True, + ), + ) + + @on.message(F.text.casefold() == "yes", after=After.goto(LanguageScene)) + async def process_like_write_bots(self, message: Message): + await message.reply("Cool! I'm too!") + + @on.message(F.text.casefold() == "no", after=After.exit()) + async def process_dont_like_write_bots(self, message: Message): + await message.answer( + "Not bad not terrible.\nSee you soon.", + reply_markup=ReplyKeyboardRemove(), + ) + + @on.message() + async def input_like_bots(self, message: Message): + await message.answer("I don't understand you :(") + + +class NameScene(CancellableScene, state="name"): + """ + This scene is used to ask user's name. + """ + + @on.message.enter() # Marker for handler that should be called when a user enters the scene. + async def on_enter(self, message: Message): + await message.answer( + "Hi there! What's your name?", + reply_markup=ReplyKeyboardMarkup(keyboard=[[BUTTON_CANCEL]], resize_keyboard=True), + ) + + @on.callback_query.enter() # different types of updates that start the scene also supported. + async def on_enter_callback(self, callback_query: CallbackQuery): + await callback_query.answer() + await self.on_enter(callback_query.message) + + @on.message.leave() # Marker for handler that should be called when a user leaves the scene. + async def on_leave(self, message: Message): + data: FSMData = await self.wizard.get_data() + name = data.get("name", "Anonymous") + await message.answer(f"Nice to meet you, {html.quote(name)}!") + + @on.message(after=After.goto(LikeBotsScene)) + async def input_name(self, message: Message): + await self.wizard.update_data(name=message.text) + + +class DefaultScene( + Scene, + reset_data_on_enter=True, # Reset state data + reset_history_on_enter=True, # Reset history + callback_query_without_state=True, # Handle callback queries even if user in any scene +): + """ + Default scene for the bot. + + This scene is used to handle all messages that are not handled by other scenes. + """ + + start_demo = on.message(F.text.casefold() == "demo", after=After.goto(NameScene)) + + @on.message(Command("demo")) + async def demo(self, message: Message): + await message.answer( + "Demo started", + reply_markup=InlineKeyboardMarkup( + inline_keyboard=[[InlineKeyboardButton(text="Go to form", callback_data="start")]] + ), + ) + + @on.callback_query(F.data == "start", after=After.goto(NameScene)) + async def demo_callback(self, callback_query: CallbackQuery): + await callback_query.answer(cache_time=0) + await callback_query.message.delete_reply_markup() + + @on.message.enter() # Mark that this handler should be called when a user enters the scene. + @on.message() + async def default_handler(self, message: Message): + await message.answer( + "Start demo?\nYou can also start demo via command /demo", + reply_markup=ReplyKeyboardMarkup( + keyboard=[[KeyboardButton(text="Demo")]], + resize_keyboard=True, + ), + ) + + +def create_dispatcher() -> Dispatcher: + dispatcher = Dispatcher() + + # Scene registry should be the only one instance in your application for proper work. + # It stores all available scenes. + # You can use any router for scenes, not only `Dispatcher`. + registry = SceneRegistry(dispatcher) + # All scenes at register time converts to Routers and includes into specified router. + registry.add( + DefaultScene, + NameScene, + LikeBotsScene, + LanguageScene, + ) + + return dispatcher + + +if __name__ == "__main__": + # Recommended to use CLI instead of this snippet. + # `aiogram run polling scene_example:create_dispatcher --token BOT_TOKEN --log-level info` + dp = create_dispatcher() + bot = Bot(token=getenv("TELEGRAM_TOKEN")) + dp.run_polling() diff --git a/tests/test_dispatcher/test_event/test_event.py b/tests/test_dispatcher/test_event/test_event.py index 3b35579c..91c5fbec 100644 --- a/tests/test_dispatcher/test_event/test_event.py +++ b/tests/test_dispatcher/test_event/test_event.py @@ -46,7 +46,7 @@ class TestEventObserver: assert observer.handlers[2].awaitable with patch( - "aiogram.dispatcher.event.handler.CallableMixin.call", + "aiogram.dispatcher.event.handler.CallableObject.call", new_callable=AsyncMock, ) as mocked_my_handler: results = await observer.trigger("test") diff --git a/tests/test_dispatcher/test_event/test_handler.py b/tests/test_dispatcher/test_event/test_handler.py index b5492dce..1f8be4af 100644 --- a/tests/test_dispatcher/test_event/test_handler.py +++ b/tests/test_dispatcher/test_event/test_handler.py @@ -5,7 +5,7 @@ import pytest from magic_filter import F as A from aiogram import F -from aiogram.dispatcher.event.handler import CallableMixin, FilterObject, HandlerObject +from aiogram.dispatcher.event.handler import CallableObject, FilterObject, HandlerObject from aiogram.filters import Filter from aiogram.handlers import BaseHandler from aiogram.types import Update @@ -38,16 +38,16 @@ class SyncCallable: return locals() -class TestCallableMixin: +class TestCallableObject: @pytest.mark.parametrize("callback", [callback2, TestFilter()]) def test_init_awaitable(self, callback): - obj = CallableMixin(callback) + obj = CallableObject(callback) assert obj.awaitable assert obj.callback == callback @pytest.mark.parametrize("callback", [callback1, SyncCallable()]) def test_init_not_awaitable(self, callback): - obj = CallableMixin(callback) + obj = CallableObject(callback) assert not obj.awaitable assert obj.callback == callback @@ -62,7 +62,7 @@ class TestCallableMixin: ], ) def test_init_args_spec(self, callback: Callable, args: Set[str]): - obj = CallableMixin(callback) + obj = CallableObject(callback) assert set(obj.params) == args def test_init_decorated(self): @@ -82,8 +82,8 @@ class TestCallableMixin: def callback2(foo, bar, baz): pass - obj1 = CallableMixin(callback1) - obj2 = CallableMixin(callback2) + obj1 = CallableObject(callback1) + obj2 = CallableObject(callback2) assert set(obj1.params) == {"foo", "bar", "baz"} assert obj1.callback == callback1 @@ -127,17 +127,17 @@ class TestCallableMixin: def test_prepare_kwargs( self, callback: Callable, kwargs: Dict[str, Any], result: Dict[str, Any] ): - obj = CallableMixin(callback) + obj = CallableObject(callback) assert obj._prepare_kwargs(kwargs) == result async def test_sync_call(self): - obj = CallableMixin(callback1) + obj = CallableObject(callback1) result = await obj.call(foo=42, bar="test", baz="fuz", spam=True) assert result == {"foo": 42, "bar": "test", "baz": "fuz"} async def test_async_call(self): - obj = CallableMixin(callback2) + obj = CallableObject(callback2) result = await obj.call(foo=42, bar="test", baz="fuz", spam=True) assert result == {"foo": 42, "bar": "test", "baz": "fuz"} diff --git a/tests/test_fsm/test_scene.py b/tests/test_fsm/test_scene.py new file mode 100644 index 00000000..2bb2cf0a --- /dev/null +++ b/tests/test_fsm/test_scene.py @@ -0,0 +1,1547 @@ +import inspect +import platform +from datetime import datetime +from unittest.mock import ANY, AsyncMock, patch + +import pytest + +from aiogram import Dispatcher, F, Router +from aiogram.dispatcher.event.bases import NextMiddlewareType +from aiogram.exceptions import SceneException +from aiogram.filters import StateFilter +from aiogram.fsm.context import FSMContext +from aiogram.fsm.scene import ( + ActionContainer, + After, + HandlerContainer, + HistoryManager, + ObserverDecorator, + ObserverMarker, + Scene, + SceneAction, + SceneConfig, + SceneHandlerWrapper, + SceneRegistry, + ScenesManager, + SceneWizard, + _empty_handler, + on, +) +from aiogram.fsm.state import State, StatesGroup +from aiogram.fsm.storage.base import StorageKey +from aiogram.fsm.storage.memory import MemoryStorage, MemoryStorageRecord +from aiogram.types import Chat, Message, Update +from tests.mocked_bot import MockedBot + + +class TestOnMarker: + @pytest.mark.parametrize( + "marker_name", + [ + "message", + "edited_message", + "channel_post", + "edited_channel_post", + "inline_query", + "chosen_inline_result", + "callback_query", + "shipping_query", + "pre_checkout_query", + "poll", + "poll_answer", + "my_chat_member", + "chat_member", + "chat_join_request", + ], + ) + def test_marker_name(self, marker_name: str): + attr = getattr(on, marker_name) + assert isinstance(attr, ObserverMarker) + assert attr.name == marker_name + + +async def test_empty_handler(): + result = await _empty_handler() + assert result is None + + +class TestAfter: + def test_exit(self): + after = After.exit() + assert after is not None + assert after.action == SceneAction.exit + assert after.scene is None + + def test_back(self): + after = After.back() + assert after is not None + assert after.action == SceneAction.back + assert after.scene is None + + def test_goto(self): + after = After.goto("test") + assert after is not None + assert after.action == SceneAction.enter + assert after.scene == "test" + + +class TestObserverMarker: + def test_decorator(self): + marker = ObserverMarker("test") + decorator = marker(F.test, after=After.back()) + assert isinstance(decorator, ObserverDecorator) + assert decorator.name == "test" + assert len(decorator.filters) == 1 + assert decorator.action is None + assert decorator.after is not None + + def test_enter(self): + marker = ObserverMarker("test") + decorator = marker.enter(F.test) + assert isinstance(decorator, ObserverDecorator) + assert decorator.name == "test" + assert len(decorator.filters) == 1 + assert decorator.action == SceneAction.enter + assert decorator.after is None + + def test_leave(self): + marker = ObserverMarker("test") + decorator = marker.leave() + assert isinstance(decorator, ObserverDecorator) + assert decorator.name == "test" + assert len(decorator.filters) == 0 + assert decorator.action == SceneAction.leave + assert decorator.after is None + + def test_exit(self): + marker = ObserverMarker("test") + decorator = marker.exit() + assert isinstance(decorator, ObserverDecorator) + assert decorator.name == "test" + assert len(decorator.filters) == 0 + assert decorator.action == SceneAction.exit + assert decorator.after is None + + def test_back(self): + marker = ObserverMarker("test") + decorator = marker.back() + assert isinstance(decorator, ObserverDecorator) + assert decorator.name == "test" + assert len(decorator.filters) == 0 + assert decorator.action == SceneAction.back + assert decorator.after is None + + +class TestObserverDecorator: + def test_wrap_something(self): + decorator = ObserverDecorator("test", F.test) + + with pytest.raises(TypeError): + decorator("test") + + def test_wrap_handler(self): + decorator = ObserverDecorator("test", F.test) + + def handler(): + pass + + wrapped = decorator(handler) + + assert wrapped is not None + assert hasattr(wrapped, "__aiogram_handler__") + assert isinstance(wrapped.__aiogram_handler__, list) + assert len(wrapped.__aiogram_handler__) == 1 + + wrapped2 = decorator(handler) + + assert len(wrapped2.__aiogram_handler__) == 2 + + def test_wrap_action(self): + decorator = ObserverDecorator("test", F.test, action=SceneAction.enter) + + def handler(): + pass + + wrapped = decorator(handler) + assert wrapped is not None + assert not hasattr(wrapped, "__aiogram_handler__") + assert hasattr(wrapped, "__aiogram_action__") + + assert isinstance(wrapped.__aiogram_action__, dict) + assert len(wrapped.__aiogram_action__) == 1 + assert SceneAction.enter in wrapped.__aiogram_action__ + assert "test" in wrapped.__aiogram_action__[SceneAction.enter] + + def test_observer_decorator_leave(self): + observer_decorator = ObserverDecorator("Test Name", (F.text,)) + action_container = observer_decorator.leave() + assert isinstance(action_container, ActionContainer) + assert action_container.name == "Test Name" + assert action_container.filters == (F.text,) + assert action_container.action == SceneAction.leave + + def test_observer_decorator_enter(self): + observer_decorator = ObserverDecorator("test", (F.text,)) + target = "mock_target" + action_container = observer_decorator.enter(target) + assert isinstance(action_container, ActionContainer) + assert action_container.name == "test" + assert action_container.filters == (F.text,) + assert action_container.action == SceneAction.enter + assert action_container.target == target + + def test_observer_decorator_exit(self): + observer_decorator = ObserverDecorator("test", (F.text,)) + action_container = observer_decorator.exit() + assert isinstance(action_container, ActionContainer) + assert action_container.name == "test" + assert action_container.filters == (F.text,) + assert action_container.action == SceneAction.exit + + def test_observer_decorator_back(self): + observer_decorator = ObserverDecorator("test", (F.text,)) + action_container = observer_decorator.back() + assert isinstance(action_container, ActionContainer) + assert action_container.name == "test" + assert action_container.filters == (F.text,) + assert action_container.action == SceneAction.back + + +class TestActionContainer: + async def test_action_container_execute_enter(self): + wizard_mock = AsyncMock(spec=SceneWizard) + + action_container = ActionContainer( + "Test Name", (F.text,), SceneAction.enter, "Test Target" + ) + await action_container.execute(wizard_mock) + + wizard_mock.goto.assert_called_once_with("Test Target") + + async def test_action_container_execute_leave(self): + wizard_mock = AsyncMock(spec=SceneWizard) + + action_container = ActionContainer("Test Name", (F.text,), SceneAction.leave) + await action_container.execute(wizard_mock) + + wizard_mock.leave.assert_called_once() + + async def test_action_container_execute_exit(self): + wizard_mock = AsyncMock(spec=SceneWizard) + + action_container = ActionContainer("Test Name", (F.text,), SceneAction.exit) + await action_container.execute(wizard_mock) + + wizard_mock.exit.assert_called_once() + + async def test_action_container_execute_back(self): + wizard_mock = AsyncMock(spec=SceneWizard) + + action_container = ActionContainer("Test Name", (F.text,), SceneAction.back) + await action_container.execute(wizard_mock) + + wizard_mock.back.assert_called_once() + + +class TestSceneHandlerWrapper: + async def test_scene_handler_wrapper_call(self): + class MyScene(Scene): + pass + + async def handler_mock(*args, **kwargs): + return 42 + + state_mock = AsyncMock(spec=FSMContext) + scenes_mock = AsyncMock(spec=ScenesManager) + event_update_mock = Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ) + kwargs = {"state": state_mock, "scenes": scenes_mock, "event_update": event_update_mock} + + scene_handler_wrapper = SceneHandlerWrapper(MyScene, handler_mock) + result = await scene_handler_wrapper(event_update_mock, **kwargs) + + # Check whether result is correct + assert result == 42 + + async def test_scene_handler_wrapper_call_with_after(self): + class MyScene(Scene): + pass + + async def handler_mock(*args, **kwargs): + return 42 + + state_mock = AsyncMock(spec=FSMContext) + scenes_mock = AsyncMock(spec=ScenesManager) + event_update_mock = Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ) + kwargs = {"state": state_mock, "scenes": scenes_mock, "event_update": event_update_mock} + + scene_handler_wrapper = SceneHandlerWrapper(MyScene, handler_mock, after=After.exit()) + + with patch( + "aiogram.fsm.scene.ActionContainer.execute", new_callable=AsyncMock + ) as after_mock: + result = await scene_handler_wrapper(event_update_mock, **kwargs) + + # Check whether after_mock is called + after_mock.assert_called_once_with(ANY) + + # Check whether result is correct + assert result == 42 + + def test_scene_handler_wrapper_str(self): + class MyScene(Scene): + pass + + async def handler_mock(*args, **kwargs): + pass + + after = After.back() + + scene_handler_wrapper = SceneHandlerWrapper(MyScene, handler_mock, after=after) + result = str(scene_handler_wrapper) + + assert result == f"SceneHandlerWrapper({MyScene}, {handler_mock}, after={after})" + + def test_await(self): + class MyScene(Scene): + pass + + async def handler_mock(*args, **kwargs): + pass + + scene_handler_wrapper = SceneHandlerWrapper(MyScene, handler_mock) + + assert inspect.isawaitable(scene_handler_wrapper) + + assert hasattr(scene_handler_wrapper, "__await__") + assert scene_handler_wrapper.__await__() is scene_handler_wrapper + + +class TestHistoryManager: + async def test_history_manager_push(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + history_data = await history_manager._history_state.get_data() + assert history_data.get("history") == [{"state": "test_state", "data": data}] + + async def test_history_manager_push_if_history_overflow(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state, size=2) + + states = ["test_state", "test_state2", "test_state3", "test_state4"] + data = {"test_data": "test_data"} + for state in states: + await history_manager.push(state, data) + + history_data = await history_manager._history_state.get_data() + assert history_data.get("history") == [ + {"state": "test_state3", "data": data}, + {"state": "test_state4", "data": data}, + ] + + async def test_history_manager_pop(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + await history_manager.push("test_state2", data) + + record = await history_manager.pop() + history_data = await history_manager._history_state.get_data() + + assert isinstance(record, MemoryStorageRecord) + assert record == MemoryStorageRecord(state="test_state2", data=data) + assert history_data.get("history") == [{"state": "test_state", "data": data}] + + async def test_history_manager_pop_if_history_empty(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + record = await history_manager.pop() + assert record is None + + async def test_history_manager_pop_if_history_become_empty_after_pop(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + await history_manager.pop() + + assert await history_manager._history_state.get_data() == {} + + async def test_history_manager_get(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + record = await history_manager.get() + + assert isinstance(record, MemoryStorageRecord) + assert record == MemoryStorageRecord(state="test_state", data=data) + + async def test_history_manager_get_if_history_empty(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + record = await history_manager.get() + assert record is None + + async def test_history_manager_all(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager_size = 10 + history_manager = HistoryManager(state=state, size=history_manager_size) + + data = {"test_data": "test_data"} + for i in range(history_manager_size): + await history_manager.push(f"test_state{i}", data) + + records = await history_manager.all() + + assert isinstance(records, list) + assert len(records) == history_manager_size + assert all(isinstance(record, MemoryStorageRecord) for record in records) + + async def test_history_manager_all_if_history_empty(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + history_manager = HistoryManager(state=state) + + records = await history_manager.all() + assert records == [] + + async def test_history_manager_clear(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + + history_manager = HistoryManager(state=state) + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + await history_manager.clear() + + assert await history_manager._history_state.get_data() == {} + + async def test_history_manager_snapshot(self): + state = FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + + history_manager = HistoryManager(state=state) + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + await history_manager.snapshot() + + assert await history_manager._history_state.get_data() == { + "history": [ + {"state": "test_state", "data": data}, + { + "state": await history_manager._state.get_state(), + "data": await history_manager._state.get_data(), + }, + ] + } + + async def test_history_manager_set_state(self): + state_mock = AsyncMock(spec=FSMContext) + state_mock.storage = MemoryStorage() + state_mock.key = StorageKey(bot_id=42, chat_id=-42, user_id=42) + state_mock.set_state = AsyncMock() + state_mock.set_data = AsyncMock() + + history_manager = HistoryManager(state=state_mock) + history_manager._state = state_mock + + state = "test_state" + data = {"test_data": "test_data"} + await history_manager._set_state(state, data) + + state_mock.set_state.assert_called_once_with(state) + state_mock.set_data.assert_called_once_with(data) + + async def test_history_manager_rollback(self): + history_manager = HistoryManager( + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + ) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + await history_manager.push("test_state2", data) + + record = await history_manager.get() + assert record == MemoryStorageRecord(state="test_state2", data=data) + + await history_manager.rollback() + + record = await history_manager.get() + assert record == MemoryStorageRecord(state="test_state", data=data) + + async def test_history_manager_rollback_if_not_previous_state(self): + history_manager = HistoryManager( + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ) + ) + + data = {"test_data": "test_data"} + await history_manager.push("test_state", data) + + state = await history_manager.rollback() + assert state == "test_state" + + state = await history_manager.rollback() + assert state is None + + +class TestScene: + def test_scene_subclass_initialisation(self): + class ParentScene(Scene): + @on.message(F.text) + def parent_handler(self, *args, **kwargs): + pass + + @on.message.enter(F.text) + def parent_action(self, *args, **kwargs): + pass + + class SubScene( + ParentScene, + state="sub_state", + reset_data_on_enter=True, + reset_history_on_enter=True, + callback_query_without_state=True, + ): + general_handler = on.message( + F.text.casefold() == "test", after=After.goto("sub_state") + ) + + @on.message(F.text) + def sub_handler(self, *args, **kwargs): + pass + + @on.message.exit() + def sub_action(self, *args, **kwargs): + pass + + # Assert __scene_config__ attributes are correctly set for SubScene + assert isinstance(SubScene.__scene_config__, SceneConfig) + assert SubScene.__scene_config__.state == "sub_state" + assert SubScene.__scene_config__.reset_data_on_enter is True + assert SubScene.__scene_config__.reset_history_on_enter is True + assert SubScene.__scene_config__.callback_query_without_state is True + + # Assert handlers are correctly set + assert len(SubScene.__scene_config__.handlers) == 3 + + for handler in SubScene.__scene_config__.handlers: + assert isinstance(handler, HandlerContainer) + assert handler.name == "message" + assert handler.handler in ( + ParentScene.parent_handler, + SubScene.sub_handler, + _empty_handler, + ) + assert handler.filters == (F.text,) + + # Assert actions are correctly set + assert len(SubScene.__scene_config__.actions) == 2 + + enter_action = SubScene.__scene_config__.actions[SceneAction.enter] + assert isinstance(enter_action, dict) + assert "message" in enter_action + assert enter_action["message"].callback == ParentScene.parent_action + + exit_action = SubScene.__scene_config__.actions[SceneAction.exit] + assert isinstance(exit_action, dict) + assert "message" in exit_action + assert exit_action["message"].callback == SubScene.sub_action + + def test_scene_subclass_initialisation_bases_is_scene_subclass(self): + class NotAScene: + pass + + class MyScene(Scene, NotAScene): + pass + + class TestClass(MyScene, NotAScene): + pass + + assert MyScene in TestClass.__bases__ + assert NotAScene in TestClass.__bases__ + bases = [base for base in TestClass.__bases__ if not issubclass(base, Scene)] + assert Scene not in bases + assert NotAScene in bases + + def test_scene_add_to_router(self): + class MyScene(Scene): + @on.message(F.text) + def test_handler(self, *args, **kwargs): + pass + + router = Router() + MyScene.add_to_router(router) + + assert len(router.observers["message"].handlers) == 1 + + def test_scene_add_to_router_scene_with_callback_query_without_state(self): + class MyScene(Scene, callback_query_without_state=True): + @on.callback_query(F.data) + def test_handler(self, *args, **kwargs): + pass + + router = Router() + MyScene.add_to_router(router) + + assert len(router.observers["callback_query"].handlers) == 1 + assert ( + StateFilter(MyScene.__scene_config__.state) + not in router.observers["callback_query"].handlers[0].filters + ) + + def test_scene_as_handler(self): + class MyScene(Scene): + @on.message(F.text) + def test_handler(self, *args, **kwargs): + pass + + handler = MyScene.as_handler() + + router = Router() + router.message.register(handler) + assert router.observers["message"].handlers[0].callback == handler + + async def test_scene_as_handler_enter(self): + class MyScene(Scene): + @on.message.enter(F.text) + def test_handler(self, *args, **kwargs): + pass + + event = AsyncMock() + + scenes = ScenesManager( + registry=SceneRegistry(Router()), + update_type="message", + event=event, + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(bot_id=42, chat_id=-42, user_id=42) + ), + data={}, + ) + scenes.enter = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + handler = MyScene.as_handler(**kwargs) + await handler(event, scenes) + + scenes.enter.assert_called_once_with(MyScene, **kwargs) + + +class TestSceneWizard: + async def test_scene_wizard_enter_with_reset_data_on_enter(self): + class MyScene(Scene, reset_data_on_enter=True): + pass + + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + state_mock = AsyncMock(spec=FSMContext) + state_mock.set_state = AsyncMock() + + wizard = SceneWizard( + scene_config=MyScene.__scene_config__, + manager=AsyncMock(spec=ScenesManager), + state=state_mock, + update_type="message", + event=AsyncMock(), + data={}, + ) + kwargs = {"test_kwargs": "test_kwargs"} + wizard._on_action = AsyncMock() + + await wizard.enter(**kwargs) + + state_mock.set_data.assert_called_once_with({}) + state_mock.set_state.assert_called_once_with(MyScene.__scene_config__.state) + wizard._on_action.assert_called_once_with(SceneAction.enter, **kwargs) + + async def test_scene_wizard_enter_with_reset_history_on_enter(self): + class MyScene(Scene, reset_history_on_enter=True): + pass + + state_mock = AsyncMock(spec=FSMContext) + state_mock.set_state = AsyncMock() + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.clear = AsyncMock() + + wizard = SceneWizard( + scene_config=MyScene.__scene_config__, + manager=manager, + state=state_mock, + update_type="message", + event=AsyncMock(), + data={}, + ) + kwargs = {"test_kwargs": "test_kwargs"} + wizard._on_action = AsyncMock() + + await wizard.enter(**kwargs) + + manager.history.clear.assert_called_once() + state_mock.set_state.assert_called_once_with(MyScene.__scene_config__.state) + wizard._on_action.assert_called_once_with(SceneAction.enter, **kwargs) + + async def test_scene_wizard_leave_with_history(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.snapshot = AsyncMock() + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=manager, + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard._on_action = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.leave(_with_history=False, **kwargs) + + manager.history.snapshot.assert_not_called() + wizard._on_action.assert_called_once_with(SceneAction.leave, **kwargs) + + async def test_scene_wizard_leave_without_history(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.snapshot = AsyncMock() + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=manager, + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard._on_action = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.leave(**kwargs) + + manager.history.snapshot.assert_called_once() + wizard._on_action.assert_called_once_with(SceneAction.leave, **kwargs) + + async def test_scene_wizard_back(self): + current_scene_config_mock = AsyncMock() + current_scene_config_mock.state = "test_state" + + previous_scene_config_mock = AsyncMock() + previous_scene_config_mock.state = "previous_test_state" + + previous_scene_mock = AsyncMock() + previous_scene_mock.__scene_config__ = previous_scene_config_mock + + manager = AsyncMock(spec=ScenesManager) + manager.history = AsyncMock(spec=HistoryManager) + manager.history.rollback = AsyncMock(return_value=previous_scene_mock) + manager.enter = AsyncMock() + + wizard = SceneWizard( + scene_config=current_scene_config_mock, + manager=manager, + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.leave = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.back(**kwargs) + + wizard.leave.assert_called_once_with(_with_history=False, **kwargs) + manager.history.rollback.assert_called_once() + manager.enter.assert_called_once_with(previous_scene_mock, _check_active=False, **kwargs) + + async def test_scene_wizard_retake(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(spec=ScenesManager), + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.goto = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.retake(**kwargs) + + wizard.goto.assert_called_once_with(scene_config_mock.state, **kwargs) + + async def test_scene_wizard_retake_exception(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = None + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(spec=ScenesManager), + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + + kwargs = {"test_kwargs": "test_kwargs"} + + with pytest.raises(AssertionError, match="Scene state is not specified"): + await wizard.retake(**kwargs) + + async def test_scene_wizard_goto(self): + scene_config_mock = AsyncMock() + scene_config_mock.state = "test_state" + + scene_mock = AsyncMock() + scene_mock.__scene_config__ = scene_config_mock + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(spec=ScenesManager), + state=AsyncMock(spec=FSMContext), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.leave = AsyncMock() + wizard.manager.enter = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + await wizard.goto(scene_mock, **kwargs) + + wizard.leave.assert_called_once_with(**kwargs) + wizard.manager.enter.assert_called_once_with(scene_mock, _check_active=False, **kwargs) + + async def test_scene_wizard_on_action(self): + scene_config_mock = AsyncMock() + scene_config_mock.actions = {SceneAction.enter: {"message": AsyncMock()}} + scene_config_mock.state = "test_state" + + scene_mock = AsyncMock() + + event_mock = AsyncMock() + event_mock.type = "message" + + data = {"test_data": "test_data"} + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=event_mock, + data=data, + ) + wizard.scene = scene_mock + + kwargs = {"test_kwargs": "test_kwargs"} + result = await wizard._on_action(SceneAction.enter, **kwargs) + + scene_config_mock.actions[SceneAction.enter]["message"].call.assert_called_once_with( + scene_mock, event_mock, **{**data, **kwargs} + ) + assert result is True + + async def test_scene_wizard_on_action_no_scene(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + + with pytest.raises(SceneException, match="Scene is not initialized"): + await wizard._on_action(SceneAction.enter) + + async def test_scene_wizard_on_action_no_action_config(self): + scene_config_mock = AsyncMock() + scene_config_mock.actions = {} + scene_config_mock.state = "test_state" + + scene_mock = AsyncMock() + + event_mock = AsyncMock() + event_mock.type = "message" + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=event_mock, + data={}, + ) + wizard.scene = scene_mock + + kwargs = {"test_kwargs": "test_kwargs"} + result = await wizard._on_action(SceneAction.enter, **kwargs) + + assert result is False + + async def test_scene_wizard_on_action_event_type_not_in_action_config(self): + scene_config_mock = AsyncMock() + scene_config_mock.actions = {SceneAction.enter: {"test_update_type": AsyncMock()}} + scene_config_mock.state = "test_state" + + event_mock = AsyncMock() + event_mock.type = "message" + + wizard = SceneWizard( + scene_config=scene_config_mock, + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=event_mock, + data={}, + ) + wizard.scene = AsyncMock() + + kwargs = {"test_kwargs": "test_kwargs"} + result = await wizard._on_action(SceneAction.enter, **kwargs) + + assert result is False + + async def test_scene_wizard_set_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.state.set_data = AsyncMock() + + data = {"test_key": "test_value"} + await wizard.set_data(data) + + wizard.state.set_data.assert_called_once_with(data=data) + + async def test_scene_wizard_get_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.state.get_data = AsyncMock() + + await wizard.get_data() + + wizard.state.get_data.assert_called_once_with() + + async def test_scene_wizard_update_data_if_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + data = {"test_key": "test_value"} + kwargs = {"test_kwargs": "test_kwargs"} + + wizard.state.update_data = AsyncMock(return_value={**data, **kwargs}) + result = await wizard.update_data(data=data, **kwargs) + + wizard.state.update_data.assert_called_once_with(data={**data, **kwargs}) + assert result == {**data, **kwargs} + + async def test_scene_wizard_update_data_if_no_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + data = None + kwargs = {"test_kwargs": "test_kwargs"} + + wizard.state.update_data = AsyncMock(return_value={**kwargs}) + result = await wizard.update_data(data=data, **kwargs) + + wizard.state.update_data.assert_called_once_with(data=kwargs) + assert result == {**kwargs} + + async def test_scene_wizard_clear_data(self): + wizard = SceneWizard( + scene_config=AsyncMock(), + manager=AsyncMock(), + state=AsyncMock(), + update_type="message", + event=AsyncMock(), + data={}, + ) + wizard.set_data = AsyncMock() + + await wizard.clear_data() + + wizard.set_data.assert_called_once_with({}) + + +class TestScenesManager: + async def test_scenes_manager_get_scene(self, bot: MockedBot): + class MyScene(Scene): + pass + + router = Router() + + registry = SceneRegistry(router) + registry.add(MyScene) + + scenes_manager = ScenesManager( + registry=registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = await scenes_manager._get_scene(MyScene) + assert isinstance(scene, MyScene) + assert isinstance(scene.wizard, SceneWizard) + assert scene.wizard.scene_config == MyScene.__scene_config__ + assert scene.wizard.manager == scenes_manager + assert scene.wizard.update_type == "message" + assert scene.wizard.data == {} + + async def test_scenes_manager_get_active_scene(self, bot: MockedBot): + class TestScene(Scene): + pass + + class TestScene2(Scene, state="test_state2"): + pass + + registry = SceneRegistry(Router()) + registry.add(TestScene, TestScene2) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = await manager._get_active_scene() + assert isinstance(scene, TestScene) + + await manager.enter(TestScene2) + scene = await manager._get_active_scene() + assert isinstance(scene, TestScene2) + + async def test_scenes_manager_get_active_scene_with_scene_exception(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = await manager._get_active_scene() + + assert scene is None + + async def test_scenes_manager_enter_with_scene_type_none(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + assert await manager.enter(None) is None + + async def test_scenes_manager_enter_with_scene_exception(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + scene = "invalid_scene" + with pytest.raises(SceneException, match=f"Scene {scene!r} is not registered"): + await manager.enter(scene) + + async def test_scenes_manager_close_if_active_scene(self, bot: MockedBot): + class TestScene(Scene): + pass + + registry = SceneRegistry(Router()) + registry.add(TestScene) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + manager._get_active_scene = AsyncMock( + return_value=TestScene( + SceneWizard( + scene_config=TestScene.__scene_config__, + manager=manager, + state=manager.state, + update_type="message", + event=manager.event, + data={}, + ) + ) + ) + manager._get_active_scene.return_value.wizard.exit = AsyncMock() + + await manager.close() + + manager._get_active_scene.assert_called_once() + manager._get_active_scene.return_value.wizard.exit.assert_called_once() + + async def test_scenes_manager_close_if_no_active_scene(self, bot: MockedBot): + registry = SceneRegistry(Router()) + + manager = ScenesManager( + registry, + update_type="message", + event=Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat( + type="private", + id=42, + ), + ), + ), + state=FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + data={}, + ) + + manager._get_active_scene = AsyncMock(return_value=None) + + result = await manager.close() + + manager._get_active_scene.assert_called_once() + + assert result is None + + +class TestSceneRegistry: + def test_scene_registry_initialization(self): + router = Router() + register_on_add = True + + registry = SceneRegistry(router, register_on_add) + + assert registry.router == router + assert registry.register_on_add == register_on_add + assert registry._scenes == {} + + def test_scene_registry_add_scene(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_add_scene_pass_router(self): + class MyScene(Scene): + pass + + router = Router() + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router) + registry.add(MyScene, router=router) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_add_scene_already_exists(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + with pytest.raises(SceneException): + registry.add(MyScene) + + def test_scene_registry_add_scene_no_scenes(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry._scenes = {} + + with pytest.raises(ValueError, match="At least one scene must be specified"): + registry.add() + + def test_scene_registry_register(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.register(MyScene) + + assert len(registry._scenes) == 1 + assert registry._scenes["test_scene"] == MyScene + + def test_scene_registry_get_scene_if_scene_type_is_str(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get("test_scene") + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_type_is_scene(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = "test_scene" + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get(MyScene) + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_state_is_state(self): + class MyStates(StatesGroup): + test_state = State() + + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = MyStates.test_state + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + retrieved_scene = registry.get(MyScene) + + assert retrieved_scene == MyScene + + def test_scene_registry_get_scene_if_scene_state_is_not_str(self): + class MyScene(Scene): + pass + + router = Router() + register_on_add = True + MyScene.__scene_config__.state = 42 + + registry = SceneRegistry(router, register_on_add) + registry.add(MyScene) + + with pytest.raises(SceneException, match="Scene must be a subclass of Scene or a string"): + registry.get(MyScene) + + def test_scene_registry_get_scene_not_registered(self): + router = Router() + register_on_add = True + + registry = SceneRegistry(router, register_on_add) + + with pytest.raises(SceneException): + registry.get("test_scene") + + def test_scene_registry_setup_middleware_with_dispatcher(self): + router = Router() + + registry = SceneRegistry(router) + + dispatcher = Dispatcher() + registry._setup_middleware(dispatcher) + + assert registry._update_middleware in dispatcher.update.outer_middleware + + for name, observer in dispatcher.observers.items(): + if name == "update": + continue + assert registry._update_middleware not in observer.outer_middleware + + def test_scene_registry_setup_middleware_with_router(self): + inner_router = Router() + + registry = SceneRegistry(inner_router) + + outer_router = Router() + registry._setup_middleware(outer_router) + + for name, observer in outer_router.observers.items(): + if name in ("update", "error"): + continue + assert registry._middleware in observer.outer_middleware + + async def test_scene_registry_update_middleware(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + ), + ) + data = { + "state": FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ) + } + + result = await registry._update_middleware(handler, event, data) + + assert "scenes" in data + assert isinstance(data["scenes"], ScenesManager) + handler.assert_called_once_with(event, data) + assert result == handler.return_value + + async def test_scene_registry_update_middleware_not_update(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + ) + data = { + "state": FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ) + } + + with pytest.raises(AssertionError, match="Event must be an Update instance"): + await registry._update_middleware(handler, event, data) + + async def test_scene_registry_middleware(self, bot: MockedBot): + router = Router() + registry = SceneRegistry(router) + handler = AsyncMock(spec=NextMiddlewareType) + event = Update( + update_id=42, + message=Message( + message_id=42, + text="test", + date=datetime.now(), + chat=Chat(id=42, type="private"), + ), + ) + data = { + "state": FSMContext( + storage=MemoryStorage(), key=StorageKey(chat_id=-42, user_id=42, bot_id=bot.id) + ), + "event_update": event, + } + + result = await registry._middleware(handler, event, data) + + assert "scenes" in data + assert isinstance(data["scenes"], ScenesManager) + handler.assert_called_once_with(event, data) + assert result == handler.return_value