mirror of
https://github.com/aiogram/aiogram.git
synced 2025-12-11 01:54:53 +00:00
Optimize state filter.
This commit is contained in:
parent
0fcb75e997
commit
a57c91067e
3 changed files with 66 additions and 9 deletions
|
|
@ -3,16 +3,20 @@ import functools
|
|||
import logging
|
||||
import typing
|
||||
|
||||
from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters
|
||||
from .filters import CommandsFilter, ContentTypeFilter, RegexpFilter, USER_STATE, generate_default_filters
|
||||
from .handler import Handler
|
||||
from .storage import DisabledStorage, BaseStorage, FSMContext
|
||||
from .storage import BaseStorage, DisabledStorage, FSMContext
|
||||
from .webhook import BaseResponse
|
||||
from ..bot import Bot
|
||||
from ..types.message import ContentType
|
||||
from ..utils.exceptions import TelegramAPIError, NetworkError
|
||||
from ..utils import context
|
||||
from ..utils.exceptions import NetworkError, TelegramAPIError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
MODE = 'MODE'
|
||||
LONG_POOLING = 'long-pooling'
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
"""
|
||||
|
|
@ -79,7 +83,7 @@ class Dispatcher:
|
|||
"""
|
||||
tasks = []
|
||||
for update in updates:
|
||||
tasks.append(self.updates_handler.notify(update))
|
||||
tasks.append(self.process_update(update))
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def process_update(self, update):
|
||||
|
|
@ -90,23 +94,56 @@ class Dispatcher:
|
|||
:return:
|
||||
"""
|
||||
self.last_update_id = update.update_id
|
||||
has_context = context.check_configured()
|
||||
if update.message:
|
||||
if has_context:
|
||||
state = self.storage.get_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.message_handlers.notify(update.message)
|
||||
if update.edited_message:
|
||||
if has_context:
|
||||
state = self.storage.get_state(chat=update.edited_message.chat.id,
|
||||
user=update.edited_message.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.edited_message_handlers.notify(update.edited_message)
|
||||
if update.channel_post:
|
||||
if has_context:
|
||||
state = self.storage.get_state(chat=update.message.chat.id,
|
||||
user=update.message.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.channel_post_handlers.notify(update.channel_post)
|
||||
if update.edited_channel_post:
|
||||
if has_context:
|
||||
state = self.storage.get_state(chat=update.edited_channel_post.chat.id,
|
||||
user=update.edited_channel_post.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
|
||||
if update.inline_query:
|
||||
if has_context:
|
||||
state = self.storage.get_state(user=update.inline_query.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.inline_query_handlers.notify(update.inline_query)
|
||||
if update.chosen_inline_result:
|
||||
if has_context:
|
||||
state = self.storage.get_state(user=update.chosen_inline_result.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
|
||||
if update.callback_query:
|
||||
if has_context:
|
||||
state = self.storage.get_state(chat=update.callback_query.message.chat.id,
|
||||
user=update.callback_query.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.callback_query_handlers.notify(update.callback_query)
|
||||
if update.shipping_query:
|
||||
if has_context:
|
||||
state = self.storage.get_state(user=update.shipping_query.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.shipping_query_handlers.notify(update.shipping_query)
|
||||
if update.pre_checkout_query:
|
||||
if has_context:
|
||||
state = self.storage.get_state(user=update.pre_checkout_query.from_user.id)
|
||||
context.set_value(USER_STATE, await state)
|
||||
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
|
||||
|
||||
async def start_pooling(self, timeout=20, relax=0.1, limit=None):
|
||||
|
|
@ -121,6 +158,7 @@ class Dispatcher:
|
|||
if self._pooling:
|
||||
raise RuntimeError('Pooling already started')
|
||||
log.info('Start pooling.')
|
||||
context.set_value(MODE, LONG_POOLING)
|
||||
|
||||
self._pooling = True
|
||||
offset = None
|
||||
|
|
@ -730,6 +768,7 @@ class Dispatcher:
|
|||
:param func:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def process_response(task):
|
||||
response = task.result()
|
||||
self.loop.create_task(response.execute_response(self.bot))
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import inspect
|
||||
import re
|
||||
|
||||
from aiogram.utils import context
|
||||
from ..utils.helper import Helper, HelperMode, Item
|
||||
|
||||
USER_STATE = 'USER_STATE'
|
||||
|
||||
|
||||
async def check_filter(filter_, args, kwargs):
|
||||
if not callable(filter_):
|
||||
|
|
@ -102,10 +105,14 @@ class StateFilter(AsyncFilter):
|
|||
if self.state == '*':
|
||||
return True
|
||||
|
||||
chat, user = self.get_target(obj)
|
||||
if context.check_value(USER_STATE):
|
||||
context_state = context.get_value(USER_STATE)
|
||||
return self.state == context_state
|
||||
else:
|
||||
chat, user = self.get_target(obj)
|
||||
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
|
||||
if chat or user:
|
||||
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ import asyncio.tasks
|
|||
import datetime
|
||||
import functools
|
||||
import typing
|
||||
from typing import Union, Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .. import types
|
||||
from ..bot import api
|
||||
from ..bot.base import Integer, String, Boolean, Float
|
||||
from ..bot.base import Boolean, Float, Integer, String
|
||||
from ..utils import context
|
||||
from ..utils import json
|
||||
from ..utils.deprecated import warn_deprecated as warn
|
||||
from ..utils.exceptions import TimeoutWarning
|
||||
|
|
@ -20,6 +21,10 @@ BOT_DISPATCHER_KEY = 'BOT_DISPATCHER'
|
|||
|
||||
RESPONSE_TIMEOUT = 55
|
||||
|
||||
WEBHOOK = 'webhook'
|
||||
WEBHOOK_CONNECTION = 'WEBHOOK_CONNECTION'
|
||||
WEBHOOK_REQUEST = 'WEBHOOK_REQUEST'
|
||||
|
||||
|
||||
class WebhookRequestHandler(web.View):
|
||||
"""
|
||||
|
|
@ -71,6 +76,11 @@ class WebhookRequestHandler(web.View):
|
|||
|
||||
:return: :class:`aiohttp.web.Response`
|
||||
"""
|
||||
|
||||
context.update_state({'CALLER': WEBHOOK,
|
||||
WEBHOOK_CONNECTION: True,
|
||||
WEBHOOK_REQUEST: self.request})
|
||||
|
||||
dispatcher = self.get_dispatcher()
|
||||
update = await self.parse_update(dispatcher.bot)
|
||||
|
||||
|
|
@ -113,6 +123,7 @@ class WebhookRequestHandler(web.View):
|
|||
if fut.done():
|
||||
return fut.result()
|
||||
else:
|
||||
context.set_value(WEBHOOK_CONNECTION, False)
|
||||
fut.remove_done_callback(cb)
|
||||
fut.add_done_callback(self.respond_via_request)
|
||||
finally:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue