mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-12 02:03:04 +00:00
Implement new middlewares
This commit is contained in:
parent
c262cc0ce6
commit
7f26ec9935
29 changed files with 532 additions and 1252 deletions
|
|
@ -11,7 +11,13 @@ class TestInlineQuery:
|
|||
offset="",
|
||||
)
|
||||
|
||||
kwargs = dict(results=[], cache_time=123, next_offset="123", switch_pm_text="foo", switch_pm_parameter="foo")
|
||||
kwargs = dict(
|
||||
results=[],
|
||||
cache_time=123,
|
||||
next_offset="123",
|
||||
switch_pm_text="foo",
|
||||
switch_pm_parameter="foo",
|
||||
)
|
||||
|
||||
api_method = inline_query.answer(**kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from aiogram.api.methods import AnswerShippingQuery
|
||||
from aiogram.api.types import ShippingAddress, ShippingQuery, User, ShippingOption, LabeledPrice
|
||||
from aiogram.api.types import LabeledPrice, ShippingAddress, ShippingOption, ShippingQuery, User
|
||||
|
||||
|
||||
class TestInlineQuery:
|
||||
|
|
@ -19,7 +19,8 @@ class TestInlineQuery:
|
|||
)
|
||||
|
||||
shipping_options = [
|
||||
ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)])]
|
||||
ShippingOption(id="id", title="foo", prices=[LabeledPrice(label="foo", amount=123)])
|
||||
]
|
||||
|
||||
kwargs = dict(ok=True, shipping_options=shipping_options, error_message="foo")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from aiogram.dispatcher.event.observer import TelegramEventObserver
|
||||
from aiogram.dispatcher.event.telegram import TelegramEventObserver
|
||||
from aiogram.dispatcher.router import Router
|
||||
from tests.deprecated import check_deprecated
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from aiogram import Bot
|
|||
from aiogram.api.methods import GetMe, GetUpdates, SendMessage
|
||||
from aiogram.api.types import Chat, Message, Update, User
|
||||
from aiogram.dispatcher.dispatcher import Dispatcher
|
||||
from aiogram.dispatcher.event.bases import NOT_HANDLED
|
||||
from aiogram.dispatcher.router import Router
|
||||
from tests.mocked_bot import MockedBot
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ class TestDispatcher:
|
|||
return message.text
|
||||
|
||||
results_count = 0
|
||||
async for result in dp.feed_update(
|
||||
result = await dp.feed_update(
|
||||
bot=bot,
|
||||
update=Update(
|
||||
update_id=42,
|
||||
|
|
@ -75,11 +76,9 @@ class TestDispatcher:
|
|||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
),
|
||||
),
|
||||
):
|
||||
results_count += 1
|
||||
assert result == "test"
|
||||
|
||||
assert results_count == 1
|
||||
)
|
||||
results_count += 1
|
||||
assert result == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feed_raw_update(self):
|
||||
|
|
@ -91,8 +90,7 @@ class TestDispatcher:
|
|||
assert message.text == "test"
|
||||
return message.text
|
||||
|
||||
handled = False
|
||||
async for result in dp.feed_raw_update(
|
||||
result = await dp.feed_raw_update(
|
||||
bot=bot,
|
||||
update={
|
||||
"update_id": 42,
|
||||
|
|
@ -101,13 +99,11 @@ class TestDispatcher:
|
|||
"date": int(time.time()),
|
||||
"text": "test",
|
||||
"chat": {"id": 42, "type": "private"},
|
||||
"user": {"id": 42, "is_bot": False, "first_name": "Test"},
|
||||
"from": {"id": 42, "is_bot": False, "first_name": "Test"},
|
||||
},
|
||||
},
|
||||
):
|
||||
handled = True
|
||||
assert result == "test"
|
||||
assert handled
|
||||
)
|
||||
assert result == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_updates(self, bot: MockedBot):
|
||||
|
|
@ -136,7 +132,8 @@ class TestDispatcher:
|
|||
async def test_process_update_empty(self, bot: MockedBot):
|
||||
dispatcher = Dispatcher()
|
||||
|
||||
assert not await dispatcher.process_update(bot=bot, update=Update(update_id=42))
|
||||
result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
|
||||
assert result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_update_handled(self, bot: MockedBot):
|
||||
|
|
@ -146,22 +143,25 @@ class TestDispatcher:
|
|||
async def update_handler(update: Update):
|
||||
pass
|
||||
|
||||
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
|
||||
assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_update_call_request(self, bot: MockedBot):
|
||||
dispatcher = Dispatcher()
|
||||
|
||||
@dispatcher.update()
|
||||
async def update_handler(update: Update):
|
||||
async def message_handler(update: Update):
|
||||
return GetMe()
|
||||
|
||||
dispatcher.update.handlers.reverse()
|
||||
|
||||
with patch(
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher._silent_call_request",
|
||||
new_callable=CoroutineMock,
|
||||
) as mocked_silent_call_request:
|
||||
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
|
||||
mocked_silent_call_request.assert_awaited_once()
|
||||
result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
|
||||
print(result)
|
||||
mocked_silent_call_request.assert_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_update_exception(self, bot: MockedBot, caplog):
|
||||
|
|
@ -171,7 +171,7 @@ class TestDispatcher:
|
|||
async def update_handler(update: Update):
|
||||
raise Exception("Kaboom!")
|
||||
|
||||
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
|
||||
assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
|
||||
log_records = [rec.message for rec in caplog.records]
|
||||
assert len(log_records) == 1
|
||||
assert "Cause exception while process update" in log_records[0]
|
||||
|
|
@ -184,7 +184,7 @@ class TestDispatcher:
|
|||
yield Update(update_id=42)
|
||||
|
||||
with patch(
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
|
||||
) as mocked_process_update, patch(
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
|
||||
) as patched_listen_updates:
|
||||
|
|
@ -203,7 +203,7 @@ class TestDispatcher:
|
|||
yield Update(update_id=42)
|
||||
|
||||
with patch(
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock
|
||||
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
|
||||
) as mocked_process_update, patch(
|
||||
"aiogram.dispatcher.router.Router.emit_startup", new_callable=CoroutineMock
|
||||
) as mocked_emit_startup, patch(
|
||||
|
|
|
|||
59
tests/test_dispatcher/test_event/test_event.py
Normal file
59
tests/test_dispatcher/test_event/test_event.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import functools
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from aiogram.dispatcher.event.event import EventObserver
|
||||
from aiogram.dispatcher.event.handler import HandlerObject
|
||||
|
||||
try:
|
||||
from asynctest import CoroutineMock, patch
|
||||
except ImportError:
|
||||
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
|
||||
|
||||
|
||||
async def my_handler(value: str, index: int = 0) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
class TestEventObserver:
|
||||
@pytest.mark.parametrize("via_decorator", [True, False])
|
||||
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||
def test_register_filters(self, via_decorator, count, handler):
|
||||
observer = EventObserver()
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
if via_decorator:
|
||||
register_result = observer()(wrapped_handler)
|
||||
assert register_result == wrapped_handler
|
||||
else:
|
||||
register_result = observer.register(wrapped_handler)
|
||||
assert register_result is None
|
||||
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert not registered_handler.filters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger(self):
|
||||
observer = EventObserver()
|
||||
|
||||
observer.register(my_handler)
|
||||
observer.register(lambda e: True)
|
||||
observer.register(my_handler)
|
||||
|
||||
assert observer.handlers[0].awaitable
|
||||
assert not observer.handlers[1].awaitable
|
||||
assert observer.handlers[2].awaitable
|
||||
|
||||
with patch(
|
||||
"aiogram.dispatcher.event.handler.CallableMixin.call", new_callable=CoroutineMock,
|
||||
) as mocked_my_handler:
|
||||
results = await observer.trigger("test")
|
||||
assert results is None
|
||||
mocked_my_handler.assert_awaited_with("test")
|
||||
assert mocked_my_handler.call_count == 3
|
||||
|
|
@ -5,11 +5,14 @@ from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
|
|||
import pytest
|
||||
|
||||
from aiogram.api.types import Chat, Message, User
|
||||
from aiogram.dispatcher.event.bases import SkipHandler
|
||||
from aiogram.dispatcher.event.handler import HandlerObject
|
||||
from aiogram.dispatcher.event.observer import EventObserver, SkipHandler, TelegramEventObserver
|
||||
from aiogram.dispatcher.event.telegram import TelegramEventObserver
|
||||
from aiogram.dispatcher.filters.base import BaseFilter
|
||||
from aiogram.dispatcher.router import Router
|
||||
|
||||
# TODO: Test middlewares in routers tree
|
||||
|
||||
|
||||
async def my_handler(event: Any, index: int = 0) -> Any:
|
||||
return event
|
||||
|
|
@ -38,54 +41,6 @@ class MyFilter3(MyFilter1):
|
|||
pass
|
||||
|
||||
|
||||
class TestEventObserver:
|
||||
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||
def test_register_filters(self, count, handler):
|
||||
observer = EventObserver()
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer.register(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert not registered_handler.filters
|
||||
|
||||
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
|
||||
def test_register_filters_via_decorator(self, count, handler):
|
||||
observer = EventObserver()
|
||||
|
||||
for index in range(count):
|
||||
wrapped_handler = functools.partial(handler, index=index)
|
||||
observer()(wrapped_handler)
|
||||
registered_handler = observer.handlers[index]
|
||||
|
||||
assert len(observer.handlers) == index + 1
|
||||
assert isinstance(registered_handler, HandlerObject)
|
||||
assert registered_handler.callback == wrapped_handler
|
||||
assert not registered_handler.filters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_accepted_bool(self):
|
||||
observer = EventObserver()
|
||||
observer.register(my_handler)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [42]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_with_skip(self):
|
||||
observer = EventObserver()
|
||||
observer.register(skip_my_handler)
|
||||
observer.register(my_handler)
|
||||
observer.register(my_handler)
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [42, 42]
|
||||
|
||||
|
||||
class TestTelegramEventObserver:
|
||||
def test_bind_filter(self):
|
||||
event_observer = TelegramEventObserver(Router(), "test")
|
||||
|
|
@ -198,8 +153,8 @@ class TestTelegramEventObserver:
|
|||
from_user=User(id=42, is_bot=False, first_name="Test"),
|
||||
)
|
||||
|
||||
results = [result async for result in observer.trigger(message)]
|
||||
assert results == [message]
|
||||
results = await observer.trigger(message)
|
||||
assert results is message
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"count,handler,filters",
|
||||
|
|
@ -223,15 +178,58 @@ class TestTelegramEventObserver:
|
|||
assert registered_handler.callback == wrapped_handler
|
||||
assert len(registered_handler.filters) == len(filters)
|
||||
|
||||
#
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_right_context_in_handlers(self):
|
||||
router = Router(use_builtin_filters=False)
|
||||
observer = router.message
|
||||
observer.register(
|
||||
pipe_handler, lambda event: {"a": 1}, lambda event: False
|
||||
) # {"a": 1} should not be in result
|
||||
observer.register(pipe_handler, lambda event: {"b": 2})
|
||||
|
||||
results = [result async for result in observer.trigger(42)]
|
||||
assert results == [((42,), {"b": 2})]
|
||||
async def mix_unnecessary_data(event):
|
||||
return {"a": 1}
|
||||
|
||||
async def mix_data(event):
|
||||
return {"b": 2}
|
||||
|
||||
async def handler(event, **kwargs):
|
||||
return False
|
||||
|
||||
observer.register(
|
||||
pipe_handler, mix_unnecessary_data, handler
|
||||
) # {"a": 1} should not be in result
|
||||
observer.register(pipe_handler, mix_data)
|
||||
|
||||
results = await observer.trigger(42)
|
||||
assert results == ((42,), {"b": 2})
|
||||
|
||||
@pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware"))
|
||||
def test_register_middleware(self, middleware_type):
|
||||
event_observer = TelegramEventObserver(Router(), "test")
|
||||
|
||||
middlewares = getattr(event_observer, f"{middleware_type}s")
|
||||
decorator = getattr(event_observer, middleware_type)
|
||||
|
||||
@decorator
|
||||
async def my_middleware1(handler, event, data):
|
||||
pass
|
||||
|
||||
assert my_middleware1 is not None
|
||||
assert my_middleware1.__name__ == "my_middleware1"
|
||||
assert my_middleware1 in middlewares
|
||||
|
||||
@decorator()
|
||||
async def my_middleware2(handler, event, data):
|
||||
pass
|
||||
|
||||
assert my_middleware2 is not None
|
||||
assert my_middleware2.__name__ == "my_middleware2"
|
||||
assert my_middleware2 in middlewares
|
||||
|
||||
async def my_middleware3(handler, event, data):
|
||||
pass
|
||||
|
||||
decorator(my_middleware3)
|
||||
|
||||
assert my_middleware3 is not None
|
||||
assert my_middleware3.__name__ == "my_middleware3"
|
||||
assert my_middleware3 in middlewares
|
||||
|
||||
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]
|
||||
|
|
@ -1,257 +0,0 @@
|
|||
import datetime
|
||||
from typing import Any, Dict, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from aiogram.api.types import (
|
||||
CallbackQuery,
|
||||
Chat,
|
||||
ChosenInlineResult,
|
||||
InlineQuery,
|
||||
Message,
|
||||
Poll,
|
||||
PollAnswer,
|
||||
PreCheckoutQuery,
|
||||
ShippingQuery,
|
||||
Update,
|
||||
User,
|
||||
)
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
|
||||
|
||||
try:
|
||||
from asynctest import CoroutineMock, patch
|
||||
except ImportError:
|
||||
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
|
||||
|
||||
|
||||
class MyMiddleware(BaseMiddleware):
|
||||
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
|
||||
return "update"
|
||||
|
||||
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
|
||||
return "message"
|
||||
|
||||
async def on_pre_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_message"
|
||||
|
||||
async def on_pre_process_channel_post(
|
||||
self, channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "channel_post"
|
||||
|
||||
async def on_pre_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_channel_post"
|
||||
|
||||
async def on_pre_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "inline_query"
|
||||
|
||||
async def on_pre_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "chosen_inline_result"
|
||||
|
||||
async def on_pre_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "callback_query"
|
||||
|
||||
async def on_pre_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "shipping_query"
|
||||
|
||||
async def on_pre_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "pre_checkout_query"
|
||||
|
||||
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
|
||||
return "poll"
|
||||
|
||||
async def on_pre_process_poll_answer(
|
||||
self, poll_answer: PollAnswer, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "poll_answer"
|
||||
|
||||
async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
|
||||
return "error"
|
||||
|
||||
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
|
||||
return "update"
|
||||
|
||||
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
|
||||
return "message"
|
||||
|
||||
async def on_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_message"
|
||||
|
||||
async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any:
|
||||
return "channel_post"
|
||||
|
||||
async def on_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "edited_channel_post"
|
||||
|
||||
async def on_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "inline_query"
|
||||
|
||||
async def on_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "chosen_inline_result"
|
||||
|
||||
async def on_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "callback_query"
|
||||
|
||||
async def on_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "shipping_query"
|
||||
|
||||
async def on_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
|
||||
) -> Any:
|
||||
return "pre_checkout_query"
|
||||
|
||||
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
|
||||
return "poll"
|
||||
|
||||
async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any:
|
||||
return "poll_answer"
|
||||
|
||||
async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
|
||||
return "error"
|
||||
|
||||
async def on_post_process_update(
|
||||
self, update: Update, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "update"
|
||||
|
||||
async def on_post_process_message(
|
||||
self, message: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "message"
|
||||
|
||||
async def on_post_process_edited_message(
|
||||
self, edited_message: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "edited_message"
|
||||
|
||||
async def on_post_process_channel_post(
|
||||
self, channel_post: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "channel_post"
|
||||
|
||||
async def on_post_process_edited_channel_post(
|
||||
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "edited_channel_post"
|
||||
|
||||
async def on_post_process_inline_query(
|
||||
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "inline_query"
|
||||
|
||||
async def on_post_process_chosen_inline_result(
|
||||
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "chosen_inline_result"
|
||||
|
||||
async def on_post_process_callback_query(
|
||||
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "callback_query"
|
||||
|
||||
async def on_post_process_shipping_query(
|
||||
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "shipping_query"
|
||||
|
||||
async def on_post_process_pre_checkout_query(
|
||||
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "pre_checkout_query"
|
||||
|
||||
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
|
||||
return "poll"
|
||||
|
||||
async def on_post_process_poll_answer(
|
||||
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "poll_answer"
|
||||
|
||||
async def on_post_process_error(
|
||||
self, exception: Exception, data: Dict[str, Any], result: Any
|
||||
) -> Any:
|
||||
return "error"
|
||||
|
||||
|
||||
UPDATE = Update(update_id=42)
|
||||
MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private"))
|
||||
POLL_ANSWER = PollAnswer(
|
||||
poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0]
|
||||
)
|
||||
|
||||
|
||||
class TestBaseMiddleware:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"event_name,event",
|
||||
[
|
||||
["update", UPDATE],
|
||||
["message", MESSAGE],
|
||||
["poll_answer", POLL_ANSWER],
|
||||
["error", Exception("KABOOM")],
|
||||
],
|
||||
)
|
||||
async def test_trigger(
|
||||
self,
|
||||
step: MiddlewareStep,
|
||||
event_name: str,
|
||||
event: UpdateType,
|
||||
middleware_cls: Type[BaseMiddleware],
|
||||
should_be_awaited: bool,
|
||||
):
|
||||
middleware = middleware_cls()
|
||||
|
||||
with patch(
|
||||
f"tests.test_dispatcher.test_middlewares.test_base."
|
||||
f"MyMiddleware.on_{step.value}_{event_name}",
|
||||
new_callable=CoroutineMock,
|
||||
) as mocked_call:
|
||||
response = await middleware.trigger(
|
||||
step=step, event_name=event_name, event=event, data={}
|
||||
)
|
||||
if should_be_awaited:
|
||||
mocked_call.assert_awaited()
|
||||
assert response is not None
|
||||
else:
|
||||
mocked_call.assert_not_awaited()
|
||||
assert response is None
|
||||
|
||||
def test_not_configured(self):
|
||||
middleware = BaseMiddleware()
|
||||
assert not middleware.configured
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
manager = middleware.manager
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from aiogram import Router
|
||||
from aiogram.api.types import Update
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
|
||||
from aiogram.dispatcher.middlewares.types import MiddlewareStep
|
||||
|
||||
try:
|
||||
from asynctest import CoroutineMock, patch
|
||||
except ImportError:
|
||||
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture("function")
|
||||
def router():
|
||||
return Router()
|
||||
|
||||
|
||||
@pytest.fixture("function")
|
||||
def manager(router: Router):
|
||||
return MiddlewareManager(router)
|
||||
|
||||
|
||||
class TestManager:
|
||||
def test_setup(self, manager: MiddlewareManager):
|
||||
middleware = BaseMiddleware()
|
||||
returned = manager.setup(middleware)
|
||||
assert returned is middleware
|
||||
assert middleware.configured
|
||||
assert middleware.manager is manager
|
||||
assert middleware in manager
|
||||
|
||||
@pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware])
|
||||
def test_setup_invalid_type(self, manager: MiddlewareManager, obj):
|
||||
with pytest.raises(TypeError):
|
||||
assert manager.setup(obj)
|
||||
|
||||
def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router):
|
||||
middleware = BaseMiddleware()
|
||||
manager.setup(middleware)
|
||||
|
||||
assert middleware.configured
|
||||
|
||||
new_manager = MiddlewareManager(router)
|
||||
with pytest.raises(ValueError):
|
||||
new_manager.setup(middleware)
|
||||
with pytest.raises(ValueError):
|
||||
middleware.setup(new_manager)
|
||||
|
||||
def test_configure_twice(self, manager: MiddlewareManager):
|
||||
middleware = BaseMiddleware()
|
||||
manager.setup(middleware)
|
||||
|
||||
assert middleware.configured
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
|
||||
manager.setup(middleware)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
|
||||
middleware.setup(manager)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("count", range(5))
|
||||
async def test_trigger(self, manager: MiddlewareManager, count: int):
|
||||
for _ in range(count):
|
||||
manager.setup(BaseMiddleware())
|
||||
|
||||
with patch(
|
||||
"aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger",
|
||||
new_callable=CoroutineMock,
|
||||
) as mocked_call:
|
||||
await manager.trigger(
|
||||
step=MiddlewareStep.PROCESS,
|
||||
event_name="update",
|
||||
event=Update(update_id=42),
|
||||
data={},
|
||||
result=None,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
assert mocked_call.await_count == count
|
||||
|
|
@ -10,6 +10,7 @@ from aiogram.api.types import (
|
|||
InlineQuery,
|
||||
Message,
|
||||
Poll,
|
||||
PollAnswer,
|
||||
PollOption,
|
||||
PreCheckoutQuery,
|
||||
ShippingAddress,
|
||||
|
|
@ -17,8 +18,8 @@ from aiogram.api.types import (
|
|||
Update,
|
||||
User,
|
||||
)
|
||||
from aiogram.dispatcher.event.observer import SkipHandler, skip
|
||||
from aiogram.dispatcher.middlewares.base import BaseMiddleware
|
||||
from aiogram.dispatcher.event.bases import NOT_HANDLED, SkipHandler, skip
|
||||
from aiogram.dispatcher.middlewares.update_processing_context import UserContextMiddleware
|
||||
from aiogram.dispatcher.router import Router
|
||||
from aiogram.utils.warnings import CodeHasNoEffect
|
||||
|
||||
|
|
@ -274,12 +275,26 @@ class TestRouter:
|
|||
False,
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
"poll_answer",
|
||||
Update(
|
||||
update_id=42,
|
||||
poll_answer=PollAnswer(
|
||||
poll_id="poll id",
|
||||
user=User(id=42, is_bot=False, first_name="Test"),
|
||||
option_ids=[42],
|
||||
),
|
||||
),
|
||||
False,
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_listen_update(
|
||||
self, event_type: str, update: Update, has_chat: bool, has_user: bool
|
||||
):
|
||||
router = Router()
|
||||
router.update.outer_middleware(UserContextMiddleware())
|
||||
observer = router.observers[event_type]
|
||||
|
||||
@observer()
|
||||
|
|
@ -291,7 +306,7 @@ class TestRouter:
|
|||
assert User.get_current(False)
|
||||
return kwargs
|
||||
|
||||
result = await router._listen_update(update, test="PASS")
|
||||
result = await router.update.trigger(update, test="PASS")
|
||||
assert isinstance(result, dict)
|
||||
assert result["event_update"] == update
|
||||
assert result["event_router"] == router
|
||||
|
|
@ -313,26 +328,26 @@ class TestRouter:
|
|||
async def handler(event: Any):
|
||||
pass
|
||||
|
||||
with pytest.raises(SkipHandler):
|
||||
await router._listen_update(
|
||||
Update(
|
||||
update_id=42,
|
||||
poll=Poll(
|
||||
id="poll id",
|
||||
question="Q?",
|
||||
options=[
|
||||
PollOption(text="A1", voter_count=2),
|
||||
PollOption(text="A2", voter_count=3),
|
||||
],
|
||||
is_closed=False,
|
||||
is_anonymous=False,
|
||||
type="quiz",
|
||||
allows_multiple_answers=False,
|
||||
total_voter_count=0,
|
||||
correct_option_id=0,
|
||||
),
|
||||
)
|
||||
response = await router._listen_update(
|
||||
Update(
|
||||
update_id=42,
|
||||
poll=Poll(
|
||||
id="poll id",
|
||||
question="Q?",
|
||||
options=[
|
||||
PollOption(text="A1", voter_count=2),
|
||||
PollOption(text="A2", voter_count=3),
|
||||
],
|
||||
is_closed=False,
|
||||
is_anonymous=False,
|
||||
type="quiz",
|
||||
allows_multiple_answers=False,
|
||||
total_voter_count=0,
|
||||
correct_option_id=0,
|
||||
),
|
||||
)
|
||||
)
|
||||
assert response is NOT_HANDLED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_router_listen_update(self):
|
||||
|
|
@ -345,8 +360,6 @@ class TestRouter:
|
|||
|
||||
@observer()
|
||||
async def my_handler(event: Message, **kwargs: Any):
|
||||
assert Chat.get_current(False)
|
||||
assert User.get_current(False)
|
||||
return kwargs
|
||||
|
||||
update = Update(
|
||||
|
|
@ -409,14 +422,6 @@ class TestRouter:
|
|||
await router1.emit_shutdown()
|
||||
assert results == [2, 1, 2]
|
||||
|
||||
def test_use(self):
|
||||
router = Router()
|
||||
|
||||
middleware = router.use(BaseMiddleware())
|
||||
assert isinstance(middleware, BaseMiddleware)
|
||||
assert middleware.configured
|
||||
assert middleware.manager == router.middleware
|
||||
|
||||
def test_skip(self):
|
||||
with pytest.raises(SkipHandler):
|
||||
skip()
|
||||
|
|
@ -444,37 +449,20 @@ class TestRouter:
|
|||
),
|
||||
)
|
||||
with pytest.raises(Exception, match="KABOOM"):
|
||||
await root_router.listen_update(
|
||||
update_type="message",
|
||||
update=update,
|
||||
event=update.message,
|
||||
from_user=update.message.from_user,
|
||||
chat=update.message.chat,
|
||||
)
|
||||
await root_router.update.trigger(update)
|
||||
|
||||
@root_router.errors()
|
||||
async def root_error_handler(exception: Exception):
|
||||
async def root_error_handler(event: Update, exception: Exception):
|
||||
return exception
|
||||
|
||||
response = await root_router.listen_update(
|
||||
update_type="message",
|
||||
update=update,
|
||||
event=update.message,
|
||||
from_user=update.message.from_user,
|
||||
chat=update.message.chat,
|
||||
)
|
||||
response = await root_router.update.trigger(update)
|
||||
|
||||
assert isinstance(response, Exception)
|
||||
assert str(response) == "KABOOM"
|
||||
|
||||
@router.errors()
|
||||
async def error_handler(exception: Exception):
|
||||
async def error_handler(event: Update, exception: Exception):
|
||||
return "KABOOM"
|
||||
|
||||
response = await root_router.listen_update(
|
||||
update_type="message",
|
||||
update=update,
|
||||
event=update.message,
|
||||
from_user=update.message.from_user,
|
||||
chat=update.message.chat,
|
||||
)
|
||||
response = await root_router.update.trigger(update)
|
||||
assert response == "KABOOM"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue