Implement new middlewares

This commit is contained in:
Alex Root Junior 2020-05-26 00:23:35 +03:00
parent c262cc0ce6
commit 7f26ec9935
29 changed files with 532 additions and 1252 deletions

View file

@ -1,6 +1,6 @@
import pytest
from aiogram.dispatcher.event.observer import TelegramEventObserver
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.router import Router
from tests.deprecated import check_deprecated

View file

@ -9,6 +9,7 @@ from aiogram import Bot
from aiogram.api.methods import GetMe, GetUpdates, SendMessage
from aiogram.api.types import Chat, Message, Update, User
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NOT_HANDLED
from aiogram.dispatcher.router import Router
from tests.mocked_bot import MockedBot
@ -63,7 +64,7 @@ class TestDispatcher:
return message.text
results_count = 0
async for result in dp.feed_update(
result = await dp.feed_update(
bot=bot,
update=Update(
update_id=42,
@ -75,11 +76,9 @@ class TestDispatcher:
from_user=User(id=42, is_bot=False, first_name="Test"),
),
),
):
results_count += 1
assert result == "test"
assert results_count == 1
)
results_count += 1
assert result == "test"
@pytest.mark.asyncio
async def test_feed_raw_update(self):
@ -91,8 +90,7 @@ class TestDispatcher:
assert message.text == "test"
return message.text
handled = False
async for result in dp.feed_raw_update(
result = await dp.feed_raw_update(
bot=bot,
update={
"update_id": 42,
@ -101,13 +99,11 @@ class TestDispatcher:
"date": int(time.time()),
"text": "test",
"chat": {"id": 42, "type": "private"},
"user": {"id": 42, "is_bot": False, "first_name": "Test"},
"from": {"id": 42, "is_bot": False, "first_name": "Test"},
},
},
):
handled = True
assert result == "test"
assert handled
)
assert result == "test"
@pytest.mark.asyncio
async def test_listen_updates(self, bot: MockedBot):
@ -136,7 +132,8 @@ class TestDispatcher:
async def test_process_update_empty(self, bot: MockedBot):
dispatcher = Dispatcher()
assert not await dispatcher.process_update(bot=bot, update=Update(update_id=42))
result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
assert result
@pytest.mark.asyncio
async def test_process_update_handled(self, bot: MockedBot):
@ -146,22 +143,25 @@ class TestDispatcher:
async def update_handler(update: Update):
pass
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
@pytest.mark.asyncio
async def test_process_update_call_request(self, bot: MockedBot):
dispatcher = Dispatcher()
@dispatcher.update()
async def update_handler(update: Update):
async def message_handler(update: Update):
return GetMe()
dispatcher.update.handlers.reverse()
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher._silent_call_request",
new_callable=CoroutineMock,
) as mocked_silent_call_request:
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
mocked_silent_call_request.assert_awaited_once()
result = await dispatcher._process_update(bot=bot, update=Update(update_id=42))
print(result)
mocked_silent_call_request.assert_awaited()
@pytest.mark.asyncio
async def test_process_update_exception(self, bot: MockedBot, caplog):
@ -171,7 +171,7 @@ class TestDispatcher:
async def update_handler(update: Update):
raise Exception("Kaboom!")
assert await dispatcher.process_update(bot=bot, update=Update(update_id=42))
assert await dispatcher._process_update(bot=bot, update=Update(update_id=42))
log_records = [rec.message for rec in caplog.records]
assert len(log_records) == 1
assert "Cause exception while process update" in log_records[0]
@ -184,7 +184,7 @@ class TestDispatcher:
yield Update(update_id=42)
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
) as mocked_process_update, patch(
"aiogram.dispatcher.dispatcher.Dispatcher._listen_updates"
) as patched_listen_updates:
@ -203,7 +203,7 @@ class TestDispatcher:
yield Update(update_id=42)
with patch(
"aiogram.dispatcher.dispatcher.Dispatcher.process_update", new_callable=CoroutineMock
"aiogram.dispatcher.dispatcher.Dispatcher._process_update", new_callable=CoroutineMock
) as mocked_process_update, patch(
"aiogram.dispatcher.router.Router.emit_startup", new_callable=CoroutineMock
) as mocked_emit_startup, patch(

View file

@ -0,0 +1,59 @@
import functools
from typing import Any
import pytest
from aiogram.dispatcher.event.event import EventObserver
from aiogram.dispatcher.event.handler import HandlerObject
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
async def my_handler(value: str, index: int = 0) -> Any:
return value
class TestEventObserver:
@pytest.mark.parametrize("via_decorator", [True, False])
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, via_decorator, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
if via_decorator:
register_result = observer()(wrapped_handler)
assert register_result == wrapped_handler
else:
register_result = observer.register(wrapped_handler)
assert register_result is None
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger(self):
observer = EventObserver()
observer.register(my_handler)
observer.register(lambda e: True)
observer.register(my_handler)
assert observer.handlers[0].awaitable
assert not observer.handlers[1].awaitable
assert observer.handlers[2].awaitable
with patch(
"aiogram.dispatcher.event.handler.CallableMixin.call", new_callable=CoroutineMock,
) as mocked_my_handler:
results = await observer.trigger("test")
assert results is None
mocked_my_handler.assert_awaited_with("test")
assert mocked_my_handler.call_count == 3

View file

@ -5,11 +5,14 @@ from typing import Any, Awaitable, Callable, Dict, NoReturn, Union
import pytest
from aiogram.api.types import Chat, Message, User
from aiogram.dispatcher.event.bases import SkipHandler
from aiogram.dispatcher.event.handler import HandlerObject
from aiogram.dispatcher.event.observer import EventObserver, SkipHandler, TelegramEventObserver
from aiogram.dispatcher.event.telegram import TelegramEventObserver
from aiogram.dispatcher.filters.base import BaseFilter
from aiogram.dispatcher.router import Router
# TODO: Test middlewares in routers tree
async def my_handler(event: Any, index: int = 0) -> Any:
return event
@ -38,54 +41,6 @@ class MyFilter3(MyFilter1):
pass
class TestEventObserver:
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer.register(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.parametrize("count,handler", ([5, my_handler], [3, my_handler], [2, my_handler]))
def test_register_filters_via_decorator(self, count, handler):
observer = EventObserver()
for index in range(count):
wrapped_handler = functools.partial(handler, index=index)
observer()(wrapped_handler)
registered_handler = observer.handlers[index]
assert len(observer.handlers) == index + 1
assert isinstance(registered_handler, HandlerObject)
assert registered_handler.callback == wrapped_handler
assert not registered_handler.filters
@pytest.mark.asyncio
async def test_trigger_accepted_bool(self):
observer = EventObserver()
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42]
@pytest.mark.asyncio
async def test_trigger_with_skip(self):
observer = EventObserver()
observer.register(skip_my_handler)
observer.register(my_handler)
observer.register(my_handler)
results = [result async for result in observer.trigger(42)]
assert results == [42, 42]
class TestTelegramEventObserver:
def test_bind_filter(self):
event_observer = TelegramEventObserver(Router(), "test")
@ -198,8 +153,8 @@ class TestTelegramEventObserver:
from_user=User(id=42, is_bot=False, first_name="Test"),
)
results = [result async for result in observer.trigger(message)]
assert results == [message]
results = await observer.trigger(message)
assert results is message
@pytest.mark.parametrize(
"count,handler,filters",
@ -223,15 +178,58 @@ class TestTelegramEventObserver:
assert registered_handler.callback == wrapped_handler
assert len(registered_handler.filters) == len(filters)
#
@pytest.mark.asyncio
async def test_trigger_right_context_in_handlers(self):
router = Router(use_builtin_filters=False)
observer = router.message
observer.register(
pipe_handler, lambda event: {"a": 1}, lambda event: False
) # {"a": 1} should not be in result
observer.register(pipe_handler, lambda event: {"b": 2})
results = [result async for result in observer.trigger(42)]
assert results == [((42,), {"b": 2})]
async def mix_unnecessary_data(event):
return {"a": 1}
async def mix_data(event):
return {"b": 2}
async def handler(event, **kwargs):
return False
observer.register(
pipe_handler, mix_unnecessary_data, handler
) # {"a": 1} should not be in result
observer.register(pipe_handler, mix_data)
results = await observer.trigger(42)
assert results == ((42,), {"b": 2})
@pytest.mark.parametrize("middleware_type", ("middleware", "outer_middleware"))
def test_register_middleware(self, middleware_type):
event_observer = TelegramEventObserver(Router(), "test")
middlewares = getattr(event_observer, f"{middleware_type}s")
decorator = getattr(event_observer, middleware_type)
@decorator
async def my_middleware1(handler, event, data):
pass
assert my_middleware1 is not None
assert my_middleware1.__name__ == "my_middleware1"
assert my_middleware1 in middlewares
@decorator()
async def my_middleware2(handler, event, data):
pass
assert my_middleware2 is not None
assert my_middleware2.__name__ == "my_middleware2"
assert my_middleware2 in middlewares
async def my_middleware3(handler, event, data):
pass
decorator(my_middleware3)
assert my_middleware3 is not None
assert my_middleware3.__name__ == "my_middleware3"
assert my_middleware3 in middlewares
assert middlewares == [my_middleware1, my_middleware2, my_middleware3]

View file

@ -1,257 +0,0 @@
import datetime
from typing import Any, Dict, Type
import pytest
from aiogram.api.types import (
CallbackQuery,
Chat,
ChosenInlineResult,
InlineQuery,
Message,
Poll,
PollAnswer,
PreCheckoutQuery,
ShippingQuery,
Update,
User,
)
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.types import MiddlewareStep, UpdateType
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
class MyMiddleware(BaseMiddleware):
async def on_pre_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_pre_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_pre_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_pre_process_channel_post(
self, channel_post: Message, data: Dict[str, Any]
) -> Any:
return "channel_post"
async def on_pre_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_pre_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_pre_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_pre_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_pre_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_pre_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any]
) -> Any:
return "poll_answer"
async def on_pre_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
return "error"
async def on_process_update(self, update: Update, data: Dict[str, Any]) -> Any:
return "update"
async def on_process_message(self, message: Message, data: Dict[str, Any]) -> Any:
return "message"
async def on_process_edited_message(
self, edited_message: Message, data: Dict[str, Any]
) -> Any:
return "edited_message"
async def on_process_channel_post(self, channel_post: Message, data: Dict[str, Any]) -> Any:
return "channel_post"
async def on_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any]
) -> Any:
return "edited_channel_post"
async def on_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any]
) -> Any:
return "inline_query"
async def on_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any]
) -> Any:
return "chosen_inline_result"
async def on_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any]
) -> Any:
return "callback_query"
async def on_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any]
) -> Any:
return "shipping_query"
async def on_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any]
) -> Any:
return "pre_checkout_query"
async def on_process_poll(self, poll: Poll, data: Dict[str, Any]) -> Any:
return "poll"
async def on_process_poll_answer(self, poll_answer: PollAnswer, data: Dict[str, Any]) -> Any:
return "poll_answer"
async def on_process_error(self, exception: Exception, data: Dict[str, Any]) -> Any:
return "error"
async def on_post_process_update(
self, update: Update, data: Dict[str, Any], result: Any
) -> Any:
return "update"
async def on_post_process_message(
self, message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "message"
async def on_post_process_edited_message(
self, edited_message: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_message"
async def on_post_process_channel_post(
self, channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "channel_post"
async def on_post_process_edited_channel_post(
self, edited_channel_post: Message, data: Dict[str, Any], result: Any
) -> Any:
return "edited_channel_post"
async def on_post_process_inline_query(
self, inline_query: InlineQuery, data: Dict[str, Any], result: Any
) -> Any:
return "inline_query"
async def on_post_process_chosen_inline_result(
self, chosen_inline_result: ChosenInlineResult, data: Dict[str, Any], result: Any
) -> Any:
return "chosen_inline_result"
async def on_post_process_callback_query(
self, callback_query: CallbackQuery, data: Dict[str, Any], result: Any
) -> Any:
return "callback_query"
async def on_post_process_shipping_query(
self, shipping_query: ShippingQuery, data: Dict[str, Any], result: Any
) -> Any:
return "shipping_query"
async def on_post_process_pre_checkout_query(
self, pre_checkout_query: PreCheckoutQuery, data: Dict[str, Any], result: Any
) -> Any:
return "pre_checkout_query"
async def on_post_process_poll(self, poll: Poll, data: Dict[str, Any], result: Any) -> Any:
return "poll"
async def on_post_process_poll_answer(
self, poll_answer: PollAnswer, data: Dict[str, Any], result: Any
) -> Any:
return "poll_answer"
async def on_post_process_error(
self, exception: Exception, data: Dict[str, Any], result: Any
) -> Any:
return "error"
UPDATE = Update(update_id=42)
MESSAGE = Message(message_id=42, date=datetime.datetime.now(), chat=Chat(id=42, type="private"))
POLL_ANSWER = PollAnswer(
poll_id="poll", user=User(id=42, is_bot=False, first_name="Test"), option_ids=[0]
)
class TestBaseMiddleware:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"middleware_cls,should_be_awaited", [[MyMiddleware, True], [BaseMiddleware, False]]
)
@pytest.mark.parametrize(
"step", [MiddlewareStep.PRE_PROCESS, MiddlewareStep.PROCESS, MiddlewareStep.POST_PROCESS]
)
@pytest.mark.parametrize(
"event_name,event",
[
["update", UPDATE],
["message", MESSAGE],
["poll_answer", POLL_ANSWER],
["error", Exception("KABOOM")],
],
)
async def test_trigger(
self,
step: MiddlewareStep,
event_name: str,
event: UpdateType,
middleware_cls: Type[BaseMiddleware],
should_be_awaited: bool,
):
middleware = middleware_cls()
with patch(
f"tests.test_dispatcher.test_middlewares.test_base."
f"MyMiddleware.on_{step.value}_{event_name}",
new_callable=CoroutineMock,
) as mocked_call:
response = await middleware.trigger(
step=step, event_name=event_name, event=event, data={}
)
if should_be_awaited:
mocked_call.assert_awaited()
assert response is not None
else:
mocked_call.assert_not_awaited()
assert response is None
def test_not_configured(self):
middleware = BaseMiddleware()
assert not middleware.configured
with pytest.raises(RuntimeError):
manager = middleware.manager

View file

@ -1,82 +0,0 @@
import pytest
from aiogram import Router
from aiogram.api.types import Update
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.manager import MiddlewareManager
from aiogram.dispatcher.middlewares.types import MiddlewareStep
try:
from asynctest import CoroutineMock, patch
except ImportError:
from unittest.mock import AsyncMock as CoroutineMock, patch # type: ignore
@pytest.fixture("function")
def router():
return Router()
@pytest.fixture("function")
def manager(router: Router):
return MiddlewareManager(router)
class TestManager:
def test_setup(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
returned = manager.setup(middleware)
assert returned is middleware
assert middleware.configured
assert middleware.manager is manager
assert middleware in manager
@pytest.mark.parametrize("obj", [object, object(), None, BaseMiddleware])
def test_setup_invalid_type(self, manager: MiddlewareManager, obj):
with pytest.raises(TypeError):
assert manager.setup(obj)
def test_configure_twice_different_managers(self, manager: MiddlewareManager, router: Router):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
new_manager = MiddlewareManager(router)
with pytest.raises(ValueError):
new_manager.setup(middleware)
with pytest.raises(ValueError):
middleware.setup(new_manager)
def test_configure_twice(self, manager: MiddlewareManager):
middleware = BaseMiddleware()
manager.setup(middleware)
assert middleware.configured
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
manager.setup(middleware)
with pytest.warns(RuntimeWarning, match="is already configured for this Router"):
middleware.setup(manager)
@pytest.mark.asyncio
@pytest.mark.parametrize("count", range(5))
async def test_trigger(self, manager: MiddlewareManager, count: int):
for _ in range(count):
manager.setup(BaseMiddleware())
with patch(
"aiogram.dispatcher.middlewares.base.BaseMiddleware.trigger",
new_callable=CoroutineMock,
) as mocked_call:
await manager.trigger(
step=MiddlewareStep.PROCESS,
event_name="update",
event=Update(update_id=42),
data={},
result=None,
reverse=True,
)
assert mocked_call.await_count == count

View file

@ -10,6 +10,7 @@ from aiogram.api.types import (
InlineQuery,
Message,
Poll,
PollAnswer,
PollOption,
PreCheckoutQuery,
ShippingAddress,
@ -17,8 +18,8 @@ from aiogram.api.types import (
Update,
User,
)
from aiogram.dispatcher.event.observer import SkipHandler, skip
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.event.bases import NOT_HANDLED, SkipHandler, skip
from aiogram.dispatcher.middlewares.update_processing_context import UserContextMiddleware
from aiogram.dispatcher.router import Router
from aiogram.utils.warnings import CodeHasNoEffect
@ -274,12 +275,26 @@ class TestRouter:
False,
False,
),
pytest.param(
"poll_answer",
Update(
update_id=42,
poll_answer=PollAnswer(
poll_id="poll id",
user=User(id=42, is_bot=False, first_name="Test"),
option_ids=[42],
),
),
False,
True,
),
],
)
async def test_listen_update(
self, event_type: str, update: Update, has_chat: bool, has_user: bool
):
router = Router()
router.update.outer_middleware(UserContextMiddleware())
observer = router.observers[event_type]
@observer()
@ -291,7 +306,7 @@ class TestRouter:
assert User.get_current(False)
return kwargs
result = await router._listen_update(update, test="PASS")
result = await router.update.trigger(update, test="PASS")
assert isinstance(result, dict)
assert result["event_update"] == update
assert result["event_router"] == router
@ -313,26 +328,26 @@ class TestRouter:
async def handler(event: Any):
pass
with pytest.raises(SkipHandler):
await router._listen_update(
Update(
update_id=42,
poll=Poll(
id="poll id",
question="Q?",
options=[
PollOption(text="A1", voter_count=2),
PollOption(text="A2", voter_count=3),
],
is_closed=False,
is_anonymous=False,
type="quiz",
allows_multiple_answers=False,
total_voter_count=0,
correct_option_id=0,
),
)
response = await router._listen_update(
Update(
update_id=42,
poll=Poll(
id="poll id",
question="Q?",
options=[
PollOption(text="A1", voter_count=2),
PollOption(text="A2", voter_count=3),
],
is_closed=False,
is_anonymous=False,
type="quiz",
allows_multiple_answers=False,
total_voter_count=0,
correct_option_id=0,
),
)
)
assert response is NOT_HANDLED
@pytest.mark.asyncio
async def test_nested_router_listen_update(self):
@ -345,8 +360,6 @@ class TestRouter:
@observer()
async def my_handler(event: Message, **kwargs: Any):
assert Chat.get_current(False)
assert User.get_current(False)
return kwargs
update = Update(
@ -409,14 +422,6 @@ class TestRouter:
await router1.emit_shutdown()
assert results == [2, 1, 2]
def test_use(self):
router = Router()
middleware = router.use(BaseMiddleware())
assert isinstance(middleware, BaseMiddleware)
assert middleware.configured
assert middleware.manager == router.middleware
def test_skip(self):
with pytest.raises(SkipHandler):
skip()
@ -444,37 +449,20 @@ class TestRouter:
),
)
with pytest.raises(Exception, match="KABOOM"):
await root_router.listen_update(
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
await root_router.update.trigger(update)
@root_router.errors()
async def root_error_handler(exception: Exception):
async def root_error_handler(event: Update, exception: Exception):
return exception
response = await root_router.listen_update(
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
response = await root_router.update.trigger(update)
assert isinstance(response, Exception)
assert str(response) == "KABOOM"
@router.errors()
async def error_handler(exception: Exception):
async def error_handler(event: Update, exception: Exception):
return "KABOOM"
response = await root_router.listen_update(
update_type="message",
update=update,
event=update.message,
from_user=update.message.from_user,
chat=update.message.chat,
)
response = await root_router.update.trigger(update)
assert response == "KABOOM"