Optimize state filter.

This commit is contained in:
Alex Root Junior 2017-08-26 18:02:01 +03:00
parent 0fcb75e997
commit a57c91067e
3 changed files with 66 additions and 9 deletions

View file

@ -3,16 +3,20 @@ import functools
import logging
import typing
from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters
from .filters import CommandsFilter, ContentTypeFilter, RegexpFilter, USER_STATE, generate_default_filters
from .handler import Handler
from .storage import DisabledStorage, BaseStorage, FSMContext
from .storage import BaseStorage, DisabledStorage, FSMContext
from .webhook import BaseResponse
from ..bot import Bot
from ..types.message import ContentType
from ..utils.exceptions import TelegramAPIError, NetworkError
from ..utils import context
from ..utils.exceptions import NetworkError, TelegramAPIError
log = logging.getLogger(__name__)
MODE = 'MODE'
LONG_POOLING = 'long-pooling'
class Dispatcher:
"""
@ -79,7 +83,7 @@ class Dispatcher:
"""
tasks = []
for update in updates:
tasks.append(self.updates_handler.notify(update))
tasks.append(self.process_update(update))
return await asyncio.gather(*tasks)
async def process_update(self, update):
@ -90,23 +94,56 @@ class Dispatcher:
:return:
"""
self.last_update_id = update.update_id
has_context = context.check_configured()
if update.message:
if has_context:
state = self.storage.get_state(chat=update.message.chat.id,
user=update.message.from_user.id)
context.set_value(USER_STATE, await state)
return await self.message_handlers.notify(update.message)
if update.edited_message:
if has_context:
state = self.storage.get_state(chat=update.edited_message.chat.id,
user=update.edited_message.from_user.id)
context.set_value(USER_STATE, await state)
return await self.edited_message_handlers.notify(update.edited_message)
if update.channel_post:
if has_context:
state = self.storage.get_state(chat=update.message.chat.id,
user=update.message.from_user.id)
context.set_value(USER_STATE, await state)
return await self.channel_post_handlers.notify(update.channel_post)
if update.edited_channel_post:
if has_context:
state = self.storage.get_state(chat=update.edited_channel_post.chat.id,
user=update.edited_channel_post.from_user.id)
context.set_value(USER_STATE, await state)
return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
if update.inline_query:
if has_context:
state = self.storage.get_state(user=update.inline_query.from_user.id)
context.set_value(USER_STATE, await state)
return await self.inline_query_handlers.notify(update.inline_query)
if update.chosen_inline_result:
if has_context:
state = self.storage.get_state(user=update.chosen_inline_result.from_user.id)
context.set_value(USER_STATE, await state)
return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
if update.callback_query:
if has_context:
state = self.storage.get_state(chat=update.callback_query.message.chat.id,
user=update.callback_query.from_user.id)
context.set_value(USER_STATE, await state)
return await self.callback_query_handlers.notify(update.callback_query)
if update.shipping_query:
if has_context:
state = self.storage.get_state(user=update.shipping_query.from_user.id)
context.set_value(USER_STATE, await state)
return await self.shipping_query_handlers.notify(update.shipping_query)
if update.pre_checkout_query:
if has_context:
state = self.storage.get_state(user=update.pre_checkout_query.from_user.id)
context.set_value(USER_STATE, await state)
return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
async def start_pooling(self, timeout=20, relax=0.1, limit=None):
@ -121,6 +158,7 @@ class Dispatcher:
if self._pooling:
raise RuntimeError('Pooling already started')
log.info('Start pooling.')
context.set_value(MODE, LONG_POOLING)
self._pooling = True
offset = None
@ -730,6 +768,7 @@ class Dispatcher:
:param func:
:return:
"""
def process_response(task):
response = task.result()
self.loop.create_task(response.execute_response(self.bot))

View file

@ -1,8 +1,11 @@
import inspect
import re
from aiogram.utils import context
from ..utils.helper import Helper, HelperMode, Item
USER_STATE = 'USER_STATE'
async def check_filter(filter_, args, kwargs):
if not callable(filter_):
@ -102,10 +105,14 @@ class StateFilter(AsyncFilter):
if self.state == '*':
return True
chat, user = self.get_target(obj)
if context.check_value(USER_STATE):
context_state = context.get_value(USER_STATE)
return self.state == context_state
else:
chat, user = self.get_target(obj)
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
if chat or user:
return await self.dispatcher.storage.get_state(chat=chat, user=user) == self.state
return False

View file

@ -3,13 +3,14 @@ import asyncio.tasks
import datetime
import functools
import typing
from typing import Union, Dict, Optional
from typing import Dict, Optional, Union
from aiohttp import web
from .. import types
from ..bot import api
from ..bot.base import Integer, String, Boolean, Float
from ..bot.base import Boolean, Float, Integer, String
from ..utils import context
from ..utils import json
from ..utils.deprecated import warn_deprecated as warn
from ..utils.exceptions import TimeoutWarning
@ -20,6 +21,10 @@ BOT_DISPATCHER_KEY = 'BOT_DISPATCHER'
RESPONSE_TIMEOUT = 55
WEBHOOK = 'webhook'
WEBHOOK_CONNECTION = 'WEBHOOK_CONNECTION'
WEBHOOK_REQUEST = 'WEBHOOK_REQUEST'
class WebhookRequestHandler(web.View):
"""
@ -71,6 +76,11 @@ class WebhookRequestHandler(web.View):
:return: :class:`aiohttp.web.Response`
"""
context.update_state({'CALLER': WEBHOOK,
WEBHOOK_CONNECTION: True,
WEBHOOK_REQUEST: self.request})
dispatcher = self.get_dispatcher()
update = await self.parse_update(dispatcher.bot)
@ -113,6 +123,7 @@ class WebhookRequestHandler(web.View):
if fut.done():
return fut.result()
else:
context.set_value(WEBHOOK_CONNECTION, False)
fut.remove_done_callback(cb)
fut.add_done_callback(self.respond_via_request)
finally: