Merge branch 'webhook' into dev

This commit is contained in:
Alex Root Junior 2017-08-09 04:47:02 +03:00
commit 882d60f477
13 changed files with 1859 additions and 38 deletions

View file

@ -151,7 +151,9 @@ class BaseBot:
:param payload: request payload :param payload: request payload
:return: resonse :return: resonse
""" """
if isinstance(file, str): if file is None:
files = {}
elif isinstance(file, str):
# You can use file ID or URL in the most of requests # You can use file ID or URL in the most of requests
payload[file_type] = file payload[file_type] = file
files = None files = None
@ -1076,8 +1078,7 @@ class BaseBot:
If you have created a Game and accepted the conditions via @Botfather, If you have created a Game and accepted the conditions via @Botfather,
specify the URL that opens your game note that this will only work specify the URL that opens your game note that this will only work
if the query comes from a callback_game button. if the query comes from a callback_game button.
Otherwise, you may use links like t.me/your_bot?start=XXXX that open your bot with a parameter.
Otherwise, you may use links like t.me/your_bot?start=XXXX that open your bot with a parameter.
:param cache_time: Integer (Optional) - The maximum amount of time in seconds that the result :param cache_time: Integer (Optional) - The maximum amount of time in seconds that the result
of the callback query may be cached client-side. Telegram apps will support of the callback query may be cached client-side. Telegram apps will support
caching starting in version 3.14. Defaults to 0. caching starting in version 3.14. Defaults to 0.

View file

@ -10,6 +10,12 @@ class MemoryStorage(BaseStorage):
This type of storage is not recommended for usage in bots, because you will lost all states after restarting. This type of storage is not recommended for usage in bots, because you will lost all states after restarting.
""" """
async def wait_closed(self):
pass
def close(self):
self.data.clear()
def __init__(self): def __init__(self):
self.data = {} self.data = {}

View file

