From 96042d13ecc6afe62cb3db65be326fbdb1ef36fc Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Sun, 15 Apr 2018 06:18:49 +0300 Subject: [PATCH] Refactoring of executor util --- aiogram/utils/executor.py | 276 ++++++++++++++++++++++++++++---------- 1 file changed, 204 insertions(+), 72 deletions(-) diff --git a/aiogram/utils/executor.py b/aiogram/utils/executor.py index 51ecd924..b375f135 100644 --- a/aiogram/utils/executor.py +++ b/aiogram/utils/executor.py @@ -1,87 +1,219 @@ +import asyncio +import datetime +import functools +import secrets +from warnings import warn + from aiohttp import web from . import context from ..bot.api import log -from ..dispatcher import Dispatcher -from ..dispatcher.webhook import BOT_DISPATCHER_KEY, get_new_configured_app +from ..dispatcher.webhook import BOT_DISPATCHER_KEY, WebhookRequestHandler + +APP_EXECUTOR_KEY = 'APP_EXECUTOR' -async def _startup(dispatcher: Dispatcher, skip_updates=False, callback=None): - user = await dispatcher.bot.me - log.info(f"Bot: {user.full_name} [@{user.username}]") - - if skip_updates: - await dispatcher.reset_webhook(True) - count = await dispatcher.skip_updates() - if count: - log.warning(f"Skipped {count} updates.") - - if callable(callback): - await callback(dispatcher) - - -async def _wh_startup(app): - callback = app.get('_startup_callback', None) - dispatcher = app.get(BOT_DISPATCHER_KEY, None) - skip_updates = app.get('_skip_updates', False) - await _startup(dispatcher, skip_updates=skip_updates, callback=callback) - - -async def _shutdown(dispatcher: Dispatcher, callback=None): - if callable(callback): - await callback(dispatcher) - - if dispatcher.is_polling(): - dispatcher.stop_polling() - # await dispatcher.wait_closed() - - await dispatcher.storage.close() - await dispatcher.storage.wait_closed() - - await dispatcher.bot.close() - - -async def _wh_shutdown(app): - callback = app.get('_shutdown_callback', None) - dispatcher = app.get(BOT_DISPATCHER_KEY, None) - await _shutdown(dispatcher, callback=callback) - - -def start_polling(dispatcher, *, loop=None, skip_updates=False, - on_startup=None, on_shutdown=None): - log.warning('Start bot with long-polling.') - if loop is None: - loop = dispatcher.loop - - loop.set_task_factory(context.task_factory) - - try: - loop.run_until_complete(_startup(dispatcher, skip_updates, on_startup)) - loop.create_task(dispatcher.start_polling(reset_webhook=True)) - loop.run_forever() - except (KeyboardInterrupt, SystemExit): +def _setup_callbacks(executor, on_startup, on_shutdown): + if on_startup is None: pass - finally: - loop.run_until_complete(_shutdown(dispatcher, callback=on_shutdown)) - log.warning("Goodbye!") + elif callable(on_startup): + executor.on_startup(on_startup) + else: + for callback in on_startup: + executor.on_startup(callback) + + if on_shutdown is None: + pass + elif callable(on_shutdown): + executor.on_shutdown(on_shutdown) + else: + for callback in on_shutdown: + executor.on_shutdown(callback) + + +def start_polling(dispatcher, *, loop=None, skip_updates=False, reset_webhook=True, + on_startup=None, on_shutdown=None): + executor = Executor(dispatcher, skip_updates=skip_updates, loop=loop) + _setup_callbacks(executor, on_startup, on_shutdown) + + executor.start_polling(reset_webhook=reset_webhook) def start_webhook(dispatcher, webhook_path, *, loop=None, skip_updates=None, on_startup=None, on_shutdown=None, check_ip=False, **kwargs): - log.warning('Start bot with webhook.') - if loop is None: - loop = dispatcher.loop + executor = Executor(dispatcher, skip_updates=skip_updates, check_ip=check_ip, loop=loop) + _setup_callbacks(executor, on_startup, on_shutdown) - loop.set_task_factory(context.task_factory) + executor.start_webhook(webhook_path, **kwargs) - app = get_new_configured_app(dispatcher, webhook_path) - app['_startup_callback'] = on_startup - app['_shutdown_callback'] = on_shutdown - app['_skip_updates'] = skip_updates - app['_check_ip'] = check_ip - app.on_startup.append(_wh_startup) - app.on_shutdown.append(_wh_shutdown) +def start(dispatcher, func, *, loop=None, skip_updates=None, + on_startup=None, on_shutdown=None): + executor = Executor(dispatcher, skip_updates=skip_updates, loop=loop) + _setup_callbacks(executor, on_startup, on_shutdown) - web.run_app(app, **kwargs) - return app + executor.start(func) + + +class Executor: + def __init__(self, dispatcher, skip_updates=None, check_ip=False, loop=None): + if loop is None: + loop = dispatcher.loop + self.dispatcher = dispatcher + self.skip_updates = skip_updates + self.check_ip = check_ip + self.loop = loop + + self._identity = secrets.token_urlsafe(16) + self._web_app = None + + self._on_startup_webhook = [] + self._on_startup_polling = [] + self._on_shutdown_webhook = [] + self._on_shutdown_polling = [] + + self._freeze = False + + @property + def frozen(self): + return self._freeze + + def set_web_app(self, application: web.Application): + self._web_app = application + + def on_startup(self, callback: callable, polling=True, webhook=True): + self._check_frozen() + + if not webhook and not polling: + warn('This action has no effect!', UserWarning) + return + + if polling: + self._on_startup_polling.append(callback) + if webhook: + self._on_startup_webhook.append(callback) + + def on_shutdown(self, callback: callable, polling=True, webhook=True): + self._check_frozen() + + if not webhook and not polling: + warn('This action has no effect!', UserWarning) + return + + if polling: + self._on_shutdown_polling.append(callback) + if webhook: + self._on_shutdown_webhook.append(callback) + + def _check_frozen(self): + if self.frozen: + raise RuntimeError('Executor is frozen!') + + def _prepare_polling(self): + self._check_frozen() + self._freeze = True + + self.loop.set_task_factory(context.task_factory) + + def _prepare_webhook(self, path=None, handler=WebhookRequestHandler): + self._check_frozen() + self._freeze = True + + self.loop.set_task_factory(context.task_factory) + + app = self._web_app + if app is None: + self._web_app = app = web.Application() + app[BOT_DISPATCHER_KEY] = self.dispatcher + + if self._identity in self._identity: + # App is already configured + return + + if path is not None: + app.router.add_route('*', path, handler, name='webhook_handler') + + async def _wrap_callback(cb, _): + return await cb(self.dispatcher) + + for callback in self._on_startup_webhook: + app.on_startup.append(functools.partial(_wrap_callback, callback)) + for callback in self._on_shutdown_webhook: + app.on_shutdown.append(functools.partial(_wrap_callback, callback)) + + app[APP_EXECUTOR_KEY] = self + app[BOT_DISPATCHER_KEY] = self.dispatcher + app[self._identity] = datetime.datetime.now() + app['_check_ip'] = self.check_ip + + def start_webhook(self, webhook_path=None, request_handler=WebhookRequestHandler, **kwargs): + self._prepare_webhook(webhook_path, request_handler) + self.loop.run_until_complete(self._startup_webhook()) + web.run_app(self._web_app, **kwargs) + + def start_polling(self, reset_webhook=None): + self._prepare_polling() + loop: asyncio.AbstractEventLoop = self.loop + + try: + loop.run_until_complete(self._startup_polling()) + loop.create_task(self.dispatcher.start_polling(reset_webhook=reset_webhook)) + loop.run_forever() + except (KeyboardInterrupt, SystemExit): + loop.stop() + finally: + loop.run_until_complete(self._shutdown_polling()) + log.warning("Goodbye!") + + def start(self, func): + self._check_frozen() + self._freeze = True + loop: asyncio.AbstractEventLoop = self.loop + + try: + loop.run_until_complete(self._startup_polling()) + loop.run_until_complete(func) + except (KeyboardInterrupt, SystemExit): + loop.stop() + finally: + loop.run_until_complete(self._shutdown_polling()) + log.warning("Goodbye!") + + async def _skip_updates(self): + await self.dispatcher.reset_webhook(True) + count = await self.dispatcher.skip_updates() + if count: + log.warning(f"Skipped {count} updates.") + return count + + async def _welcome(self): + user = await self.dispatcher.bot.me + log.info(f"Bot: {user.full_name} [@{user.username}]") + + async def _shutdown(self): + self.dispatcher.stop_polling() + await self.dispatcher.storage.close() + await self.dispatcher.storage.wait_closed() + await self.dispatcher.bot.close() + + async def _startup_polling(self): + await self._welcome() + + if self.skip_updates: + await self._skip_updates() + for callback in self._on_startup_polling: + await callback(self.dispatcher) + + async def _shutdown_polling(self, wait_closed=False): + await self._shutdown() + + for callback in self._on_shutdown_polling: + await callback(self.dispatcher) + + if wait_closed: + await self.dispatcher.wait_closed() + + async def _startup_webhook(self): + await self._welcome() + if self.skip_updates: + self._skip_updates()