@ -2,6 +2,7 @@
This module has redis storage for finite-state machine based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver This module has redis storage for finite-state machine based on `aioredis <https://github.com/aio-libs/aioredis>`_ driver
""" """
import asyncio
import typing import typing
import aioredis import aioredis
@ -21,6 +22,13 @@ class RedisStorage(BaseStorage):
storage = RedisStorage('localhost', 6379, db=5) storage = RedisStorage('localhost', 6379, db=5)
dp = Dispatcher(bot, storage=storage) dp = Dispatcher(bot, storage=storage)
And need to close Redis connection when shutdown
.. code-block:: python3
dp.storage.close()
await dp.storage.wait_closed()
""" """
def __init__(self, host, port, db=None, password=None, ssl=None, loop=None, **kwargs): def __init__(self, host, port, db=None, password=None, ssl=None, loop=None, **kwargs):
@ -29,10 +37,22 @@ class RedisStorage(BaseStorage):
self._db = db self._db = db
self._password = password self._password = password
self._ssl = ssl self._ssl = ssl
self._loop = loop self._loop = loop or asyncio.get_event_loop()
self._kwargs = kwargs self._kwargs = kwargs
self._redis: aioredis.RedisConnection = None self._redis: aioredis.RedisConnection = None
self._connection_lock = asyncio.Lock(loop=self._loop)
def close(self):
if self._redis and not self._redis.closed:
self._redis.close()
del self._redis
self._redis = None
async def wait_closed(self):
if self._redis:
return await self._redis.wait_closed()
return True
@property @property
async def redis(self) -> aioredis.RedisConnection: async def redis(self) -> aioredis.RedisConnection:
@ -41,11 +61,13 @@ class RedisStorage(BaseStorage):
This property is awaitable. This property is awaitable.
""" """
if self._redis is None: # Use thread-safe asyncio Lock because this method without that is not safe
self._redis = await aioredis.create_connection((self._host, self._port), async with self._connection_lock:
db=self._db, password=self._password, ssl=self._ssl, if self._redis is None:
loop=self._loop, self._redis = await aioredis.create_connection((self._host, self._port),
**self._kwargs) db=self._db, password=self._password, ssl=self._ssl,
loop=self._loop,
**self._kwargs)
return self._redis return self._redis
async def get_record(self, *, async def get_record(self, *,

View file

@ -5,6 +5,7 @@ import typing
from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters from .filters import CommandsFilter, RegexpFilter, ContentTypeFilter, generate_default_filters
from .handler import Handler from .handler import Handler
from .storage import DisabledStorage, BaseStorage, FSMContext from .storage import DisabledStorage, BaseStorage, FSMContext
from .webhook import BaseResponse
from ..bot import Bot from ..bot import Bot
from ..types.message import ContentType from ..types.message import ContentType
@ -74,8 +75,10 @@ class Dispatcher:
:param updates: :param updates:
:return: :return:
""" """
tasks = []
for update in updates: for update in updates:
self.loop.create_task(self.updates_handler.notify(update)) tasks.append(self.loop.create_task(self.updates_handler.notify(update)))
return await asyncio.gather(*tasks)
async def process_update(self, update): async def process_update(self, update):
""" """
@ -86,30 +89,31 @@ class Dispatcher:
""" """
self.last_update_id = update.update_id self.last_update_id = update.update_id
if update.message: if update.message:
await self.message_handlers.notify(update.message) return await self.message_handlers.notify(update.message)
if update.edited_message: if update.edited_message:
await self.edited_message_handlers.notify(update.edited_message) return await self.edited_message_handlers.notify(update.edited_message)
if update.channel_post: if update.channel_post:
await self.channel_post_handlers.notify(update.channel_post) return await self.channel_post_handlers.notify(update.channel_post)
if update.edited_channel_post: if update.edited_channel_post:
await self.edited_channel_post_handlers.notify(update.edited_channel_post) return await self.edited_channel_post_handlers.notify(update.edited_channel_post)
if update.inline_query: if update.inline_query:
await self.inline_query_handlers.notify(update.inline_query) return await self.inline_query_handlers.notify(update.inline_query)
if update.chosen_inline_result: if update.chosen_inline_result:
await self.chosen_inline_result_handlers.notify(update.chosen_inline_result) return await self.chosen_inline_result_handlers.notify(update.chosen_inline_result)
if update.callback_query: if update.callback_query:
await self.callback_query_handlers.notify(update.callback_query) return await self.callback_query_handlers.notify(update.callback_query)
if update.shipping_query: if update.shipping_query:
await self.shipping_query_handlers.notify(update.shipping_query) return await self.shipping_query_handlers.notify(update.shipping_query)
if update.pre_checkout_query: if update.pre_checkout_query:
await self.pre_checkout_query_handlers.notify(update.pre_checkout_query) return await self.pre_checkout_query_handlers.notify(update.pre_checkout_query)
async def start_pooling(self, timeout=20, relax=0.1): async def start_pooling(self, timeout=20, relax=0.1, limit=None):
""" """
Start long-pooling Start long-pooling
:param timeout: :param timeout:
:param relax: :param relax:
:param limit:
:return: :return:
""" """
if self._pooling: if self._pooling:
@ -120,21 +124,42 @@ class Dispatcher:
offset = None offset = None
while self._pooling: while self._pooling:
try: try:
updates = await self.bot.get_updates(offset=offset, timeout=timeout) updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
except Exception as e: except Exception as e:
log.exception('Cause exception while getting updates') log.exception('Cause exception while getting updates')
await asyncio.sleep(relax) if relax:
await asyncio.sleep(relax)
continue continue
if updates: if updates:
log.info("Received {0} updates.".format(len(updates))) log.info("Received {0} updates.".format(len(updates)))
offset = updates[-1].update_id + 1 offset = updates[-1].update_id + 1
await self.process_updates(updates)
self.loop.create_task(self._process_pooling_updates(updates))
await asyncio.sleep(relax) await asyncio.sleep(relax)
log.warning('Pooling is stopped.') log.warning('Pooling is stopped.')
async def _process_pooling_updates(self, updates):
"""
Process updates received from long-pooling.
:param updates: list of updates.
"""
need_to_call = []
for update in await self.process_updates(updates):
for responses in update:
for response in responses:
if not isinstance(response, BaseResponse):
continue
need_to_call.append(response.execute_response(self.bot))
if need_to_call:
try:
asyncio.gather(*need_to_call)
except Exception as e:
log.exception('Cause exception while processing updates.')
def stop_pooling(self): def stop_pooling(self):
""" """
Break long-pooling process. Break long-pooling process.

View file

@ -34,13 +34,19 @@ class Handler:
raise ValueError('This handler is not registered!') raise ValueError('This handler is not registered!')
async def notify(self, *args, **kwargs): async def notify(self, *args, **kwargs):
results = []
for filters, handler in self.handlers: for filters, handler in self.handlers:
if await check_filters(filters, args, kwargs): if await check_filters(filters, args, kwargs):
try: try:
await handler(*args, **kwargs) response = await handler(*args, **kwargs)
if results is not None:
results.append(response)
if self.once: if self.once:
break break
except SkipHandler: except SkipHandler:
continue continue
except CancelHandler: except CancelHandler:
break break
return results

View file

@ -6,6 +6,23 @@ class BaseStorage:
In states-storage you can save current user state and data for all steps In states-storage you can save current user state and data for all steps
""" """
def close(self):
"""
Need override this method and use when application is shutdowns.
You can save data or etc.
:return:
"""
raise NotImplementedError
async def wait_closed(self):
"""
You need override this method for all asynchronously storage's like Redis.
:return:
"""
raise NotImplementedError
@classmethod @classmethod
def check_address(cls, *, def check_address(cls, *,
chat: typing.Union[str, int, None] = None, chat: typing.Union[str, int, None] = None,
@ -209,6 +226,12 @@ class DisabledStorage(BaseStorage):
Empty storage. Use it if you don't want to use Finite-State Machine Empty storage. Use it if you don't want to use Finite-State Machine
""" """
def close(self):
pass
async def wait_closed(self):
pass
async def get_state(self, *, async def get_state(self, *,
chat: typing.Union[str, int, None] = None, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,

File diff suppressed because it is too large Load diff

View file

@ -59,6 +59,11 @@ class Deserializable:
result[name] = attr result[name] = attr
return result return result
@classmethod
def _parse_date(cls, unix_time):
if unix_time is not None:
return datetime.datetime.fromtimestamp(unix_time)
@property @property
def bot(self) -> 'Bot': def bot(self) -> 'Bot':
""" """

View file

@ -35,11 +35,6 @@ class ChatMember(Deserializable):
self.can_send_other_messages: bool = can_send_other_messages self.can_send_other_messages: bool = can_send_other_messages
self.can_add_web_page_previews: bool = can_add_web_page_previews self.can_add_web_page_previews: bool = can_add_web_page_previews
@classmethod
def _parse_date(cls, unix_time):
if unix_time is not None:
return datetime.datetime.fromtimestamp(unix_time)
@classmethod @classmethod
def de_json(cls, raw_data): def de_json(cls, raw_data):
user = User.deserialize(raw_data.get('user')) user = User.deserialize(raw_data.get('user'))

View file

@ -74,10 +74,6 @@ class Message(Deserializable):
self.content_type = content_type self.content_type = content_type
@classmethod
def _parse_date(cls, unix_time):
return datetime.datetime.fromtimestamp(unix_time)
@classmethod @classmethod
def de_json(cls, raw_data): def de_json(cls, raw_data):
message_id = raw_data.get('message_id') message_id = raw_data.get('message_id')

View file

@ -19,10 +19,6 @@ class WebhookInfo(Deserializable):
self.max_connections: int = max_connections self.max_connections: int = max_connections
self.allowed_updates: [str] = allowed_updates self.allowed_updates: [str] = allowed_updates
@classmethod
def _parse_date(cls, unix_time):
return datetime.datetime.fromtimestamp(unix_time)
@classmethod @classmethod
def de_json(cls, raw_data): def de_json(cls, raw_data):
url = raw_data.get('url') url = raw_data.get('url')

View file

@ -15,7 +15,9 @@ def generate_payload(exclude=None, **kwargs):
def prepare_arg(value): def prepare_arg(value):
if isinstance(value, (list, dict)): if value is None:
return None
elif isinstance(value, (list, dict)):
return json.dumps(value) return json.dumps(value)
elif hasattr(value, 'to_json'): elif hasattr(value, 'to_json'):
return json.dumps(value.to_json()) return json.dumps(value.to_json())

View file

@ -32,3 +32,10 @@ Middlewares
.. automodule:: aiogram.dispatcher.middlewares .. automodule:: aiogram.dispatcher.middlewares
:members: :members:
:show-inheritance: :show-inheritance:
Webhook
-------
.. automodule:: aiogram.dispatcher.webhook
:members:
:show-inheritance: