Add Black and Flake8

This commit is contained in:
Alex RootJunior 2019-06-29 19:53:18 +03:00
parent 5b9d82f1ca
commit 7c51b1a7d7
124 changed files with 5841 additions and 3772 deletions

12
.flake8 Normal file
View file

@ -0,0 +1,12 @@
[flake8]
max-line-length = 80
select = C,E,F,W,B,B950
ignore = E501,W503,E203
exclude =
.git
build
dist
venv
docs
*.egg-info
experiment.py

View file

@ -7,6 +7,7 @@
[![Supported python versions](https://img.shields.io/pypi/pyversions/aiogram.svg?style=flat-square)](https://pypi.python.org/pypi/aiogram)
[![Telegram Bot API](https://img.shields.io/badge/Telegram%20Bot%20API-4.3-blue.svg?style=flat-square&logo=telegram)](https://core.telegram.org/bots/api)
[![Documentation Status](https://img.shields.io/readthedocs/pip/stable.svg?style=flat-square)](http://aiogram.readthedocs.io/en/latest/?badge=latest)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/python/black)
[![Github issues](https://img.shields.io/github/issues/aiogram/aiogram.svg?style=flat-square)](https://github.com/aiogram/aiogram/issues)
[![MIT License](https://img.shields.io/pypi/l/aiogram.svg?style=flat-square)](https://opensource.org/licenses/MIT)

View file

@ -26,9 +26,13 @@ AIOGramBot
:alt: Telegram Bot API
.. image:: https://img.shields.io/readthedocs/pip/stable.svg?style=flat-square
:target: http://aiogram.readthedocs.io/en/latest/?badge=latest
:target: http://aiogram.readthedocs.io/en/latest/?badge=latest?style=flat-square
:alt: Documentation Status
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square
:target: https://github.com/python/black
:alt: Code style: Black
.. image:: https://img.shields.io/github/issues/aiogram/aiogram.svg?style=flat-square
:target: https://github.com/aiogram/aiogram/issues
:alt: Github issues

View file

@ -17,26 +17,26 @@ try:
except ImportError:
uvloop = None
else:
if 'DISABLE_UVLOOP' not in os.environ:
if "DISABLE_UVLOOP" not in os.environ:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
__all__ = [
'Bot',
'Dispatcher',
'__api_version__',
'__version__',
'bot',
'contrib',
'dispatcher',
'exceptions',
'executor',
'filters',
'helper',
'md',
'middlewares',
'types',
'utils'
"Bot",
"Dispatcher",
"__api_version__",
"__version__",
"bot",
"contrib",
"dispatcher",
"exceptions",
"executor",
"filters",
"helper",
"md",
"middlewares",
"types",
"utils",
]
__version__ = '2.2.1.dev1'
__api_version__ = '4.3'
__version__ = "2.2.1.dev1"
__api_version__ = "4.3"

View file

@ -18,7 +18,7 @@ class SysInfo:
@property
def python(self):
return sys.version.replace('\n', '')
return sys.version.replace("\n", "")
@property
def aiogram(self):
@ -57,27 +57,27 @@ class SysInfo:
return aiohttp.__version__
def collect(self):
yield f'{self.python_implementation}: {self.python}'
yield f'OS: {self.os}'
yield f'aiogram: {self.aiogram}'
yield f'aiohttp: {self.aiohttp}'
yield f"{self.python_implementation}: {self.python}"
yield f"OS: {self.os}"
yield f"aiogram: {self.aiogram}"
yield f"aiohttp: {self.aiohttp}"
uvloop = self.uvloop
if uvloop:
yield f'uvloop: {uvloop}'
yield f"uvloop: {uvloop}"
yield f'JSON mode: {json.mode}'
yield f"JSON mode: {json.mode}"
rapidjson = self.rapidjson
if rapidjson:
yield f'rapidjson: {rapidjson}'
yield f"rapidjson: {rapidjson}"
ujson = self.ujson
if ujson:
yield f'ujson: {ujson}'
yield f"ujson: {ujson}"
def __str__(self):
return '\n'.join(self.collect())
return "\n".join(self.collect())
if __name__ == '__main__':
if __name__ == "__main__":
print(SysInfo())

View file

@ -2,8 +2,4 @@ from . import api
from .base import BaseBot
from .bot import Bot
__all__ = [
'BaseBot',
'Bot',
'api'
]
__all__ = ["BaseBot", "Bot", "api"]

View file

@ -10,7 +10,7 @@ from ..utils import json
from ..utils.helper import Helper, HelperMode, Item
# Main aiogram logger
log = logging.getLogger('aiogram')
log = logging.getLogger("aiogram")
# API Url's
API_URL = "https://api.telegram.org/bot{token}/{method}"
@ -25,11 +25,11 @@ def check_token(token: str) -> bool:
:return:
"""
if any(x.isspace() for x in token):
raise exceptions.ValidationError('Token is invalid!')
raise exceptions.ValidationError("Token is invalid!")
left, sep, right = token.partition(':')
left, sep, right = token.partition(":")
if (not sep) or (not left.isdigit()) or (len(left) < 3):
raise exceptions.ValidationError('Token is invalid!')
raise exceptions.ValidationError("Token is invalid!")
return True
@ -51,19 +51,21 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
"""
log.debug('Response for %s: [%d] "%r"', method_name, status_code, body)
if content_type != 'application/json':
raise exceptions.NetworkError(f"Invalid response with content type {content_type}: \"{body}\"")
if content_type != "application/json":
raise exceptions.NetworkError(
f'Invalid response with content type {content_type}: "{body}"'
)
try:
result_json = json.loads(body)
except ValueError:
result_json = {}
description = result_json.get('description') or body
parameters = types.ResponseParameters(**result_json.get('parameters', {}) or {})
description = result_json.get("description") or body
parameters = types.ResponseParameters(**result_json.get("parameters", {}) or {})
if HTTPStatus.OK <= status_code <= HTTPStatus.IM_USED:
return result_json.get('result')
return result_json.get("result")
elif parameters.retry_after:
raise exceptions.RetryAfter(parameters.retry_after)
elif parameters.migrate_to_chat_id:
@ -77,10 +79,12 @@ def check_result(method_name: str, content_type: str, status_code: int, body: st
elif status_code in [HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN]:
exceptions.Unauthorized.detect(description)
elif status_code == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
raise exceptions.NetworkError('File too large for uploading. '
'Check telegram api limits https://core.telegram.org/bots/api#senddocument')
raise exceptions.NetworkError(
"File too large for uploading. "
"Check telegram api limits https://core.telegram.org/bots/api#senddocument"
)
elif status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
if 'restart' in description:
if "restart" in description:
raise exceptions.RestartingTelegram()
raise exceptions.TelegramAPIError(description)
raise exceptions.TelegramAPIError(f"{description} [{status_code}]")
@ -95,9 +99,13 @@ async def make_request(session, token, method, data=None, files=None, **kwargs):
req = compose_data(data, files)
try:
async with session.post(url, data=req, **kwargs) as response:
return check_result(method, response.content_type, response.status, await response.text())
return check_result(
method, response.content_type, response.status, await response.text()
)
except aiohttp.ClientError as e:
raise exceptions.NetworkError(f"aiohttp client throws an error: {e.__class__.__name__}: {e}")
raise exceptions.NetworkError(
f"aiohttp client throws an error: {e.__class__.__name__}: {e}"
)
def guess_filename(obj):
@ -107,8 +115,8 @@ def guess_filename(obj):
:param obj:
:return:
"""
name = getattr(obj, 'name', None)
if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>':
name = getattr(obj, "name", None)
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
return os.path.basename(name)
@ -132,7 +140,7 @@ def compose_data(params=None, files=None):
if len(f) == 2:
filename, fileobj = f
else:
raise ValueError('Tuple must have exactly 2 elements: filename, fileobj')
raise ValueError("Tuple must have exactly 2 elements: filename, fileobj")
elif isinstance(f, types.InputFile):
filename, fileobj = f.filename, f.file
else:
@ -149,6 +157,7 @@ class Methods(Helper):
List is updated to Bot API 4.3
"""
mode = HelperMode.lowerCamelCase
# Getting Updates

View file

@ -20,19 +20,22 @@ class BaseBot:
"""
Base class for bot. It's raw bot.
"""
_ctx_timeout = ContextVar('TelegramRequestTimeout')
_ctx_token = ContextVar('BotDifferentToken')
_ctx_timeout = ContextVar("TelegramRequestTimeout")
_ctx_token = ContextVar("BotDifferentToken")
def __init__(
self,
token: base.String,
loop: Optional[Union[asyncio.BaseEventLoop, asyncio.AbstractEventLoop]] = None,
connections_limit: Optional[base.Integer] = None,
proxy: Optional[base.String] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
validate_token: Optional[base.Boolean] = True,
parse_mode: typing.Optional[base.String] = None,
timeout: typing.Optional[typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]] = None
self,
token: base.String,
loop: Optional[asyncio.AbstractEventLoop] = None,
connections_limit: Optional[base.Integer] = None,
proxy: Optional[base.String] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
validate_token: Optional[base.Boolean] = True,
parse_mode: typing.Optional[base.String] = None,
timeout: typing.Optional[
typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]
] = None,
):
"""
Instructions how to get Bot token is found here: https://core.telegram.org/bots#3-how-do-i-create-a-bot
@ -72,7 +75,9 @@ class BaseBot:
# aiohttp main session
ssl_context = ssl.create_default_context(cafile=certifi.where())
if isinstance(proxy, str) and (proxy.startswith('socks5://') or proxy.startswith('socks4://')):
if isinstance(proxy, str) and (
proxy.startswith("socks5://") or proxy.startswith("socks4://")
):
from aiohttp_socks import SocksConnector
from aiohttp_socks.helpers import parse_socks_url
@ -83,25 +88,36 @@ class BaseBot:
if not password:
password = proxy_auth.password
connector = SocksConnector(socks_ver=socks_ver, host=host, port=port,
username=username, password=password,
limit=connections_limit, ssl_context=ssl_context,
rdns=True, loop=self.loop)
connector = SocksConnector(
socks_ver=socks_ver,
host=host,
port=port,
username=username,
password=password,
limit=connections_limit,
ssl_context=ssl_context,
rdns=True,
loop=self.loop,
)
self.proxy = None
self.proxy_auth = None
else:
connector = aiohttp.TCPConnector(limit=connections_limit, ssl=ssl_context, loop=self.loop)
connector = aiohttp.TCPConnector(
limit=connections_limit, ssl=ssl_context, loop=self.loop
)
self._timeout = None
self.timeout = timeout
self.session = aiohttp.ClientSession(connector=connector, loop=self.loop, json_serialize=json.dumps)
self.session = aiohttp.ClientSession(
connector=connector, loop=self.loop, json_serialize=json.dumps
)
self.parse_mode = parse_mode
@staticmethod
def _prepare_timeout(
value: typing.Optional[typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]]
value: typing.Optional[typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]]
) -> typing.Optional[aiohttp.ClientTimeout]:
if value is None or isinstance(value, aiohttp.ClientTimeout):
return value
@ -123,7 +139,9 @@ class BaseBot:
self.timeout = None
@contextlib.contextmanager
def request_timeout(self, timeout: typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]):
def request_timeout(
self, timeout: typing.Union[base.Integer, base.Float, aiohttp.ClientTimeout]
):
"""
Context manager implements opportunity to change request timeout in current context
@ -162,9 +180,13 @@ class BaseBot:
"""
await self.session.close()
async def request(self, method: base.String,
data: Optional[Dict] = None,
files: Optional[Dict] = None, **kwargs) -> Union[List, Dict, base.Boolean]:
async def request(
self,
method: base.String,
data: Optional[Dict] = None,
files: Optional[Dict] = None,
**kwargs,
) -> Union[List, Dict, base.Boolean]:
"""
Make an request to Telegram Bot API
@ -180,14 +202,26 @@ class BaseBot:
:rtype: Union[List, Dict]
:raise: :obj:`aiogram.exceptions.TelegramApiError`
"""
return await api.make_request(self.session, self.__token, method, data, files,
proxy=self.proxy, proxy_auth=self.proxy_auth, timeout=self.timeout, **kwargs)
return await api.make_request(
self.session,
self.__token,
method,
data,
files,
proxy=self.proxy,
proxy_auth=self.proxy_auth,
timeout=self.timeout,
**kwargs,
)
async def download_file(self, file_path: base.String,
destination: Optional[base.InputFile] = None,
timeout: Optional[base.Integer] = sentinel,
chunk_size: Optional[base.Integer] = 65536,
seek: Optional[base.Boolean] = True) -> Union[io.BytesIO, io.FileIO]:
async def download_file(
self,
file_path: base.String,
destination: Optional[base.InputFile] = None,
timeout: Optional[base.Integer] = sentinel,
chunk_size: Optional[base.Integer] = 65536,
seek: Optional[base.Boolean] = True,
) -> Union[io.BytesIO, io.FileIO]:
"""
Download file by file_path to destination
@ -207,8 +241,10 @@ class BaseBot:
url = self.get_file_url(file_path)
dest = destination if isinstance(destination, io.IOBase) else open(destination, 'wb')
async with self.session.get(url, timeout=timeout, proxy=self.proxy, proxy_auth=self.proxy_auth) as response:
dest = destination if isinstance(destination, io.IOBase) else open(destination, "wb")
async with self.session.get(
url, timeout=timeout, proxy=self.proxy, proxy_auth=self.proxy_auth
) as response:
while True:
chunk = await response.content.read(chunk_size)
if not chunk:
@ -247,19 +283,19 @@ class BaseBot:
@property
def parse_mode(self):
return getattr(self, '_parse_mode', None)
return getattr(self, "_parse_mode", None)
@parse_mode.setter
def parse_mode(self, value):
if value is None:
setattr(self, '_parse_mode', None)
setattr(self, "_parse_mode", None)
else:
if not isinstance(value, str):
raise TypeError(f"Parse mode must be str, not {type(value)}")
value = value.lower()
if value not in ParseMode.all():
raise ValueError(f"Parse mode must be one of {ParseMode.all()}")
setattr(self, '_parse_mode', value)
setattr(self, "_parse_mode", value)
@parse_mode.deleter
def parse_mode(self):

File diff suppressed because it is too large Load diff

View file

@ -36,11 +36,11 @@ class JSONStorage(_FileStorage):
"""
def read(self, path: pathlib.Path):
with path.open('r') as f:
with path.open("r") as f:
return json.load(f)
def write(self, path: pathlib.Path):
with path.open('w') as f:
with path.open("w") as f:
return json.dump(self.data, f, indent=4)
@ -50,9 +50,9 @@ class PickleStorage(_FileStorage):
"""
def read(self, path: pathlib.Path):
with path.open('rb') as f:
with path.open("rb") as f:
return pickle.load(f)
def write(self, path: pathlib.Path):
with path.open('wb') as f:
with path.open("wb") as f:
return pickle.dump(self.data, f, protocol=pickle.HIGHEST_PROTOCOL)

View file

@ -26,51 +26,70 @@ class MemoryStorage(BaseStorage):
if chat_id not in self.data:
self.data[chat_id] = {}
if user_id not in self.data[chat_id]:
self.data[chat_id][user_id] = {'state': None, 'data': {}, 'bucket': {}}
self.data[chat_id][user_id] = {"state": None, "data": {}, "bucket": {}}
return chat_id, user_id
async def get_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
async def get_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Optional[str]:
chat, user = self.resolve_address(chat=chat, user=user)
return self.data[chat][user]['state']
return self.data[chat][user]["state"]
async def get_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
async def get_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Dict:
chat, user = self.resolve_address(chat=chat, user=user)
return copy.deepcopy(self.data[chat][user]['data'])
return copy.deepcopy(self.data[chat][user]["data"])
async def update_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
async def update_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs,
):
if data is None:
data = {}
chat, user = self.resolve_address(chat=chat, user=user)
self.data[chat][user]['data'].update(data, **kwargs)
self.data[chat][user]["data"].update(data, **kwargs)
async def set_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.AnyStr = None):
async def set_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.AnyStr = None,
):
chat, user = self.resolve_address(chat=chat, user=user)
self.data[chat][user]['state'] = state
self.data[chat][user]["state"] = state
async def set_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
async def set_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
):
chat, user = self.resolve_address(chat=chat, user=user)
self.data[chat][user]['data'] = copy.deepcopy(data)
self.data[chat][user]["data"] = copy.deepcopy(data)
async def reset_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
with_data: typing.Optional[bool] = True):
async def reset_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
with_data: typing.Optional[bool] = True,
):
await self.set_state(chat=chat, user=user, state=None)
if with_data:
await self.set_data(chat=chat, user=user, data={})
@ -78,25 +97,35 @@ class MemoryStorage(BaseStorage):
def has_bucket(self):
return True
async def get_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
async def get_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None,
) -> typing.Dict:
chat, user = self.resolve_address(chat=chat, user=user)
return copy.deepcopy(self.data[chat][user]['bucket'])
return copy.deepcopy(self.data[chat][user]["bucket"])
async def set_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
async def set_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
):
chat, user = self.resolve_address(chat=chat, user=user)
self.data[chat][user]['bucket'] = copy.deepcopy(bucket)
self.data[chat][user]["bucket"] = copy.deepcopy(bucket)
async def update_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
async def update_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
**kwargs,
):
if bucket is None:
bucket = {}
chat, user = self.resolve_address(chat=chat, user=user)
self.data[chat][user]['bucket'].update(bucket, **kwargs)
self.data[chat][user]["bucket"].update(bucket, **kwargs)

View file

@ -11,9 +11,9 @@ import aioredis
from ...dispatcher.storage import BaseStorage
from ...utils import json
STATE_KEY = 'state'
STATE_DATA_KEY = 'data'
STATE_BUCKET_KEY = 'bucket'
STATE_KEY = "state"
STATE_DATA_KEY = "data"
STATE_BUCKET_KEY = "bucket"
class RedisStorage(BaseStorage):
@ -36,7 +36,9 @@ class RedisStorage(BaseStorage):
"""
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, loop=None, **kwargs):
def __init__(
self, host="localhost", port=6379, db=None, password=None, ssl=None, loop=None, **kwargs
):
self._host = host
self._port = port
self._db = db
@ -68,15 +70,22 @@ class RedisStorage(BaseStorage):
# Use thread-safe asyncio Lock because this method without that is not safe
async with self._connection_lock:
if self._redis is None:
self._redis = await aioredis.create_connection((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
loop=self._loop,
**self._kwargs)
self._redis = await aioredis.create_connection(
(self._host, self._port),
db=self._db,
password=self._password,
ssl=self._ssl,
loop=self._loop,
**self._kwargs,
)
return self._redis
async def get_record(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None) -> typing.Dict:
async def get_record(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
) -> typing.Dict:
"""
Get record from storage
@ -88,13 +97,20 @@ class RedisStorage(BaseStorage):
addr = f"fsm:{chat}:{user}"
conn = await self.redis()
data = await conn.execute('GET', addr)
data = await conn.execute("GET", addr)
if data is None:
return {'state': None, 'data': {}}
return {"state": None, "data": {}}
return json.loads(data)
async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state=None, data=None, bucket=None):
async def set_record(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state=None,
data=None,
bucket=None,
):
"""
Write record to storage
@ -113,39 +129,65 @@ class RedisStorage(BaseStorage):
chat, user = self.check_address(chat=chat, user=user)
addr = f"fsm:{chat}:{user}"
record = {'state': state, 'data': data, 'bucket': bucket}
record = {"state": state, "data": data, "bucket": bucket}
conn = await self.redis()
await conn.execute('SET', addr, json.dumps(record))
await conn.execute("SET", addr, json.dumps(record))
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
async def get_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Optional[str]:
record = await self.get_record(chat=chat, user=user)
return record['state']
return record["state"]
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
async def get_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Dict:
record = await self.get_record(chat=chat, user=user)
return record['data']
return record["data"]
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
async def set_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None,
):
record = await self.get_record(chat=chat, user=user)
await self.set_record(chat=chat, user=user, state=state, data=record['data'])
await self.set_record(chat=chat, user=user, state=state, data=record["data"])
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
async def set_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
):
record = await self.get_record(chat=chat, user=user)
await self.set_record(chat=chat, user=user, state=record['state'], data=data)
await self.set_record(chat=chat, user=user, state=record["state"], data=data)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
async def update_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs,
):
if data is None:
data = {}
record = await self.get_record(chat=chat, user=user)
record_data = record.get('data', {})
record_data = record.get("data", {})
record_data.update(data, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_data)
await self.set_record(chat=chat, user=user, state=record["state"], data=record_data)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
"""
@ -156,9 +198,9 @@ class RedisStorage(BaseStorage):
conn = await self.redis()
result = []
keys = await conn.execute('KEYS', 'fsm:*')
keys = await conn.execute("KEYS", "fsm:*")
for item in keys:
*_, chat, user = item.decode('utf-8').split(':')
*_, chat, user = item.decode("utf-8").split(":")
result.append((chat, user))
return result
@ -173,33 +215,52 @@ class RedisStorage(BaseStorage):
conn = await self.redis()
if full:
await conn.execute('FLUSHDB')
await conn.execute("FLUSHDB")
else:
keys = await conn.execute('KEYS', 'fsm:*')
await conn.execute('DEL', *keys)
keys = await conn.execute("KEYS", "fsm:*")
await conn.execute("DEL", *keys)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
async def get_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Dict:
record = await self.get_record(chat=chat, user=user)
return record.get('bucket', {})
return record.get("bucket", {})
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
async def set_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
):
record = await self.get_record(chat=chat, user=user)
await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket)
await self.set_record(
chat=chat, user=user, state=record["state"], data=record["data"], bucket=bucket
)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
async def update_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
**kwargs,
):
record = await self.get_record(chat=chat, user=user)
record_bucket = record.get('bucket', {})
record_bucket = record.get("bucket", {})
if bucket is None:
bucket = {}
record_bucket.update(bucket, **kwargs)
await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)
await self.set_record(
chat=chat, user=user, state=record["state"], data=record_bucket, bucket=bucket
)
class RedisStorage2(BaseStorage):
@ -223,8 +284,18 @@ class RedisStorage2(BaseStorage):
"""
def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None,
pool_size=10, loop=None, prefix='fsm', **kwargs):
def __init__(
self,
host="localhost",
port=6379,
db=None,
password=None,
ssl=None,
pool_size=10,
loop=None,
prefix="fsm",
**kwargs,
):
self._host = host
self._port = port
self._db = db
@ -247,14 +318,20 @@ class RedisStorage2(BaseStorage):
# Use thread-safe asyncio Lock because this method without that is not safe
async with self._connection_lock:
if self._redis is None:
self._redis = await aioredis.create_redis_pool((self._host, self._port),
db=self._db, password=self._password, ssl=self._ssl,
minsize=1, maxsize=self._pool_size,
loop=self._loop, **self._kwargs)
self._redis = await aioredis.create_redis_pool(
(self._host, self._port),
db=self._db,
password=self._password,
ssl=self._ssl,
minsize=1,
maxsize=self._pool_size,
loop=self._loop,
**self._kwargs,
)
return self._redis
def generate_key(self, *parts):
return ':'.join(self._prefix + tuple(map(str, parts)))
return ":".join(self._prefix + tuple(map(str, parts)))
async def close(self):
async with self._connection_lock:
@ -269,25 +346,40 @@ class RedisStorage2(BaseStorage):
return await self._redis.wait_closed()
return True
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
async def get_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Optional[str]:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis()
return await redis.get(key, encoding='utf8') or None
return await redis.get(key, encoding="utf8") or None
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
async def get_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None,
) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self.redis()
raw_result = await redis.get(key, encoding='utf8')
raw_result = await redis.get(key, encoding="utf8")
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
async def set_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None,
):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_KEY)
redis = await self.redis()
@ -296,15 +388,26 @@ class RedisStorage2(BaseStorage):
else:
await redis.set(key, state)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
async def set_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_DATA_KEY)
redis = await self.redis()
await redis.set(key, json.dumps(data))
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
async def update_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs,
):
if data is None:
data = {}
temp_data = await self.get_data(chat=chat, user=user, default={})
@ -314,26 +417,41 @@ class RedisStorage2(BaseStorage):
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
async def get_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None,
) -> typing.Dict:
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self.redis()
raw_result = await redis.get(key, encoding='utf8')
raw_result = await redis.get(key, encoding="utf8")
if raw_result:
return json.loads(raw_result)
return default or {}
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
async def set_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
):
chat, user = self.check_address(chat=chat, user=user)
key = self.generate_key(chat, user, STATE_BUCKET_KEY)
redis = await self.redis()
await redis.set(key, json.dumps(bucket))
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None, **kwargs):
async def update_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
**kwargs,
):
if bucket is None:
bucket = {}
temp_bucket = await self.get_bucket(chat=chat, user=user)
@ -352,7 +470,7 @@ class RedisStorage2(BaseStorage):
if full:
await conn.flushdb()
else:
keys = await conn.keys(self.generate_key('*'))
keys = await conn.keys(self.generate_key("*"))
await conn.delete(*keys)
async def get_states_list(self) -> typing.List[typing.Tuple[int]]:
@ -364,9 +482,9 @@ class RedisStorage2(BaseStorage):
conn = await self.redis()
result = []
keys = await conn.keys(self.generate_key('*', '*', STATE_KEY), encoding='utf8')
keys = await conn.keys(self.generate_key("*", "*", STATE_KEY), encoding="utf8")
for item in keys:
*_, chat, user, _ = item.split(':')
*_, chat, user, _ = item.split(":")
result.append((chat, user))
return result
@ -388,7 +506,7 @@ async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorag
if not isinstance(storage2, RedisStorage):
raise TypeError(f"{type(storage2)} is not RedisStorage instance.")
log = logging.getLogger('aiogram.RedisStorage')
log = logging.getLogger("aiogram.RedisStorage")
for chat, user in await storage1.get_states_list():
state = await storage1.get_state(chat=chat, user=user)

View file

@ -7,10 +7,10 @@ from rethinkdb.asyncio_net.net_asyncio import Connection
from ...dispatcher.storage import BaseStorage
__all__ = ['RethinkDBStorage']
__all__ = ["RethinkDBStorage"]
r = rethinkdb.RethinkDB()
r.set_loop_type('asyncio')
r.set_loop_type("asyncio")
class RethinkDBStorage(BaseStorage):
@ -32,17 +32,19 @@ class RethinkDBStorage(BaseStorage):
"""
def __init__(self,
host: str = 'localhost',
port: int = 28015,
db: str = 'aiogram',
table: str = 'aiogram',
auth_key: typing.Optional[str] = None,
user: typing.Optional[str] = None,
password: typing.Optional[str] = None,
timeout: int = 20,
ssl: typing.Optional[dict] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None):
def __init__(
self,
host: str = "localhost",
port: int = 28015,
db: str = "aiogram",
table: str = "aiogram",
auth_key: typing.Optional[str] = None,
user: typing.Optional[str] = None,
password: typing.Optional[str] = None,
timeout: int = 20,
ssl: typing.Optional[dict] = None,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
):
self._host = host
self._port = port
self._db = db
@ -61,15 +63,17 @@ class RethinkDBStorage(BaseStorage):
Get or create a connection.
"""
if self._conn is None:
self._conn = await r.connect(host=self._host,
port=self._port,
db=self._db,
auth_key=self._auth_key,
user=self._user,
password=self._password,
timeout=self._timeout,
ssl=self._ssl,
io_loop=self._loop)
self._conn = await r.connect(
host=self._host,
port=self._port,
db=self._db,
auth_key=self._auth_key,
user=self._user,
password=self._password,
timeout=self._timeout,
ssl=self._ssl,
io_loop=self._loop,
)
return self._conn
@contextlib.asynccontextmanager
@ -90,64 +94,126 @@ class RethinkDBStorage(BaseStorage):
"""
pass
async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
async def get_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Optional[str]:
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
return await r.table(self._table).get(chat)[user]['state'].default(default or None).run(conn)
return (
await r.table(self._table)
.get(chat)[user]["state"]
.default(default or None)
.run(conn)
)
async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
async def get_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Dict:
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
return await r.table(self._table).get(chat)[user]['data'].default(default or {}).run(conn)
return (
await r.table(self._table).get(chat)[user]["data"].default(default or {}).run(conn)
)
async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
async def set_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None,
):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
await r.table(self._table).insert({'id': chat, user: {'state': state}}, conflict="update").run(conn)
await r.table(self._table).insert(
{"id": chat, user: {"state": state}}, conflict="update"
).run(conn)
async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
async def set_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'data': r.literal(data)}}).run(conn)
await r.table(self._table).get(chat).update({user: {"data": r.literal(data)}}).run(
conn
)
else:
await r.table(self._table).insert({'id': chat, user: {'data': data}}).run(conn)
await r.table(self._table).insert({"id": chat, user: {"data": data}}).run(conn)
async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs):
async def update_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs,
):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
await r.table(self._table).insert({'id': chat, user: {'data': data}}, conflict="update").run(conn)
await r.table(self._table).insert(
{"id": chat, user: {"data": data}}, conflict="update"
).run(conn)
def has_bucket(self):
return True
async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
async def get_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None,
) -> typing.Dict:
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
return await r.table(self._table).get(chat)[user]['bucket'].default(default or {}).run(conn)
return (
await r.table(self._table)
.get(chat)[user]["bucket"]
.default(default or {})
.run(conn)
)
async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
async def set_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
if await r.table(self._table).get(chat).run(conn):
await r.table(self._table).get(chat).update({user: {'bucket': r.literal(bucket)}}).run(conn)
await r.table(self._table).get(chat).update(
{user: {"bucket": r.literal(bucket)}}
).run(conn)
else:
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}).run(conn)
await r.table(self._table).insert({"id": chat, user: {"bucket": bucket}}).run(conn)
async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None, bucket: typing.Dict = None,
**kwargs):
async def update_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
**kwargs,
):
chat, user = map(str, self.check_address(chat=chat, user=user))
async with self.connection() as conn:
await r.table(self._table).insert({'id': chat, user: {'bucket': bucket}}, conflict="update").run(conn)
await r.table(self._table).insert(
{"id": chat, user: {"bucket": bucket}}, conflict="update"
).run(conn)
async def get_states_list(self) -> typing.List[typing.Tuple[int, int]]:
"""
@ -161,7 +227,7 @@ class RethinkDBStorage(BaseStorage):
items = (await r.table(self._table).run(conn)).items
for item in items:
chat = int(item.pop('id'))
chat = int(item.pop("id"))
for key in item.keys():
user = int(key)
result.append((chat, user))

View file

@ -11,15 +11,11 @@ class EnvironmentMiddleware(BaseMiddleware):
def update_data(self, data):
dp = self.manager.dispatcher
data.update(
bot=dp.bot,
dispatcher=dp,
loop=dp.loop
)
data.update(bot=dp.bot, dispatcher=dp, loop=dp.loop)
if self.context:
data.update(self.context)
async def trigger(self, action, args):
if 'error' not in action and action.startswith('pre_process_'):
if "error" not in action and action.startswith("pre_process_"):
self.update_data(args[-1])
return True

View file

@ -6,7 +6,7 @@ from aiogram.dispatcher.storage import FSMContext
class FSMMiddleware(LifetimeControllerMiddleware):
skip_patterns = ['error', 'update']
skip_patterns = ["error", "update"]
def __init__(self):
super(FSMMiddleware, self).__init__()
@ -14,10 +14,10 @@ class FSMMiddleware(LifetimeControllerMiddleware):
async def pre_process(self, obj, data, *args):
proxy = await FSMSStorageProxy.create(self.manager.dispatcher.current_state())
data['state_data'] = proxy
data["state_data"] = proxy
async def post_process(self, obj, data, *args):
proxy = data.get('state_data', None)
proxy = data.get("state_data", None)
if isinstance(proxy, FSMSStorageProxy):
await proxy.save()

View file

@ -23,9 +23,9 @@ class I18nMiddleware(BaseMiddleware):
>>> _ = i18n = I18nMiddleware(DOMAIN_NAME, LOCALES_DIR)
"""
ctx_locale = ContextVar('ctx_user_locale', default=None)
ctx_locale = ContextVar("ctx_user_locale", default=None)
def __init__(self, domain, path=None, default='en'):
def __init__(self, domain, path=None, default="en"):
"""
:param domain: domain
:param path: path where located all *.mo files
@ -34,7 +34,7 @@ class I18nMiddleware(BaseMiddleware):
super(I18nMiddleware, self).__init__()
if path is None:
path = os.path.join(os.getcwd(), 'locales')
path = os.path.join(os.getcwd(), "locales")
self.domain = domain
self.path = path
@ -53,12 +53,12 @@ class I18nMiddleware(BaseMiddleware):
for name in os.listdir(self.path):
if not os.path.isdir(os.path.join(self.path, name)):
continue
mo_path = os.path.join(self.path, name, 'LC_MESSAGES', self.domain + '.mo')
mo_path = os.path.join(self.path, name, "LC_MESSAGES", self.domain + ".mo")
if os.path.exists(mo_path):
with open(mo_path, 'rb') as fp:
with open(mo_path, "rb") as fp:
translations[name] = gettext.GNUTranslations(fp)
elif os.path.exists(mo_path[:-2] + 'po'):
elif os.path.exists(mo_path[:-2] + "po"):
raise RuntimeError(f"Found locale '{name} but this language is not compiled!")
return translations
@ -134,7 +134,7 @@ class I18nMiddleware(BaseMiddleware):
if locale:
*_, data = args
language = data['locale'] = locale.language
language = data["locale"] = locale.language
return language
async def trigger(self, action, args):
@ -145,9 +145,7 @@ class I18nMiddleware(BaseMiddleware):
:param args: event arguments
:return:
"""
if 'update' not in action \
and 'error' not in action \
and action.startswith('pre_process'):
if "update" not in action and "error" not in action and action.startswith("pre_process"):
locale = await self.get_user_locale(action, args)
self.ctx_locale.set(locale)
return True

View file

@ -5,7 +5,7 @@ import logging
from aiogram import types
from aiogram.dispatcher.middlewares import BaseMiddleware
HANDLED_STR = ['Unhandled', 'Handled']
HANDLED_STR = ["Unhandled", "Handled"]
class LoggingMiddleware(BaseMiddleware):
@ -18,123 +18,181 @@ class LoggingMiddleware(BaseMiddleware):
super(LoggingMiddleware, self).__init__()
def check_timeout(self, obj):
start = obj.conf.get('_start', None)
start = obj.conf.get("_start", None)
if start:
del obj.conf['_start']
del obj.conf["_start"]
return round((time.time() - start) * 1000)
return -1
async def on_pre_process_update(self, update: types.Update, data: dict):
update.conf['_start'] = time.time()
update.conf["_start"] = time.time()
self.logger.debug(f"Received update [ID:{update.update_id}]")
async def on_post_process_update(self, update: types.Update, result, data: dict):
timeout = self.check_timeout(update)
if timeout > 0:
self.logger.info(f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)")
self.logger.info(
f"Process update [ID:{update.update_id}]: [success] (in {timeout} ms)"
)
async def on_pre_process_message(self, message: types.Message, data: dict):
self.logger.info(f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
self.logger.info(
f"Received message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]"
)
async def on_post_process_message(self, message: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"message [ID:{message.message_id}] in chat [{message.chat.type}:{message.chat.id}]"
)
async def on_pre_process_edited_message(self, edited_message, data: dict):
self.logger.info(f"Received edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
self.logger.info(
f"Received edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]"
)
async def on_post_process_edited_message(self, edited_message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"edited message [ID:{edited_message.message_id}] "
f"in chat [{edited_message.chat.type}:{edited_message.chat.id}]"
)
async def on_pre_process_channel_post(self, channel_post: types.Message, data: dict):
self.logger.info(f"Received channel post [ID:{channel_post.message_id}] "
f"in channel [ID:{channel_post.chat.id}]")
self.logger.info(
f"Received channel post [ID:{channel_post.message_id}] "
f"in channel [ID:{channel_post.chat.id}]"
)
async def on_post_process_channel_post(self, channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"channel post [ID:{channel_post.message_id}] "
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"channel post [ID:{channel_post.message_id}] "
f"in chat [{channel_post.chat.type}:{channel_post.chat.id}]"
)
async def on_pre_process_edited_channel_post(self, edited_channel_post: types.Message, data: dict):
self.logger.info(f"Received edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_pre_process_edited_channel_post(
self, edited_channel_post: types.Message, data: dict
):
self.logger.info(
f"Received edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]"
)
async def on_post_process_edited_channel_post(self, edited_channel_post: types.Message, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]")
async def on_post_process_edited_channel_post(
self, edited_channel_post: types.Message, results, data: dict
):
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"edited channel post [ID:{edited_channel_post.message_id}] "
f"in channel [ID:{edited_channel_post.chat.id}]"
)
async def on_pre_process_inline_query(self, inline_query: types.InlineQuery, data: dict):
self.logger.info(f"Received inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
self.logger.info(
f"Received inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]"
)
async def on_post_process_inline_query(self, inline_query: types.InlineQuery, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]")
async def on_post_process_inline_query(
self, inline_query: types.InlineQuery, results, data: dict
):
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"inline query [ID:{inline_query.id}] "
f"from user [ID:{inline_query.from_user.id}]"
)
async def on_pre_process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult, data: dict):
self.logger.info(f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_pre_process_chosen_inline_result(
self, chosen_inline_result: types.ChosenInlineResult, data: dict
):
self.logger.info(
f"Received chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]"
)
async def on_post_process_chosen_inline_result(self, chosen_inline_result, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]")
async def on_post_process_chosen_inline_result(
self, chosen_inline_result, results, data: dict
):
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"chosen inline result [Inline msg ID:{chosen_inline_result.inline_message_id}] "
f"from user [ID:{chosen_inline_result.from_user.id}] "
f"result [ID:{chosen_inline_result.result_id}]"
)
async def on_pre_process_callback_query(self, callback_query: types.CallbackQuery, data: dict):
if callback_query.message:
if callback_query.message.from_user:
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
f"from user [ID:{callback_query.message.from_user.id}]")
self.logger.info(
f"Received callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
f"from user [ID:{callback_query.message.from_user.id}]"
)
else:
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}]")
self.logger.info(
f"Received callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}]"
)
else:
self.logger.info(f"Received callback query [ID:{callback_query.id}] "
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
self.logger.info(
f"Received callback query [ID:{callback_query.id}] "
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]"
)
async def on_post_process_callback_query(self, callback_query, results, data: dict):
if callback_query.message:
if callback_query.message.from_user:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
f"from user [ID:{callback_query.message.from_user.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}] "
f"from user [ID:{callback_query.message.from_user.id}]"
)
else:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"in chat [{callback_query.message.chat.type}:{callback_query.message.chat.id}]"
)
else:
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"callback query [ID:{callback_query.id}] "
f"from inline message [ID:{callback_query.inline_message_id}] "
f"from user [ID:{callback_query.from_user.id}]"
)
async def on_pre_process_shipping_query(self, shipping_query: types.ShippingQuery, data: dict):
self.logger.info(f"Received shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
self.logger.info(
f"Received shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]"
)
async def on_post_process_shipping_query(self, shipping_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"shipping query [ID:{shipping_query.id}] "
f"from user [ID:{shipping_query.from_user.id}]"
)
async def on_pre_process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery, data: dict):
self.logger.info(f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
async def on_pre_process_pre_checkout_query(
self, pre_checkout_query: types.PreCheckoutQuery, data: dict
):
self.logger.info(
f"Received pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]"
)
async def on_post_process_pre_checkout_query(self, pre_checkout_query, results, data: dict):
self.logger.debug(f"{HANDLED_STR[bool(len(results))]} "
f"pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]")
self.logger.debug(
f"{HANDLED_STR[bool(len(results))]} "
f"pre-checkout query [ID:{pre_checkout_query.id}] "
f"from user [ID:{pre_checkout_query.from_user.id}]"
)
async def on_pre_process_error(self, update, error, data: dict):
timeout = self.check_timeout(update)
@ -168,7 +226,7 @@ class LoggingFilter(logging.Filter):
"""
def __init__(self, name='', prefix='tg', include_content=False):
def __init__(self, name="", prefix="tg", include_content=False):
"""
:param name:
:param prefix: prefix for all records
@ -200,34 +258,34 @@ class LoggingFilter(logging.Filter):
:param update:
:return:
"""
yield 'update_id', update.update_id
yield "update_id", update.update_id
if update.message:
yield 'update_type', 'message'
yield "update_type", "message"
yield from self.process_message(update.message)
if update.edited_message:
yield 'update_type', 'edited_message'
yield "update_type", "edited_message"
yield from self.process_message(update.edited_message)
if update.channel_post:
yield 'update_type', 'channel_post'
yield "update_type", "channel_post"
yield from self.process_message(update.channel_post)
if update.edited_channel_post:
yield 'update_type', 'edited_channel_post'
yield "update_type", "edited_channel_post"
yield from self.process_message(update.edited_channel_post)
if update.inline_query:
yield 'update_type', 'inline_query'
yield "update_type", "inline_query"
yield from self.process_inline_query(update.inline_query)
if update.chosen_inline_result:
yield 'update_type', 'chosen_inline_result'
yield "update_type", "chosen_inline_result"
yield from self.process_chosen_inline_result(update.chosen_inline_result)
if update.callback_query:
yield 'update_type', 'callback_query'
yield "update_type", "callback_query"
yield from self.process_callback_query(update.callback_query)
if update.shipping_query:
yield 'update_type', 'shipping_query'
yield "update_type", "shipping_query"
yield from self.process_shipping_query(update.shipping_query)
if update.pre_checkout_query:
yield 'update_type', 'pre_checkout_query'
yield "update_type", "pre_checkout_query"
yield from self.process_pre_checkout_query(update.pre_checkout_query)
def make_prefix(self, prefix, iterable):
@ -254,11 +312,11 @@ class LoggingFilter(logging.Filter):
if not user:
return
yield 'user_id', user.id
yield "user_id", user.id
if self.include_content:
yield 'user_full_name', user.full_name
yield "user_full_name", user.full_name
if user.username:
yield 'user_name', f"@{user.username}"
yield "user_name", f"@{user.username}"
def process_chat(self, chat: types.Chat):
"""
@ -270,15 +328,15 @@ class LoggingFilter(logging.Filter):
if not chat:
return
yield 'chat_id', chat.id
yield 'chat_type', chat.type
yield "chat_id", chat.id
yield "chat_type", chat.type
if self.include_content:
yield 'chat_title', chat.full_name
yield "chat_title", chat.full_name
if chat.username:
yield 'chat_name', f"@{chat.username}"
yield "chat_name", f"@{chat.username}"
def process_message(self, message: types.Message):
yield 'message_content_type', message.content_type
yield "message_content_type", message.content_type
yield from self.process_user(message.from_user)
yield from self.process_chat(message.chat)
@ -286,82 +344,84 @@ class LoggingFilter(logging.Filter):
return
if message.reply_to_message:
yield from self.make_prefix('reply_to', self.process_message(message.reply_to_message))
yield from self.make_prefix("reply_to", self.process_message(message.reply_to_message))
if message.forward_from:
yield from self.make_prefix('forward_from', self.process_user(message.forward_from))
yield from self.make_prefix("forward_from", self.process_user(message.forward_from))
if message.forward_from_chat:
yield from self.make_prefix('forward_from_chat', self.process_chat(message.forward_from_chat))
yield from self.make_prefix(
"forward_from_chat", self.process_chat(message.forward_from_chat)
)
if message.forward_from_message_id:
yield 'message_forward_from_message_id', message.forward_from_message_id
yield "message_forward_from_message_id", message.forward_from_message_id
if message.forward_date:
yield 'message_forward_date', message.forward_date
yield "message_forward_date", message.forward_date
if message.edit_date:
yield 'message_edit_date', message.edit_date
yield "message_edit_date", message.edit_date
if message.media_group_id:
yield 'message_media_group_id', message.media_group_id
yield "message_media_group_id", message.media_group_id
if message.author_signature:
yield 'message_author_signature', message.author_signature
yield "message_author_signature", message.author_signature
if message.text:
yield 'text', message.text or message.caption
yield 'html_text', message.html_text
yield "text", message.text or message.caption
yield "html_text", message.html_text
elif message.audio:
yield 'audio', message.audio.file_id
yield "audio", message.audio.file_id
elif message.animation:
yield 'animation', message.animation.file_id
yield "animation", message.animation.file_id
elif message.document:
yield 'document', message.document.file_id
yield "document", message.document.file_id
elif message.game:
yield 'game', message.game.title
yield "game", message.game.title
elif message.photo:
yield 'photo', message.photo[-1].file_id
yield "photo", message.photo[-1].file_id
elif message.sticker:
yield 'sticker', message.sticker.file_id
yield "sticker", message.sticker.file_id
elif message.video:
yield 'video', message.video.file_id
yield "video", message.video.file_id
elif message.video_note:
yield 'video_note', message.video_note.file_id
yield "video_note", message.video_note.file_id
elif message.voice:
yield 'voice', message.voice.file_id
yield "voice", message.voice.file_id
elif message.contact:
yield 'contact_full_name', message.contact.full_name
yield 'contact_phone_number', message.contact.phone_number
yield "contact_full_name", message.contact.full_name
yield "contact_phone_number", message.contact.phone_number
elif message.venue:
yield 'venue_address', message.venue.address
yield 'location_latitude', message.venue.location.latitude
yield 'location_longitude', message.venue.location.longitude
yield "venue_address", message.venue.address
yield "location_latitude", message.venue.location.latitude
yield "location_longitude", message.venue.location.longitude
elif message.location:
yield 'location_latitude', message.location.latitude
yield 'location_longitude', message.location.longitude
yield "location_latitude", message.location.latitude
yield "location_longitude", message.location.longitude
elif message.new_chat_members:
yield 'new_chat_members', [user.id for user in message.new_chat_members]
yield "new_chat_members", [user.id for user in message.new_chat_members]
elif message.left_chat_member:
yield 'left_chat_member', [user.id for user in message.new_chat_members]
yield "left_chat_member", [user.id for user in message.new_chat_members]
elif message.invoice:
yield 'invoice_title', message.invoice.title
yield 'invoice_description', message.invoice.description
yield 'invoice_start_parameter', message.invoice.start_parameter
yield 'invoice_currency', message.invoice.currency
yield 'invoice_total_amount', message.invoice.total_amount
yield "invoice_title", message.invoice.title
yield "invoice_description", message.invoice.description
yield "invoice_start_parameter", message.invoice.start_parameter
yield "invoice_currency", message.invoice.currency
yield "invoice_total_amount", message.invoice.total_amount
elif message.successful_payment:
yield 'successful_payment_currency', message.successful_payment.currency
yield 'successful_payment_total_amount', message.successful_payment.total_amount
yield 'successful_payment_invoice_payload', message.successful_payment.invoice_payload
yield 'successful_payment_shipping_option_id', message.successful_payment.shipping_option_id
yield 'successful_payment_telegram_payment_charge_id', message.successful_payment.telegram_payment_charge_id
yield 'successful_payment_provider_payment_charge_id', message.successful_payment.provider_payment_charge_id
yield "successful_payment_currency", message.successful_payment.currency
yield "successful_payment_total_amount", message.successful_payment.total_amount
yield "successful_payment_invoice_payload", message.successful_payment.invoice_payload
yield "successful_payment_shipping_option_id", message.successful_payment.shipping_option_id
yield "successful_payment_telegram_payment_charge_id", message.successful_payment.telegram_payment_charge_id
yield "successful_payment_provider_payment_charge_id", message.successful_payment.provider_payment_charge_id
elif message.connected_website:
yield 'connected_website', message.connected_website
yield "connected_website", message.connected_website
elif message.migrate_from_chat_id:
yield 'migrate_from_chat_id', message.migrate_from_chat_id
yield "migrate_from_chat_id", message.migrate_from_chat_id
elif message.migrate_to_chat_id:
yield 'migrate_to_chat_id', message.migrate_to_chat_id
yield "migrate_to_chat_id", message.migrate_to_chat_id
elif message.pinned_message:
yield from self.make_prefix('pinned_message', message.pinned_message)
yield from self.make_prefix("pinned_message", message.pinned_message)
elif message.new_chat_title:
yield 'new_chat_title', message.new_chat_title
yield "new_chat_title", message.new_chat_title
elif message.new_chat_photo:
yield 'new_chat_photo', message.new_chat_photo[-1].file_id
yield "new_chat_photo", message.new_chat_photo[-1].file_id
# elif message.delete_chat_photo:
# yield 'delete_chat_photo', message.delete_chat_photo
# elif message.group_chat_created:
@ -370,53 +430,55 @@ class LoggingFilter(logging.Filter):
# yield 'passport_data', message.passport_data
def process_inline_query(self, inline_query: types.InlineQuery):
yield 'inline_query_id', inline_query.id
yield "inline_query_id", inline_query.id
yield from self.process_user(inline_query.from_user)
if self.include_content:
yield 'inline_query_text', inline_query.query
yield "inline_query_text", inline_query.query
if inline_query.location:
yield 'location_latitude', inline_query.location.latitude
yield 'location_longitude', inline_query.location.longitude
yield "location_latitude", inline_query.location.latitude
yield "location_longitude", inline_query.location.longitude
if inline_query.offset:
yield 'inline_query_offset', inline_query.offset
yield "inline_query_offset", inline_query.offset
def process_chosen_inline_result(self, chosen_inline_result: types.ChosenInlineResult):
yield 'chosen_inline_result_id', chosen_inline_result.result_id
yield "chosen_inline_result_id", chosen_inline_result.result_id
yield from self.process_user(chosen_inline_result.from_user)
if self.include_content:
yield 'inline_query_text', chosen_inline_result.query
yield "inline_query_text", chosen_inline_result.query
if chosen_inline_result.location:
yield 'location_latitude', chosen_inline_result.location.latitude
yield 'location_longitude', chosen_inline_result.location.longitude
yield "location_latitude", chosen_inline_result.location.latitude
yield "location_longitude", chosen_inline_result.location.longitude
def process_callback_query(self, callback_query: types.CallbackQuery):
yield from self.process_user(callback_query.from_user)
yield 'callback_query_data', callback_query.data
yield "callback_query_data", callback_query.data
if callback_query.message:
yield from self.make_prefix('callback_query_message', self.process_message(callback_query.message))
yield from self.make_prefix(
"callback_query_message", self.process_message(callback_query.message)
)
if callback_query.inline_message_id:
yield 'callback_query_inline_message_id', callback_query.inline_message_id
yield "callback_query_inline_message_id", callback_query.inline_message_id
if callback_query.chat_instance:
yield 'callback_query_chat_instance', callback_query.chat_instance
yield "callback_query_chat_instance", callback_query.chat_instance
if callback_query.game_short_name:
yield 'callback_query_game_short_name', callback_query.game_short_name
yield "callback_query_game_short_name", callback_query.game_short_name
def process_shipping_query(self, shipping_query: types.ShippingQuery):
yield 'shipping_query_id', shipping_query.id
yield "shipping_query_id", shipping_query.id
yield from self.process_user(shipping_query.from_user)
if self.include_content:
yield 'shipping_query_invoice_payload', shipping_query.invoice_payload
yield "shipping_query_invoice_payload", shipping_query.invoice_payload
def process_pre_checkout_query(self, pre_checkout_query: types.PreCheckoutQuery):
yield 'pre_checkout_query_id', pre_checkout_query.id
yield "pre_checkout_query_id", pre_checkout_query.id
yield from self.process_user(pre_checkout_query.from_user)
if self.include_content:
yield 'pre_checkout_query_currency', pre_checkout_query.currency
yield 'pre_checkout_query_total_amount', pre_checkout_query.total_amount
yield 'pre_checkout_query_invoice_payload', pre_checkout_query.invoice_payload
yield 'pre_checkout_query_shipping_option_id', pre_checkout_query.shipping_option_id
yield "pre_checkout_query_currency", pre_checkout_query.currency
yield "pre_checkout_query_total_amount", pre_checkout_query.total_amount
yield "pre_checkout_query_invoice_payload", pre_checkout_query.invoice_payload
yield "pre_checkout_query_shipping_option_id", pre_checkout_query.shipping_option_id

View file

@ -6,12 +6,12 @@ from . import webhook
from .dispatcher import Dispatcher, FSMContext, DEFAULT_RATE_LIMIT
__all__ = [
'DEFAULT_RATE_LIMIT',
'Dispatcher',
'FSMContext',
'filters',
'handler',
'middlewares',
'storage',
'webhook'
"DEFAULT_RATE_LIMIT",
"Dispatcher",
"FSMContext",
"filters",
"handler",
"middlewares",
"storage",
"webhook",
]

View file

@ -8,12 +8,29 @@ import typing
import aiohttp
from aiohttp.helpers import sentinel
from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \
RegexpCommandsFilter, StateFilter, Text
from .filters import (
Command,
ContentTypeFilter,
ExceptionsFilter,
FiltersFactory,
HashTag,
Regexp,
RegexpCommandsFilter,
StateFilter,
Text,
)
from .handler import Handler
from .middlewares import MiddlewareManager
from .storage import BaseStorage, DELTA, DisabledStorage, EXCEEDED_COUNT, FSMContext, \
LAST_CALL, RATE_LIMIT, RESULT
from .storage import (
BaseStorage,
DELTA,
DisabledStorage,
EXCEEDED_COUNT,
FSMContext,
LAST_CALL,
RATE_LIMIT,
RESULT,
)
from .webhook import BaseResponse
from .. import types
from ..bot import Bot
@ -22,7 +39,7 @@ from ..utils.mixins import ContextInstanceMixin, DataMixin
log = logging.getLogger(__name__)
DEFAULT_RATE_LIMIT = .1
DEFAULT_RATE_LIMIT = 0.1
class Dispatcher(DataMixin, ContextInstanceMixin):
@ -33,13 +50,21 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
inline queries, chosen inline results, callback queries, shipping queries, pre-checkout queries.
"""
def __init__(self, bot, loop=None, storage: typing.Optional[BaseStorage] = None,
run_tasks_by_default: bool = False,
throttling_rate_limit=DEFAULT_RATE_LIMIT, no_throttle_error=False,
filters_factory=None):
def __init__(
self,
bot,
loop=None,
storage: typing.Optional[BaseStorage] = None,
run_tasks_by_default: bool = False,
throttling_rate_limit=DEFAULT_RATE_LIMIT,
no_throttle_error=False,
filters_factory=None,
):
if not isinstance(bot, Bot):
raise TypeError(f"Argument 'bot' must be an instance of Bot, not '{type(bot).__name__}'")
raise TypeError(
f"Argument 'bot' must be an instance of Bot, not '{type(bot).__name__}'"
)
if loop is None:
loop = bot.loop
@ -57,18 +82,18 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
self.no_throttle_error = no_throttle_error
self.filters_factory: FiltersFactory = filters_factory
self.updates_handler = Handler(self, middleware_key='update')
self.message_handlers = Handler(self, middleware_key='message')
self.edited_message_handlers = Handler(self, middleware_key='edited_message')
self.channel_post_handlers = Handler(self, middleware_key='channel_post')
self.edited_channel_post_handlers = Handler(self, middleware_key='edited_channel_post')
self.inline_query_handlers = Handler(self, middleware_key='inline_query')
self.chosen_inline_result_handlers = Handler(self, middleware_key='chosen_inline_result')
self.callback_query_handlers = Handler(self, middleware_key='callback_query')
self.shipping_query_handlers = Handler(self, middleware_key='shipping_query')
self.pre_checkout_query_handlers = Handler(self, middleware_key='pre_checkout_query')
self.poll_handlers = Handler(self, middleware_key='poll')
self.errors_handlers = Handler(self, once=False, middleware_key='error')
self.updates_handler = Handler(self, middleware_key="update")
self.message_handlers = Handler(self, middleware_key="message")
self.edited_message_handlers = Handler(self, middleware_key="edited_message")
self.channel_post_handlers = Handler(self, middleware_key="channel_post")
self.edited_channel_post_handlers = Handler(self, middleware_key="edited_channel_post")
self.inline_query_handlers = Handler(self, middleware_key="inline_query")
self.chosen_inline_result_handlers = Handler(self, middleware_key="chosen_inline_result")
self.callback_query_handlers = Handler(self, middleware_key="callback_query")
self.shipping_query_handlers = Handler(self, middleware_key="shipping_query")
self.pre_checkout_query_handlers = Handler(self, middleware_key="pre_checkout_query")
self.poll_handlers = Handler(self, middleware_key="poll")
self.errors_handlers = Handler(self, once=False, middleware_key="error")
self.middleware = MiddlewareManager(self)
@ -83,37 +108,57 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
def _setup_filters(self):
filters_factory = self.filters_factory
filters_factory.bind(StateFilter, exclude_event_handlers=[
self.errors_handlers,
self.poll_handlers
])
filters_factory.bind(ContentTypeFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
]),
filters_factory.bind(Command, event_handlers=[
self.message_handlers, self.edited_message_handlers
])
filters_factory.bind(Text, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
self.callback_query_handlers, self.poll_handlers
])
filters_factory.bind(HashTag, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers
])
filters_factory.bind(Regexp, event_handlers=[
self.message_handlers, self.edited_message_handlers,
self.channel_post_handlers, self.edited_channel_post_handlers,
self.callback_query_handlers, self.poll_handlers
])
filters_factory.bind(RegexpCommandsFilter, event_handlers=[
self.message_handlers, self.edited_message_handlers
])
filters_factory.bind(ExceptionsFilter, event_handlers=[
self.errors_handlers
])
filters_factory.bind(
StateFilter, exclude_event_handlers=[self.errors_handlers, self.poll_handlers]
)
filters_factory.bind(
ContentTypeFilter,
event_handlers=[
self.message_handlers,
self.edited_message_handlers,
self.channel_post_handlers,
self.edited_channel_post_handlers,
],
),
filters_factory.bind(
Command, event_handlers=[self.message_handlers, self.edited_message_handlers]
)
filters_factory.bind(
Text,
event_handlers=[
self.message_handlers,
self.edited_message_handlers,
self.channel_post_handlers,
self.edited_channel_post_handlers,
self.callback_query_handlers,
self.poll_handlers,
],
)
filters_factory.bind(
HashTag,
event_handlers=[
self.message_handlers,
self.edited_message_handlers,
self.channel_post_handlers,
self.edited_channel_post_handlers,
],
)
filters_factory.bind(
Regexp,
event_handlers=[
self.message_handlers,
self.edited_message_handlers,
self.channel_post_handlers,
self.edited_channel_post_handlers,
self.callback_query_handlers,
self.poll_handlers,
],
)
filters_factory.bind(
RegexpCommandsFilter,
event_handlers=[self.message_handlers, self.edited_message_handlers],
)
filters_factory.bind(ExceptionsFilter, event_handlers=[self.errors_handlers])
def __del__(self):
self.stop_polling()
@ -209,13 +254,15 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
return await self.bot.delete_webhook()
async def start_polling(self,
timeout=20,
relax=0.1,
limit=None,
reset_webhook=None,
fast: typing.Optional[bool] = True,
error_sleep: int = 5):
async def start_polling(
self,
timeout=20,
relax=0.1,
limit=None,
reset_webhook=None,
fast: typing.Optional[bool] = True,
error_sleep: int = 5,
):
"""
Start long-polling
@ -227,9 +274,9 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:return:
"""
if self._polling:
raise RuntimeError('Polling already started')
raise RuntimeError("Polling already started")
log.info('Start polling.')
log.info("Start polling.")
# context.set_value(MODE, LONG_POLLING)
Dispatcher.set_current(self)
@ -245,16 +292,20 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
try:
current_request_timeout = self.bot.timeout
if current_request_timeout is not sentinel and timeout is not None:
request_timeout = aiohttp.ClientTimeout(total=current_request_timeout.total + timeout or 1)
request_timeout = aiohttp.ClientTimeout(
total=current_request_timeout.total + timeout or 1
)
else:
request_timeout = None
while self._polling:
try:
with self.bot.request_timeout(request_timeout):
updates = await self.bot.get_updates(limit=limit, offset=offset, timeout=timeout)
updates = await self.bot.get_updates(
limit=limit, offset=offset, timeout=timeout
)
except:
log.exception('Cause exception while getting updates.')
log.exception("Cause exception while getting updates.")
await asyncio.sleep(error_sleep)
continue
@ -269,7 +320,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
finally:
self._close_waiter._set_result(None)
log.warning('Polling is stopped.')
log.warning("Polling is stopped.")
async def _process_polling_updates(self, updates, fast: typing.Optional[bool] = True):
"""
@ -288,7 +339,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
try:
asyncio.gather(*need_to_call)
except TelegramAPIError:
log.exception('Cause exception while processing updates.')
log.exception("Cause exception while processing updates.")
def stop_polling(self):
"""
@ -296,8 +347,8 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:return:
"""
if hasattr(self, '_polling') and self._polling:
log.info('Stop polling...')
if hasattr(self, "_polling") and self._polling:
log.info("Stop polling...")
self._polling = False
async def wait_closed(self):
@ -316,8 +367,17 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
return self._polling
def register_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None,
state=None, run_task=None, **kwargs):
def register_message_handler(
self,
callback,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Register handler for message
@ -343,17 +403,27 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param state:
:return: decorated function
"""
filters_set = self.filters_factory.resolve(self.message_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs)
filters_set = self.filters_factory.resolve(
self.message_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs,
)
self.message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
def message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None, state=None,
run_task=None, **kwargs):
def message_handler(
self,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Decorator for message handler
@ -424,15 +494,31 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_message_handler(callback, *custom_filters,
commands=commands, regexp=regexp, content_types=content_types,
state=state, run_task=run_task, **kwargs)
self.register_message_handler(
callback,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
run_task=run_task,
**kwargs,
)
return callback
return decorator
def register_edited_message_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None,
state=None, run_task=None, **kwargs):
def register_edited_message_handler(
self,
callback,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Register handler for edited message
@ -446,17 +532,29 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param kwargs:
:return: decorated function
"""
filters_set = self.filters_factory.resolve(self.edited_message_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs)
self.edited_message_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
filters_set = self.filters_factory.resolve(
self.edited_message_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs,
)
self.edited_message_handlers.register(
self._wrap_async_task(callback, run_task), filters_set
)
def edited_message_handler(self, *custom_filters, commands=None, regexp=None, content_types=None,
state=None, run_task=None, **kwargs):
def edited_message_handler(
self,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Decorator for edited message handler
@ -479,14 +577,31 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_edited_message_handler(callback, *custom_filters, commands=commands, regexp=regexp,
content_types=content_types, state=state, run_task=run_task, **kwargs)
self.register_edited_message_handler(
callback,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
run_task=run_task,
**kwargs,
)
return callback
return decorator
def register_channel_post_handler(self, callback, *custom_filters, commands=None, regexp=None, content_types=None,
state=None, run_task=None, **kwargs):
def register_channel_post_handler(
self,
callback,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Register handler for channel post
@ -500,17 +615,27 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param kwargs:
:return: decorated function
"""
filters_set = self.filters_factory.resolve(self.channel_post_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs)
filters_set = self.filters_factory.resolve(
self.channel_post_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs,
)
self.channel_post_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
def channel_post_handler(self, *custom_filters, commands=None, regexp=None, content_types=None,
state=None, run_task=None, **kwargs):
def channel_post_handler(
self,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Decorator for channel post handler
@ -525,14 +650,31 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp,
content_types=content_types, state=state, run_task=run_task, **kwargs)
self.register_channel_post_handler(
callback,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
run_task=run_task,
**kwargs,
)
return callback
return decorator
def register_edited_channel_post_handler(self, callback, *custom_filters, commands=None, regexp=None,
content_types=None, state=None, run_task=None, **kwargs):
def register_edited_channel_post_handler(
self,
callback,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Register handler for edited channel post
@ -546,17 +688,29 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param kwargs:
:return: decorated function
"""
filters_set = self.filters_factory.resolve(self.edited_message_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs)
self.edited_channel_post_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
filters_set = self.filters_factory.resolve(
self.edited_message_handlers,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
**kwargs,
)
self.edited_channel_post_handlers.register(
self._wrap_async_task(callback, run_task), filters_set
)
def edited_channel_post_handler(self, *custom_filters, commands=None, regexp=None, content_types=None,
state=None, run_task=None, **kwargs):
def edited_channel_post_handler(
self,
*custom_filters,
commands=None,
regexp=None,
content_types=None,
state=None,
run_task=None,
**kwargs,
):
"""
Decorator for edited channel post handler
@ -571,14 +725,23 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_edited_channel_post_handler(callback, *custom_filters, commands=commands, regexp=regexp,
content_types=content_types, state=state, run_task=run_task,
**kwargs)
self.register_edited_channel_post_handler(
callback,
*custom_filters,
commands=commands,
regexp=regexp,
content_types=content_types,
state=state,
run_task=run_task,
**kwargs,
)
return callback
return decorator
def register_inline_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs):
def register_inline_handler(
self, callback, *custom_filters, state=None, run_task=None, **kwargs
):
"""
Register handler for inline query
@ -597,10 +760,9 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
if custom_filters is None:
custom_filters = []
filters_set = self.filters_factory.resolve(self.inline_query_handlers,
*custom_filters,
state=state,
**kwargs)
filters_set = self.filters_factory.resolve(
self.inline_query_handlers, *custom_filters, state=state, **kwargs
)
self.inline_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
def inline_handler(self, *custom_filters, state=None, run_task=None, **kwargs):
@ -622,12 +784,16 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_inline_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
self.register_inline_handler(
callback, *custom_filters, state=state, run_task=run_task, **kwargs
)
return callback
return decorator
def register_chosen_inline_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs):
def register_chosen_inline_handler(
self, callback, *custom_filters, state=None, run_task=None, **kwargs
):
"""
Register handler for chosen inline query
@ -646,11 +812,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
if custom_filters is None:
custom_filters = []
filters_set = self.filters_factory.resolve(self.chosen_inline_result_handlers,
*custom_filters,
state=state,
**kwargs)
self.chosen_inline_result_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
filters_set = self.filters_factory.resolve(
self.chosen_inline_result_handlers, *custom_filters, state=state, **kwargs
)
self.chosen_inline_result_handlers.register(
self._wrap_async_task(callback, run_task), filters_set
)
def chosen_inline_handler(self, *custom_filters, state=None, run_task=None, **kwargs):
"""
@ -671,12 +838,16 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_chosen_inline_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
self.register_chosen_inline_handler(
callback, *custom_filters, state=state, run_task=run_task, **kwargs
)
return callback
return decorator
def register_callback_query_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs):
def register_callback_query_handler(
self, callback, *custom_filters, state=None, run_task=None, **kwargs
):
"""
Register handler for callback query
@ -692,11 +863,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param run_task: run callback in task (no wait results)
:param kwargs:
"""
filters_set = self.filters_factory.resolve(self.callback_query_handlers,
*custom_filters,
state=state,
**kwargs)
self.callback_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
filters_set = self.filters_factory.resolve(
self.callback_query_handlers, *custom_filters, state=state, **kwargs
)
self.callback_query_handlers.register(
self._wrap_async_task(callback, run_task), filters_set
)
def callback_query_handler(self, *custom_filters, state=None, run_task=None, **kwargs):
"""
@ -716,13 +888,16 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_callback_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
self.register_callback_query_handler(
callback, *custom_filters, state=state, run_task=run_task, **kwargs
)
return callback
return decorator
def register_shipping_query_handler(self, callback, *custom_filters, state=None, run_task=None,
**kwargs):
def register_shipping_query_handler(
self, callback, *custom_filters, state=None, run_task=None, **kwargs
):
"""
Register handler for shipping query
@ -738,11 +913,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param run_task: run callback in task (no wait results)
:param kwargs:
"""
filters_set = self.filters_factory.resolve(self.shipping_query_handlers,
*custom_filters,
state=state,
**kwargs)
self.shipping_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
filters_set = self.filters_factory.resolve(
self.shipping_query_handlers, *custom_filters, state=state, **kwargs
)
self.shipping_query_handlers.register(
self._wrap_async_task(callback, run_task), filters_set
)
def shipping_query_handler(self, *custom_filters, state=None, run_task=None, **kwargs):
"""
@ -762,12 +938,16 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_shipping_query_handler(callback, *custom_filters, state=state, run_task=run_task, **kwargs)
self.register_shipping_query_handler(
callback, *custom_filters, state=state, run_task=run_task, **kwargs
)
return callback
return decorator
def register_pre_checkout_query_handler(self, callback, *custom_filters, state=None, run_task=None, **kwargs):
def register_pre_checkout_query_handler(
self, callback, *custom_filters, state=None, run_task=None, **kwargs
):
"""
Register handler for pre-checkout query
@ -783,11 +963,12 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param run_task: run callback in task (no wait results)
:param kwargs:
"""
filters_set = self.filters_factory.resolve(self.pre_checkout_query_handlers,
*custom_filters,
state=state,
**kwargs)
self.pre_checkout_query_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
filters_set = self.filters_factory.resolve(
self.pre_checkout_query_handlers, *custom_filters, state=state, **kwargs
)
self.pre_checkout_query_handlers.register(
self._wrap_async_task(callback, run_task), filters_set
)
def pre_checkout_query_handler(self, *custom_filters, state=None, run_task=None, **kwargs):
"""
@ -807,27 +988,27 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_pre_checkout_query_handler(callback, *custom_filters, state=state, run_task=run_task,
**kwargs)
self.register_pre_checkout_query_handler(
callback, *custom_filters, state=state, run_task=run_task, **kwargs
)
return callback
return decorator
def register_poll_handler(self, callback, *custom_filters, run_task=None, **kwargs):
filters_set = self.filters_factory.resolve(self.poll_handlers,
*custom_filters,
**kwargs)
filters_set = self.filters_factory.resolve(self.poll_handlers, *custom_filters, **kwargs)
self.poll_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
def poll_handler(self, *custom_filters, run_task=None, **kwargs):
def decorator(callback):
self.register_poll_handler(callback, *custom_filters, run_task=run_task,
**kwargs)
self.register_poll_handler(callback, *custom_filters, run_task=run_task, **kwargs)
return callback
return decorator
def register_errors_handler(self, callback, *custom_filters, exception=None, run_task=None, **kwargs):
def register_errors_handler(
self, callback, *custom_filters, exception=None, run_task=None, **kwargs
):
"""
Register handler for errors
@ -835,10 +1016,9 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:param exception: you can make handler for specific errors type
:param run_task: run callback in task (no wait results)
"""
filters_set = self.filters_factory.resolve(self.errors_handlers,
*custom_filters,
exception=exception,
**kwargs)
filters_set = self.filters_factory.resolve(
self.errors_handlers, *custom_filters, exception=exception, **kwargs
)
self.errors_handlers.register(self._wrap_async_task(callback, run_task), filters_set)
def errors_handler(self, *custom_filters, exception=None, run_task=None, **kwargs):
@ -851,15 +1031,22 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
"""
def decorator(callback):
self.register_errors_handler(self._wrap_async_task(callback, run_task),
*custom_filters, exception=exception, **kwargs)
self.register_errors_handler(
self._wrap_async_task(callback, run_task),
*custom_filters,
exception=exception,
**kwargs,
)
return callback
return decorator
def current_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None) -> FSMContext:
def current_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
) -> FSMContext:
"""
Get current state for user in chat as context
@ -897,7 +1084,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:return: bool
"""
if not self.storage.has_bucket():
raise RuntimeError('This storage does not provide Leaky Bucket')
raise RuntimeError("This storage does not provide Leaky Bucket")
if no_error is None:
no_error = self.no_throttle_error
@ -951,7 +1138,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:return:
"""
if not self.storage.has_bucket():
raise RuntimeError('This storage does not provide Leaky Bucket')
raise RuntimeError("This storage does not provide Leaky Bucket")
if user is None and chat is None:
user = types.User.get_current()
@ -971,7 +1158,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
:return:
"""
if not self.storage.has_bucket():
raise RuntimeError('This storage does not provide Leaky Bucket')
raise RuntimeError("This storage does not provide Leaky Bucket")
if user is None and chat is None:
user = types.User.get_current()
@ -979,7 +1166,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
bucket = await self.storage.get_bucket(chat=chat, user=user)
if bucket and key in bucket:
del bucket['key']
del bucket["key"]
await self.storage.set_bucket(chat=chat, user=user, bucket=bucket)
return True
return False
@ -1005,8 +1192,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin):
try:
response = task.result()
except Exception as e:
self.loop.create_task(
self.errors_handlers.notify(types.Update.get_current(), e))
self.loop.create_task(self.errors_handlers.notify(types.Update.get_current(), e))
else:
if isinstance(response, BaseResponse):
self.loop.create_task(response.execute_response(self.bot))

View file

@ -1,30 +1,51 @@
from .builtin import Command, CommandHelp, CommandPrivacy, CommandSettings, CommandStart, ContentTypeFilter, \
ExceptionsFilter, HashTag, Regexp, RegexpCommandsFilter, StateFilter, Text
from .builtin import (
Command,
CommandHelp,
CommandPrivacy,
CommandSettings,
CommandStart,
ContentTypeFilter,
ExceptionsFilter,
HashTag,
Regexp,
RegexpCommandsFilter,
StateFilter,
Text,
)
from .factory import FiltersFactory
from .filters import AbstractFilter, BoundFilter, Filter, FilterNotPassed, FilterRecord, execute_filter, \
check_filters, get_filter_spec, get_filters_spec
from .filters import (
AbstractFilter,
BoundFilter,
Filter,
FilterNotPassed,
FilterRecord,
execute_filter,
check_filters,
get_filter_spec,
get_filters_spec,
)
__all__ = [
'AbstractFilter',
'BoundFilter',
'Command',
'CommandStart',
'CommandHelp',
'CommandPrivacy',
'CommandSettings',
'ContentTypeFilter',
'ExceptionsFilter',
'HashTag',
'Filter',
'FilterNotPassed',
'FilterRecord',
'FiltersFactory',
'RegexpCommandsFilter',
'Regexp',
'StateFilter',
'Text',
'get_filter_spec',
'get_filters_spec',
'execute_filter',
'check_filters'
"AbstractFilter",
"BoundFilter",
"Command",
"CommandStart",
"CommandHelp",
"CommandPrivacy",
"CommandSettings",
"ContentTypeFilter",
"ExceptionsFilter",
"HashTag",
"Filter",
"FilterNotPassed",
"FilterRecord",
"FiltersFactory",
"RegexpCommandsFilter",
"Regexp",
"StateFilter",
"Text",
"get_filter_spec",
"get_filters_spec",
"execute_filter",
"check_filters",
]

View file

@ -21,10 +21,13 @@ class Command(Filter):
By default this filter is registered for messages and edited messages handlers.
"""
def __init__(self, commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = '/',
ignore_case: bool = True,
ignore_mention: bool = False):
def __init__(
self,
commands: Union[Iterable, str],
prefixes: Union[Iterable, str] = "/",
ignore_case: bool = True,
ignore_mention: bool = False,
):
"""
Filter can be initialized from filters factory or by simply creating instance of this class.
@ -66,30 +69,38 @@ class Command(Filter):
:return: config or empty dict
"""
config = {}
if 'commands' in full_config:
config['commands'] = full_config.pop('commands')
if config and 'commands_prefix' in full_config:
config['prefixes'] = full_config.pop('commands_prefix')
if config and 'commands_ignore_mention' in full_config:
config['ignore_mention'] = full_config.pop('commands_ignore_mention')
if "commands" in full_config:
config["commands"] = full_config.pop("commands")
if config and "commands_prefix" in full_config:
config["prefixes"] = full_config.pop("commands_prefix")
if config and "commands_ignore_mention" in full_config:
config["ignore_mention"] = full_config.pop("commands_ignore_mention")
return config
async def check(self, message: types.Message):
return await self.check_command(message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention)
return await self.check_command(
message, self.commands, self.prefixes, self.ignore_case, self.ignore_mention
)
@staticmethod
async def check_command(message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False):
async def check_command(
message: types.Message, commands, prefixes, ignore_case=True, ignore_mention=False
):
full_command = message.text.split()[0]
prefix, (command, _, mention) = full_command[0], full_command[1:].partition('@')
prefix, (command, _, mention) = (full_command[0], full_command[1:].partition("@"))
if not ignore_mention and mention and (await message.bot.me).username.lower() != mention.lower():
if (
not ignore_mention
and mention
and (await message.bot.me).username.lower() != mention.lower()
):
return False
elif prefix not in prefixes:
return False
elif (command.lower() if ignore_case else command) not in commands:
return False
return {'command': Command.CommandObj(command=command, prefix=prefix, mention=mention)}
return {"command": Command.CommandObj(command=command, prefix=prefix, mention=mention)}
@dataclass
class CommandObj:
@ -100,9 +111,9 @@ class Command(Filter):
"""
"""Command prefix"""
prefix: str = '/'
prefix: str = "/"
"""Command without prefix and mention"""
command: str = ''
command: str = ""
"""Mention (if available)"""
mention: str = None
"""Command argument"""
@ -126,9 +137,9 @@ class Command(Filter):
"""
line = self.prefix + self.command
if self.mentioned:
line += '@' + self.mention
line += "@" + self.mention
if self.args:
line += ' ' + self.args
line += " " + self.args
return line
@ -149,7 +160,7 @@ class CommandStart(Command):
:param deep_link: string or compiled regular expression (by ``re.compile(...)``).
"""
super(CommandStart, self).__init__(['start'])
super(CommandStart, self).__init__(["start"])
self.deep_link = deep_link
async def check(self, message: types.Message):
@ -167,7 +178,7 @@ class CommandStart(Command):
match = self.deep_link.match(message.get_args())
if match:
return {'deep_link': match}
return {"deep_link": match}
return False
return check
@ -179,7 +190,7 @@ class CommandHelp(Command):
"""
def __init__(self):
super(CommandHelp, self).__init__(['help'])
super(CommandHelp, self).__init__(["help"])
class CommandSettings(Command):
@ -188,7 +199,7 @@ class CommandSettings(Command):
"""
def __init__(self):
super(CommandSettings, self).__init__(['settings'])
super(CommandSettings, self).__init__(["settings"])
class CommandPrivacy(Command):
@ -197,7 +208,7 @@ class CommandPrivacy(Command):
"""
def __init__(self):
super(CommandPrivacy, self).__init__(['privacy'])
super(CommandPrivacy, self).__init__(["privacy"])
class Text(Filter):
@ -205,12 +216,14 @@ class Text(Filter):
Simple text filter
"""
def __init__(self,
equals: Optional[Union[str, LazyProxy]] = None,
contains: Optional[Union[str, LazyProxy]] = None,
startswith: Optional[Union[str, LazyProxy]] = None,
endswith: Optional[Union[str, LazyProxy]] = None,
ignore_case=False):
def __init__(
self,
equals: Optional[Union[str, LazyProxy]] = None,
contains: Optional[Union[str, LazyProxy]] = None,
startswith: Optional[Union[str, LazyProxy]] = None,
endswith: Optional[Union[str, LazyProxy]] = None,
ignore_case=False,
):
"""
Check text for one of pattern. Only one mode can be used in one filter.
@ -223,11 +236,18 @@ class Text(Filter):
# Only one mode can be used. check it.
check = sum(map(bool, (equals, contains, startswith, endswith)))
if check > 1:
args = "' and '".join([arg[0] for arg in [('equals', equals),
('contains', contains),
('startswith', startswith),
('endswith', endswith)
] if arg[1]])
args = "' and '".join(
[
arg[0]
for arg in [
("equals", equals),
("contains", contains),
("startswith", startswith),
("endswith", endswith),
]
if arg[1]
]
)
raise ValueError(f"Arguments '{args}' cannot be used together.")
elif check == 0:
raise ValueError(f"No one mode is specified!")
@ -240,18 +260,18 @@ class Text(Filter):
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'text' in full_config:
return {'equals': full_config.pop('text')}
elif 'text_contains' in full_config:
return {'contains': full_config.pop('text_contains')}
elif 'text_startswith' in full_config:
return {'startswith': full_config.pop('text_startswith')}
elif 'text_endswith' in full_config:
return {'endswith': full_config.pop('text_endswith')}
if "text" in full_config:
return {"equals": full_config.pop("text")}
elif "text_contains" in full_config:
return {"contains": full_config.pop("text_contains")}
elif "text_startswith" in full_config:
return {"startswith": full_config.pop("text_startswith")}
elif "text_endswith" in full_config:
return {"endswith": full_config.pop("text_endswith")}
async def check(self, obj: Union[Message, CallbackQuery, InlineQuery]):
if isinstance(obj, Message):
text = obj.text or obj.caption or ''
text = obj.text or obj.caption or ""
if not text and obj.poll:
text = obj.poll.question
elif isinstance(obj, CallbackQuery):
@ -287,7 +307,7 @@ class HashTag(Filter):
def __init__(self, hashtags=None, cashtags=None):
if not hashtags and not cashtags:
raise ValueError('No one hashtag or cashtag is specified!')
raise ValueError("No one hashtag or cashtag is specified!")
if hashtags is None:
hashtags = []
@ -307,10 +327,10 @@ class HashTag(Filter):
@classmethod
def validate(cls, full_config: Dict[str, Any]):
config = {}
if 'hashtags' in full_config:
config['hashtags'] = full_config.pop('hashtags')
if 'cashtags' in full_config:
config['cashtags'] = full_config.pop('cashtags')
if "hashtags" in full_config:
config["hashtags"] = full_config.pop("hashtags")
if "cashtags" in full_config:
config["cashtags"] = full_config.pop("cashtags")
return config
async def check(self, message: types.Message):
@ -324,9 +344,13 @@ class HashTag(Filter):
return False
hashtags, cashtags = self._get_tags(text, entities)
if self.hashtags and set(hashtags) & set(self.hashtags) \
or self.cashtags and set(cashtags) & set(self.cashtags):
return {'hashtags': hashtags, 'cashtags': cashtags}
if (
self.hashtags
and set(hashtags) & set(self.hashtags)
or self.cashtags
and set(cashtags) & set(self.cashtags)
):
return {"hashtags": hashtags, "cashtags": cashtags}
def _get_tags(self, text, entities):
hashtags = []
@ -334,11 +358,11 @@ class HashTag(Filter):
for entity in entities:
if entity.type == types.MessageEntityType.HASHTAG:
value = entity.get_text(text).lstrip('#')
value = entity.get_text(text).lstrip("#")
hashtags.append(value)
elif entity.type == types.MessageEntityType.CASHTAG:
value = entity.get_text(text).lstrip('$')
value = entity.get_text(text).lstrip("$")
cashtags.append(value)
return hashtags, cashtags
@ -356,12 +380,12 @@ class Regexp(Filter):
@classmethod
def validate(cls, full_config: Dict[str, Any]):
if 'regexp' in full_config:
return {'regexp': full_config.pop('regexp')}
if "regexp" in full_config:
return {"regexp": full_config.pop("regexp")}
async def check(self, obj: Union[Message, CallbackQuery]):
if isinstance(obj, Message):
content = obj.text or obj.caption or ''
content = obj.text or obj.caption or ""
if not content and obj.poll:
content = obj.poll.question
elif isinstance(obj, CallbackQuery) and obj.data:
@ -372,7 +396,7 @@ class Regexp(Filter):
match = self.regexp.search(content)
if match:
return {'regexp': match}
return {"regexp": match}
return False
@ -381,17 +405,19 @@ class RegexpCommandsFilter(BoundFilter):
Check commands by regexp in message
"""
key = 'regexp_commands'
key = "regexp_commands"
def __init__(self, regexp_commands):
self.regexp_commands = [re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands]
self.regexp_commands = [
re.compile(command, flags=re.IGNORECASE | re.MULTILINE) for command in regexp_commands
]
async def check(self, message):
if not message.is_command():
return False
command = message.text.split()[0][1:]
command, _, mention = command.partition('@')
command, _, mention = command.partition("@")
if mention and mention != (await message.bot.me).username:
return False
@ -399,7 +425,7 @@ class RegexpCommandsFilter(BoundFilter):
for command in self.regexp_commands:
search = command.search(message.text)
if search:
return {'regexp_command': search}
return {"regexp_command": search}
return False
@ -408,7 +434,7 @@ class ContentTypeFilter(BoundFilter):
Check message content type
"""
key = 'content_types'
key = "content_types"
required = True
default = types.ContentTypes.TEXT
@ -416,18 +442,21 @@ class ContentTypeFilter(BoundFilter):
self.content_types = content_types
async def check(self, message):
return types.ContentType.ANY in self.content_types or \
message.content_type in self.content_types
return (
types.ContentType.ANY in self.content_types
or message.content_type in self.content_types
)
class StateFilter(BoundFilter):
"""
Check user state
"""
key = 'state'
key = "state"
required = True
ctx_state = ContextVar('user_state')
ctx_state = ContextVar("user_state")
def __init__(self, dispatcher, state):
from aiogram.dispatcher.filters.state import State, StatesGroup
@ -435,7 +464,7 @@ class StateFilter(BoundFilter):
self.dispatcher = dispatcher
states = []
if not isinstance(state, (list, set, tuple, frozenset)) or state is None:
state = [state, ]
state = [state]
for item in state:
if isinstance(item, State):
states.append(item.state)
@ -446,11 +475,14 @@ class StateFilter(BoundFilter):
self.states = states
def get_target(self, obj):
return getattr(getattr(obj, 'chat', None), 'id', None), getattr(getattr(obj, 'from_user', None), 'id', None)
return (
getattr(getattr(obj, "chat", None), "id", None),
getattr(getattr(obj, "from_user", None), "id", None),
)
async def check(self, obj):
if '*' in self.states:
return {'state': self.dispatcher.current_state()}
if "*" in self.states:
return {"state": self.dispatcher.current_state()}
try:
state = self.ctx_state.get()
@ -461,11 +493,11 @@ class StateFilter(BoundFilter):
state = await self.dispatcher.storage.get_state(chat=chat, user=user)
self.ctx_state.set(state)
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
return {"state": self.dispatcher.current_state(), "raw_state": state}
else:
if state in self.states:
return {'state': self.dispatcher.current_state(), 'raw_state': state}
return {"state": self.dispatcher.current_state(), "raw_state": state}
return False
@ -475,7 +507,7 @@ class ExceptionsFilter(BoundFilter):
Filter for exceptions
"""
key = 'exception'
key = "exception"
def __init__(self, exception):
self.exception = exception

View file

@ -13,10 +13,13 @@ class FiltersFactory:
self._dispatcher = dispatcher
self._registered: typing.List[FilterRecord] = []
def bind(self, callback: typing.Union[typing.Callable, AbstractFilter],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.List[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
def bind(
self,
callback: typing.Union[typing.Callable, AbstractFilter],
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.List[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
):
"""
Register filter
@ -38,8 +41,9 @@ class FiltersFactory:
if record.callback == callback:
self._registered.remove(record)
def resolve(self, event_handler, *custom_filters, **full_config
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
def resolve(
self, event_handler, *custom_filters, **full_config
) -> typing.List[typing.Union[typing.Callable, AbstractFilter]]:
"""
Resolve filters to filters-set
@ -49,8 +53,11 @@ class FiltersFactory:
:return:
"""
filters_set = []
filters_set.extend(self._resolve_registered(event_handler,
{k: v for k, v in full_config.items() if v is not None}))
filters_set.extend(
self._resolve_registered(
event_handler, {k: v for k, v in full_config.items() if v is not None}
)
)
if custom_filters:
filters_set.extend(custom_filters)
@ -70,4 +77,4 @@ class FiltersFactory:
yield filter_
if full_config:
raise NameError('Invalid filter name(s): \'' + '\', '.join(full_config.keys()) + '\'')
raise NameError("Invalid filter name(s): '" + "', ".join(full_config.keys()) + "'")

View file

@ -13,9 +13,11 @@ def wrap_async(func):
async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)
if inspect.isawaitable(func) \
or inspect.iscoroutinefunction(func) \
or isinstance(func, AbstractFilter):
if (
inspect.isawaitable(func)
or inspect.iscoroutinefunction(func)
or isinstance(func, AbstractFilter)
):
return func
return async_wrapper
@ -23,14 +25,16 @@ def wrap_async(func):
def get_filter_spec(dispatcher, filter_: callable):
kwargs = {}
if not callable(filter_):
raise TypeError('Filter must be callable and/or awaitable!')
raise TypeError("Filter must be callable and/or awaitable!")
spec = inspect.getfullargspec(filter_)
if 'dispatcher' in spec:
kwargs['dispatcher'] = dispatcher
if inspect.isawaitable(filter_) \
or inspect.iscoroutinefunction(filter_) \
or isinstance(filter_, AbstractFilter):
if "dispatcher" in spec:
kwargs["dispatcher"] = dispatcher
if (
inspect.isawaitable(filter_)
or inspect.iscoroutinefunction(filter_)
or isinstance(filter_, AbstractFilter)
):
return FilterObj(filter=filter_, kwargs=kwargs, is_async=True)
else:
return FilterObj(filter=filter_, kwargs=kwargs, is_async=False)
@ -82,12 +86,17 @@ class FilterRecord:
Filters record for factory
"""
def __init__(self, callback: typing.Callable,
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None):
def __init__(
self,
callback: typing.Callable,
validator: typing.Optional[typing.Callable] = None,
event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
exclude_event_handlers: typing.Optional[typing.Iterable[Handler]] = None,
):
if event_handlers and exclude_event_handlers:
raise ValueError("'event_handlers' and 'exclude_event_handlers' arguments cannot be used together.")
raise ValueError(
"'event_handlers' and 'exclude_event_handlers' arguments cannot be used together."
)
self.callback = callback
self.event_handlers = event_handlers
@ -100,17 +109,17 @@ class FilterRecord:
elif issubclass(callback, AbstractFilter):
self.resolver = callback.validate
else:
raise RuntimeError('validator is required!')
raise RuntimeError("validator is required!")
def resolve(self, dispatcher, event_handler, full_config):
if not self._check_event_handler(event_handler):
return
config = self.resolver(full_config)
if config:
if 'dispatcher' not in config:
if "dispatcher" not in config:
spec = inspect.getfullargspec(self.callback)
if 'dispatcher' in spec.args:
config['dispatcher'] = dispatcher
if "dispatcher" in spec.args:
config["dispatcher"] = dispatcher
for key in config:
if key in full_config:
@ -133,7 +142,9 @@ class AbstractFilter(abc.ABC):
@classmethod
@abc.abstractmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""
Validate and parse config.
@ -184,7 +195,9 @@ class Filter(AbstractFilter):
"""
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
def validate(
cls, full_config: typing.Dict[str, typing.Any]
) -> typing.Optional[typing.Dict[str, typing.Any]]:
"""
Here method ``validate`` is optional.
If you need to use filter from filters factory you need to override this method.
@ -228,7 +241,7 @@ class BoundFilter(Filter):
class _LogicFilter(Filter):
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise ValueError('That filter can\'t be used in filters factory!')
raise ValueError("That filter can't be used in filters factory!")
class NotFilter(_LogicFilter):
@ -240,7 +253,6 @@ class NotFilter(_LogicFilter):
class AndFilter(_LogicFilter):
def __init__(self, *targets):
self.targets = list(wrap_async(target) for target in targets)

View file

@ -17,7 +17,7 @@ class State:
@property
def group(self):
if not self._group:
raise RuntimeError('This state is not in any group.')
raise RuntimeError("This state is not in any group.")
return self._group
def get_root(self):
@ -27,19 +27,19 @@ class State:
def state(self):
if self._state is None:
return None
elif self._state == '*':
elif self._state == "*":
return self._state
elif self._group_name is None and self._group:
group = self._group.__full_group_name__
elif self._group_name:
group = self._group_name
else:
group = '@'
group = "@"
return f"{group}:{self._state}"
def set_parent(self, group):
if not issubclass(group, StatesGroup):
raise ValueError('Group must be subclass of StatesGroup')
raise ValueError("Group must be subclass of StatesGroup")
self._group = group
def __set_name__(self, owner, name):
@ -89,7 +89,7 @@ class StatesGroupMeta(type):
@property
def __full_group_name__(cls):
if cls._parent:
return cls._parent.__full_group_name__ + '.' + cls._group_name
return cls._parent.__full_group_name__ + "." + cls._group_name
return cls._group_name
@property
@ -195,4 +195,4 @@ class StatesGroup(metaclass=StatesGroupMeta):
default_state = State()
any_state = State(state='*')
any_state = State(state="*")

View file

@ -3,8 +3,8 @@ from contextvars import ContextVar
from dataclasses import dataclass
from typing import Optional, Iterable
ctx_data = ContextVar('ctx_handler_data')
current_handler = ContextVar('current_handler')
ctx_data = ContextVar("ctx_handler_data")
current_handler = ContextVar("current_handler")
@dataclass
@ -23,7 +23,7 @@ class CancelHandler(Exception):
def _get_spec(func: callable):
while hasattr(func, '__wrapped__'): # Try to resolve decorated callbacks
while hasattr(func, "__wrapped__"): # Try to resolve decorated callbacks
func = func.__wrapped__
spec = inspect.getfullargspec(func)
@ -47,6 +47,7 @@ class Handler:
def register(self, handler, filters=None, index=None):
from .filters import get_filters_spec
"""
Register callback
@ -80,7 +81,7 @@ class Handler:
if handler is registered:
self.handlers.remove(handler_obj)
return True
raise ValueError('This handler is not registered!')
raise ValueError("This handler is not registered!")
async def notify(self, *args):
"""
@ -98,7 +99,9 @@ class Handler:
if self.middleware_key:
try:
await self.dispatcher.middleware.trigger(f"pre_process_{self.middleware_key}", args + (data,))
await self.dispatcher.middleware.trigger(
f"pre_process_{self.middleware_key}", args + (data,)
)
except CancelHandler: # Allow to cancel current event
return results
@ -112,7 +115,9 @@ class Handler:
ctx_token = current_handler.set(handler_obj.handler)
try:
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"process_{self.middleware_key}", args + (data,))
await self.dispatcher.middleware.trigger(
f"process_{self.middleware_key}", args + (data,)
)
partial_data = _check_spec(handler_obj.spec, data)
response = await handler_obj.handler(*args, **partial_data)
if response is not None:
@ -127,8 +132,9 @@ class Handler:
current_handler.reset(ctx_token)
finally:
if self.middleware_key:
await self.dispatcher.middleware.trigger(f"post_process_{self.middleware_key}",
args + (results, data,))
await self.dispatcher.middleware.trigger(
f"post_process_{self.middleware_key}", args + (results, data)
)
return results

View file

@ -1,7 +1,7 @@
import logging
import typing
log = logging.getLogger('aiogram.Middleware')
log = logging.getLogger("aiogram.Middleware")
class MiddlewareManager:
@ -29,9 +29,11 @@ class MiddlewareManager:
:return:
"""
if not isinstance(middleware, BaseMiddleware):
raise TypeError(f"`middleware` must be an instance of BaseMiddleware, not {type(middleware)}")
raise TypeError(
f"`middleware` must be an instance of BaseMiddleware, not {type(middleware)}"
)
if middleware.is_configured():
raise ValueError('That middleware is already used!')
raise ValueError("That middleware is already used!")
self.applications.append(middleware)
middleware.setup(self)
@ -67,7 +69,7 @@ class BaseMiddleware:
Instance of MiddlewareManager
"""
if self._manager is None:
raise RuntimeError('Middleware is not configured!')
raise RuntimeError("Middleware is not configured!")
return self._manager
def setup(self, manager):
@ -119,9 +121,9 @@ class LifetimeControllerMiddleware(BaseMiddleware):
return False
obj, *args, data = args
if action.startswith('pre_process_'):
if action.startswith("pre_process_"):
await self.pre_process(obj, data, *args)
elif action.startswith('post_process_'):
elif action.startswith("post_process_"):
await self.post_process(obj, data, *args)
else:
return False

View file

@ -5,13 +5,13 @@ from ..utils.deprecated import warn_deprecated as warn
from ..utils.exceptions import FSMStorageWarning
# Leak bucket
KEY = 'key'
LAST_CALL = 'called_at'
RATE_LIMIT = 'rate_limit'
RESULT = 'result'
EXCEEDED_COUNT = 'exceeded'
DELTA = 'delta'
THROTTLE_MANAGER = '$throttle_manager'
KEY = "key"
LAST_CALL = "called_at"
RATE_LIMIT = "rate_limit"
RESULT = "result"
EXCEEDED_COUNT = "exceeded"
DELTA = "delta"
THROTTLE_MANAGER = "$throttle_manager"
class BaseStorage:
@ -38,9 +38,12 @@ class BaseStorage:
raise NotImplementedError
@classmethod
def check_address(cls, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None) -> (typing.Union[str, int], typing.Union[str, int]):
def check_address(
cls,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
) -> (typing.Union[str, int], typing.Union[str, int]):
"""
In all storage's methods chat or user is always required.
If one of them is not provided, you have to set missing value based on the provided one.
@ -52,7 +55,7 @@ class BaseStorage:
:return:
"""
if chat is None and user is None:
raise ValueError('`user` or `chat` parameter is required but no one is provided!')
raise ValueError("`user` or `chat` parameter is required but no one is provided!")
if user is None and chat is not None:
user = chat
@ -60,10 +63,13 @@ class BaseStorage:
chat = user
return chat, user
async def get_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
async def get_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Optional[str]:
"""
Get current state of user in chat. Return `default` if no record is found.
@ -77,10 +83,13 @@ class BaseStorage:
"""
raise NotImplementedError
async def get_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[typing.Dict] = None) -> typing.Dict:
async def get_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[typing.Dict] = None,
) -> typing.Dict:
"""
Get state-data for user in chat. Return `default` if no data is provided in storage.
@ -94,10 +103,13 @@ class BaseStorage:
"""
raise NotImplementedError
async def set_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
async def set_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None,
):
"""
Set new state for user in chat
@ -110,10 +122,13 @@ class BaseStorage:
"""
raise NotImplementedError
async def set_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
async def set_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
):
"""
Set data for user in chat
@ -126,11 +141,14 @@ class BaseStorage:
"""
raise NotImplementedError
async def update_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs):
async def update_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs,
):
"""
Update data for user in chat
@ -147,9 +165,12 @@ class BaseStorage:
"""
raise NotImplementedError
async def reset_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None):
async def reset_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
):
"""
Reset data for user in chat.
@ -162,10 +183,13 @@ class BaseStorage:
"""
await self.set_data(chat=chat, user=user, data={})
async def reset_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
with_data: typing.Optional[bool] = True):
async def reset_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
with_data: typing.Optional[bool] = True,
):
"""
Reset state for user in chat.
You may desire to use this method when finishing conversations.
@ -183,9 +207,12 @@ class BaseStorage:
if with_data:
await self.set_data(chat=chat, user=user, data={})
async def finish(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None):
async def finish(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
):
"""
Finish conversation for user in chat.
@ -201,10 +228,13 @@ class BaseStorage:
def has_bucket(self):
return False
async def get_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None) -> typing.Dict:
async def get_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] = None,
) -> typing.Dict:
"""
Get bucket for user in chat. Return `default` if no data is provided in storage.
@ -218,10 +248,13 @@ class BaseStorage:
"""
raise NotImplementedError
async def set_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None):
async def set_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
):
"""
Set bucket for user in chat
@ -234,11 +267,14 @@ class BaseStorage:
"""
raise NotImplementedError
async def update_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
**kwargs):
async def update_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict = None,
**kwargs,
):
"""
Update bucket for user in chat
@ -255,9 +291,12 @@ class BaseStorage:
"""
raise NotImplementedError
async def reset_bucket(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None):
async def reset_bucket(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
):
"""
Reset bucket dor user in chat.
@ -292,7 +331,9 @@ class FSMContext:
return str(value)
async def get_state(self, default: typing.Optional[str] = None) -> typing.Optional[str]:
return await self.storage.get_state(chat=self.chat, user=self.user, default=self._resolve_state(default))
return await self.storage.get_state(
chat=self.chat, user=self.user, default=self._resolve_state(default)
)
async def get_data(self, default: typing.Optional[str] = None) -> typing.Dict:
return await self.storage.get_data(chat=self.chat, user=self.user, default=default)
@ -301,7 +342,9 @@ class FSMContext:
await self.storage.update_data(chat=self.chat, user=self.user, data=data, **kwargs)
async def set_state(self, state: typing.Union[typing.AnyStr, None] = None):
await self.storage.set_state(chat=self.chat, user=self.user, state=self._resolve_state(state))
await self.storage.set_state(
chat=self.chat, user=self.user, state=self._resolve_state(state)
)
async def set_data(self, data: typing.Dict = None):
await self.storage.set_data(chat=self.chat, user=self.user, data=data)
@ -338,7 +381,7 @@ class FSMContextProxy:
def _check_closed(self):
if self._closed:
raise LookupError('Proxy is closed!')
raise LookupError("Proxy is closed!")
@classmethod
async def create(cls, fsm_context: FSMContext):
@ -447,7 +490,7 @@ class FSMContextProxy:
readable_state = f"'{self.state}'" if self.state else "<default>"
result = f"{self.__class__.__name__} state = {readable_state}, data = {self._data}"
if self._closed:
result += ', closed = True'
result += ", closed = True"
return result
@ -462,39 +505,58 @@ class DisabledStorage(BaseStorage):
async def wait_closed(self):
pass
async def get_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Optional[str]:
async def get_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Optional[str]:
return None
async def get_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None) -> typing.Dict:
async def get_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[str] = None,
) -> typing.Dict:
self._warn()
return {}
async def update_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None, **kwargs):
async def update_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
**kwargs,
):
self._warn()
async def set_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None):
async def set_state(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None,
):
self._warn()
async def set_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None):
async def set_data(
self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict = None,
):
self._warn()
@staticmethod
def _warn():
warn(f"You havent set any storage yet so no states and no data will be saved. \n"
f"You can connect MemoryStorage for debug purposes or non-essential data.",
FSMStorageWarning, 5)
warn(
f"You havent set any storage yet so no states and no data will be saved. \n"
f"You can connect MemoryStorage for debug purposes or non-essential data.",
FSMStorageWarning,
5,
)

File diff suppressed because it is too large Load diff

View file

@ -19,17 +19,46 @@ from .game import Game
from .game_high_score import GameHighScore
from .inline_keyboard import InlineKeyboardButton, InlineKeyboardMarkup
from .inline_query import InlineQuery
from .inline_query_result import InlineQueryResult, InlineQueryResultArticle, InlineQueryResultAudio, \
InlineQueryResultCachedAudio, InlineQueryResultCachedDocument, InlineQueryResultCachedGif, \
InlineQueryResultCachedMpeg4Gif, InlineQueryResultCachedPhoto, InlineQueryResultCachedSticker, \
InlineQueryResultCachedVideo, InlineQueryResultCachedVoice, InlineQueryResultContact, InlineQueryResultDocument, \
InlineQueryResultGame, InlineQueryResultGif, InlineQueryResultLocation, InlineQueryResultMpeg4Gif, \
InlineQueryResultPhoto, InlineQueryResultVenue, InlineQueryResultVideo, InlineQueryResultVoice
from .inline_query_result import (
InlineQueryResult,
InlineQueryResultArticle,
InlineQueryResultAudio,
InlineQueryResultCachedAudio,
InlineQueryResultCachedDocument,
InlineQueryResultCachedGif,
InlineQueryResultCachedMpeg4Gif,
InlineQueryResultCachedPhoto,
InlineQueryResultCachedSticker,
InlineQueryResultCachedVideo,
InlineQueryResultCachedVoice,
InlineQueryResultContact,
InlineQueryResultDocument,
InlineQueryResultGame,
InlineQueryResultGif,
InlineQueryResultLocation,
InlineQueryResultMpeg4Gif,
InlineQueryResultPhoto,
InlineQueryResultVenue,
InlineQueryResultVideo,
InlineQueryResultVoice,
)
from .input_file import InputFile
from .input_media import InputMedia, InputMediaAnimation, InputMediaAudio, InputMediaDocument, InputMediaPhoto, \
InputMediaVideo, MediaGroup
from .input_message_content import InputContactMessageContent, InputLocationMessageContent, InputMessageContent, \
InputTextMessageContent, InputVenueMessageContent
from .input_media import (
InputMedia,
InputMediaAnimation,
InputMediaAudio,
InputMediaDocument,
InputMediaPhoto,
InputMediaVideo,
MediaGroup,
)
from .input_message_content import (
InputContactMessageContent,
InputLocationMessageContent,
InputMessageContent,
InputTextMessageContent,
InputVenueMessageContent,
)
from .invoice import Invoice
from .labeled_price import LabeledPrice
from .location import Location
@ -39,9 +68,15 @@ from .message import ContentType, ContentTypes, Message, ParseMode
from .message_entity import MessageEntity, MessageEntityType
from .order_info import OrderInfo
from .passport_data import PassportData
from .passport_element_error import PassportElementError, PassportElementErrorDataField, PassportElementErrorFile, \
PassportElementErrorFiles, PassportElementErrorFrontSide, PassportElementErrorReverseSide, \
PassportElementErrorSelfie
from .passport_element_error import (
PassportElementError,
PassportElementErrorDataField,
PassportElementErrorFile,
PassportElementErrorFiles,
PassportElementErrorFrontSide,
PassportElementErrorReverseSide,
PassportElementErrorSelfie,
)
from .passport_file import PassportFile
from .photo_size import PhotoSize
from .poll import PollOption, Poll
@ -64,107 +99,107 @@ from .voice import Voice
from .webhook_info import WebhookInfo
__all__ = (
'AllowedUpdates',
'Animation',
'Audio',
'AuthWidgetData',
'CallbackGame',
'CallbackQuery',
'Chat',
'ChatActions',
'ChatMember',
'ChatMemberStatus',
'ChatPhoto',
'ChatType',
'ChosenInlineResult',
'Contact',
'ContentType',
'ContentTypes',
'Document',
'EncryptedCredentials',
'EncryptedPassportElement',
'File',
'ForceReply',
'Game',
'GameHighScore',
'InlineKeyboardButton',
'InlineKeyboardMarkup',
'InlineQuery',
'InlineQueryResult',
'InlineQueryResultArticle',
'InlineQueryResultAudio',
'InlineQueryResultCachedAudio',
'InlineQueryResultCachedDocument',
'InlineQueryResultCachedGif',
'InlineQueryResultCachedMpeg4Gif',
'InlineQueryResultCachedPhoto',
'InlineQueryResultCachedSticker',
'InlineQueryResultCachedVideo',
'InlineQueryResultCachedVoice',
'InlineQueryResultContact',
'InlineQueryResultDocument',
'InlineQueryResultGame',
'InlineQueryResultGif',
'InlineQueryResultLocation',
'InlineQueryResultMpeg4Gif',
'InlineQueryResultPhoto',
'InlineQueryResultVenue',
'InlineQueryResultVideo',
'InlineQueryResultVoice',
'InputContactMessageContent',
'InputFile',
'InputLocationMessageContent',
'InputMedia',
'InputMediaAnimation',
'InputMediaAudio',
'InputMediaDocument',
'InputMediaPhoto',
'InputMediaVideo',
'InputMessageContent',
'InputTextMessageContent',
'InputVenueMessageContent',
'Invoice',
'KeyboardButton',
'LabeledPrice',
'Location',
'LoginUrl',
'MaskPosition',
'MediaGroup',
'Message',
'MessageEntity',
'MessageEntityType',
'OrderInfo',
'ParseMode',
'PassportData',
'PassportElementError',
'PassportElementErrorDataField',
'PassportElementErrorFile',
'PassportElementErrorFiles',
'PassportElementErrorFrontSide',
'PassportElementErrorReverseSide',
'PassportElementErrorSelfie',
'PassportFile',
'PhotoSize',
'Poll',
'PollOption',
'PreCheckoutQuery',
'ReplyKeyboardMarkup',
'ReplyKeyboardRemove',
'ResponseParameters',
'ShippingAddress',
'ShippingOption',
'ShippingQuery',
'Sticker',
'StickerSet',
'SuccessfulPayment',
'Update',
'User',
'UserProfilePhotos',
'Venue',
'Video',
'VideoNote',
'Voice',
'WebhookInfo',
'base',
'fields',
"AllowedUpdates",
"Animation",
"Audio",
"AuthWidgetData",
"CallbackGame",
"CallbackQuery",
"Chat",
"ChatActions",
"ChatMember",
"ChatMemberStatus",
"ChatPhoto",
"ChatType",
"ChosenInlineResult",
"Contact",
"ContentType",
"ContentTypes",
"Document",
"EncryptedCredentials",
"EncryptedPassportElement",
"File",
"ForceReply",
"Game",
"GameHighScore",
"InlineKeyboardButton",
"InlineKeyboardMarkup",
"InlineQuery",
"InlineQueryResult",
"InlineQueryResultArticle",
"InlineQueryResultAudio",
"InlineQueryResultCachedAudio",
"InlineQueryResultCachedDocument",
"InlineQueryResultCachedGif",
"InlineQueryResultCachedMpeg4Gif",
"InlineQueryResultCachedPhoto",
"InlineQueryResultCachedSticker",
"InlineQueryResultCachedVideo",
"InlineQueryResultCachedVoice",
"InlineQueryResultContact",
"InlineQueryResultDocument",
"InlineQueryResultGame",
"InlineQueryResultGif",
"InlineQueryResultLocation",
"InlineQueryResultMpeg4Gif",
"InlineQueryResultPhoto",
"InlineQueryResultVenue",
"InlineQueryResultVideo",
"InlineQueryResultVoice",
"InputContactMessageContent",
"InputFile",
"InputLocationMessageContent",
"InputMedia",
"InputMediaAnimation",
"InputMediaAudio",
"InputMediaDocument",
"InputMediaPhoto",
"InputMediaVideo",
"InputMessageContent",
"InputTextMessageContent",
"InputVenueMessageContent",
"Invoice",
"KeyboardButton",
"LabeledPrice",
"Location",
"LoginUrl",
"MaskPosition",
"MediaGroup",
"Message",
"MessageEntity",
"MessageEntityType",
"OrderInfo",
"ParseMode",
"PassportData",
"PassportElementError",
"PassportElementErrorDataField",
"PassportElementErrorFile",
"PassportElementErrorFiles",
"PassportElementErrorFrontSide",
"PassportElementErrorReverseSide",
"PassportElementErrorSelfie",
"PassportFile",
"PhotoSize",
"Poll",
"PollOption",
"PreCheckoutQuery",
"ReplyKeyboardMarkup",
"ReplyKeyboardRemove",
"ResponseParameters",
"ShippingAddress",
"ShippingOption",
"ShippingQuery",
"Sticker",
"StickerSet",
"SuccessfulPayment",
"Update",
"User",
"UserProfilePhotos",
"Venue",
"Video",
"VideoNote",
"Voice",
"WebhookInfo",
"base",
"fields",
)

View file

@ -10,6 +10,7 @@ class Audio(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#audio
"""
file_id: base.String = fields.Field()
duration: base.Integer = fields.Field()
performer: base.String = fields.Field()

View file

@ -26,11 +26,11 @@ class AuthWidgetData(base.TelegramObject):
"""
try:
query = dict(request.query)
query['id'] = int(query['id'])
query['auth_date'] = int(query['auth_date'])
query["id"] = int(query["id"])
query["auth_date"] = int(query["auth_date"])
widget = AuthWidgetData(**query)
except (ValueError, KeyError):
raise web.HTTPBadRequest(text='Invalid auth data')
raise web.HTTPBadRequest(text="Invalid auth data")
else:
return widget
@ -41,7 +41,7 @@ class AuthWidgetData(base.TelegramObject):
def full_name(self):
result = self.first_name
if self.last_name:
result += ' '
result += " "
result += self.last_name
return result

View file

@ -10,24 +10,33 @@ from .fields import BaseField
from ..utils import json
from ..utils.mixins import ContextInstanceMixin
__all__ = ('MetaTelegramObject', 'TelegramObject', 'InputFile', 'String', 'Integer', 'Float', 'Boolean')
__all__ = (
"MetaTelegramObject",
"TelegramObject",
"InputFile",
"String",
"Integer",
"Float",
"Boolean",
)
PROPS_ATTR_NAME = '_props'
VALUES_ATTR_NAME = '_values'
ALIASES_ATTR_NAME = '_aliases'
PROPS_ATTR_NAME = "_props"
VALUES_ATTR_NAME = "_values"
ALIASES_ATTR_NAME = "_aliases"
# Binding of builtin types
InputFile = TypeVar('InputFile', 'InputFile', io.BytesIO, io.FileIO, str)
String = TypeVar('String', bound=str)
Integer = TypeVar('Integer', bound=int)
Float = TypeVar('Float', bound=float)
Boolean = TypeVar('Boolean', bound=bool)
InputFile = TypeVar("InputFile", "InputFile", io.BytesIO, io.FileIO, str)
String = TypeVar("String", bound=str)
Integer = TypeVar("Integer", bound=int)
Float = TypeVar("Float", bound=float)
Boolean = TypeVar("Boolean", bound=bool)
class MetaTelegramObject(type):
"""
Metaclass for telegram objects
"""
_objects = {}
def __new__(mcs, name, bases, namespace, **kwargs):
@ -46,7 +55,9 @@ class MetaTelegramObject(type):
aliases.update(getattr(base, ALIASES_ATTR_NAME))
# Scan current object for props
for name, prop in ((name, prop) for name, prop in namespace.items() if isinstance(prop, BaseField)):
for name, prop in (
(name, prop) for name, prop in namespace.items() if isinstance(prop, BaseField)
):
props[prop.alias] = prop
if prop.default is not None:
values[prop.alias] = prop.default
@ -147,9 +158,11 @@ class TelegramObject(ContextInstanceMixin, metaclass=MetaTelegramObject):
bot = Bot.get_current()
if bot is None:
raise RuntimeError("Can't get bot instance from context. "
"You can fix it with setting current instance: "
"'Bot.set_current(bot_instance)'")
raise RuntimeError(
"Can't get bot instance from context. "
"You can fix it with setting current instance: "
"'Bot.set_current(bot_instance)'"
)
return bot
def to_python(self) -> typing.Dict:
@ -219,7 +232,7 @@ class TelegramObject(ContextInstanceMixin, metaclass=MetaTelegramObject):
:return:
"""
if key in self.props:
return self.props[key].set_value(self, value, self.conf.get('parent', None))
return self.props[key].set_value(self, value, self.conf.get("parent", None))
raise KeyError(key)
def __contains__(self, item):

View file

@ -7,4 +7,5 @@ class CallbackGame(base.TelegramObject):
https://core.telegram.org/bots/api#callbackgame
"""
pass

View file

@ -20,18 +20,22 @@ class CallbackQuery(base.TelegramObject):
https://core.telegram.org/bots/api#callbackquery
"""
id: base.String = fields.Field()
from_user: User = fields.Field(alias='from', base=User)
from_user: User = fields.Field(alias="from", base=User)
message: Message = fields.Field(base=Message)
inline_message_id: base.String = fields.Field()
chat_instance: base.String = fields.Field()
data: base.String = fields.Field()
game_short_name: base.String = fields.Field()
async def answer(self, text: typing.Union[base.String, None] = None,
show_alert: typing.Union[base.Boolean, None] = None,
url: typing.Union[base.String, None] = None,
cache_time: typing.Union[base.Integer, None] = None):
async def answer(
self,
text: typing.Union[base.String, None] = None,
show_alert: typing.Union[base.Boolean, None] = None,
url: typing.Union[base.String, None] = None,
cache_time: typing.Union[base.Integer, None] = None,
):
"""
Use this method to send answers to callback queries sent from inline keyboards.
The answer will be displayed to the user as a notification at the top of the chat screen or as an alert.
@ -54,8 +58,13 @@ class CallbackQuery(base.TelegramObject):
:type cache_time: :obj:`typing.Union[base.Integer, None]`
:return: On success, True is returned.
:rtype: :obj:`base.Boolean`"""
await self.bot.answer_callback_query(callback_query_id=self.id, text=text,
show_alert=show_alert, url=url, cache_time=cache_time)
await self.bot.answer_callback_query(
callback_query_id=self.id,
text=text,
show_alert=show_alert,
url=url,
cache_time=cache_time,
)
def __hash__(self):
return hash(self.id)

View file

@ -16,6 +16,7 @@ class Chat(base.TelegramObject):
https://core.telegram.org/bots/api#chat
"""
id: base.Integer = fields.Field()
type: base.String = fields.Field()
title: base.String = fields.Field()
@ -26,7 +27,7 @@ class Chat(base.TelegramObject):
photo: ChatPhoto = fields.Field(base=ChatPhoto)
description: base.String = fields.Field()
invite_link: base.String = fields.Field()
pinned_message: 'Message' = fields.Field(base='Message')
pinned_message: "Message" = fields.Field(base="Message")
sticker_set_name: base.String = fields.Field()
can_set_sticker_set: base.Boolean = fields.Field()
@ -38,7 +39,7 @@ class Chat(base.TelegramObject):
if self.type == ChatType.PRIVATE:
full_name = self.first_name
if self.last_name:
full_name += ' ' + self.last_name
full_name += " " + self.last_name
return full_name
return self.title
@ -48,7 +49,7 @@ class Chat(base.TelegramObject):
Get mention if a Chat has a username, or get full name if this is a Private Chat, otherwise None is returned
"""
if self.username:
return '@' + self.username
return "@" + self.username
if self.type == ChatType.PRIVATE:
return self.full_name
return None
@ -56,7 +57,7 @@ class Chat(base.TelegramObject):
@property
def user_url(self):
if self.type != ChatType.PRIVATE:
raise TypeError('`user_url` property is only available in private chats!')
raise TypeError("`user_url` property is only available in private chats!")
return f"tg://user?id={self.id}"
@ -79,7 +80,7 @@ class Chat(base.TelegramObject):
return f"tg://user?id={self.id}"
if self.username:
return f'https://t.me/{self.username}'
return f"https://t.me/{self.username}"
if self.invite_link:
return self.invite_link
@ -161,8 +162,9 @@ class Chat(base.TelegramObject):
"""
return await self.bot.delete_chat_description(self.id, description)
async def kick(self, user_id: base.Integer,
until_date: typing.Union[base.Integer, None] = None):
async def kick(
self, user_id: base.Integer, until_date: typing.Union[base.Integer, None] = None
):
"""
Use this method to kick a user from a group, a supergroup or a channel.
In the case of supergroups and channels, the user will not be able to return to the group
@ -201,12 +203,15 @@ class Chat(base.TelegramObject):
"""
return await self.bot.unban_chat_member(self.id, user_id=user_id)
async def restrict(self, user_id: base.Integer,
until_date: typing.Union[base.Integer, None] = None,
can_send_messages: typing.Union[base.Boolean, None] = None,
can_send_media_messages: typing.Union[base.Boolean, None] = None,
can_send_other_messages: typing.Union[base.Boolean, None] = None,
can_add_web_page_previews: typing.Union[base.Boolean, None] = None) -> base.Boolean:
async def restrict(
self,
user_id: base.Integer,
until_date: typing.Union[base.Integer, None] = None,
can_send_messages: typing.Union[base.Boolean, None] = None,
can_send_media_messages: typing.Union[base.Boolean, None] = None,
can_send_other_messages: typing.Union[base.Boolean, None] = None,
can_add_web_page_previews: typing.Union[base.Boolean, None] = None,
) -> base.Boolean:
"""
Use this method to restrict a user in a supergroup.
The bot must be an administrator in the supergroup for this to work and must have the appropriate admin rights.
@ -232,21 +237,28 @@ class Chat(base.TelegramObject):
:return: Returns True on success.
:rtype: :obj:`base.Boolean`
"""
return await self.bot.restrict_chat_member(self.id, user_id=user_id, until_date=until_date,
can_send_messages=can_send_messages,
can_send_media_messages=can_send_media_messages,
can_send_other_messages=can_send_other_messages,
can_add_web_page_previews=can_add_web_page_previews)
return await self.bot.restrict_chat_member(
self.id,
user_id=user_id,
until_date=until_date,
can_send_messages=can_send_messages,
can_send_media_messages=can_send_media_messages,
can_send_other_messages=can_send_other_messages,
can_add_web_page_previews=can_add_web_page_previews,
)
async def promote(self, user_id: base.Integer,
can_change_info: typing.Union[base.Boolean, None] = None,
can_post_messages: typing.Union[base.Boolean, None] = None,
can_edit_messages: typing.Union[base.Boolean, None] = None,
can_delete_messages: typing.Union[base.Boolean, None] = None,
can_invite_users: typing.Union[base.Boolean, None] = None,
can_restrict_members: typing.Union[base.Boolean, None] = None,
can_pin_messages: typing.Union[base.Boolean, None] = None,
can_promote_members: typing.Union[base.Boolean, None] = None) -> base.Boolean:
async def promote(
self,
user_id: base.Integer,
can_change_info: typing.Union[base.Boolean, None] = None,
can_post_messages: typing.Union[base.Boolean, None] = None,
can_edit_messages: typing.Union[base.Boolean, None] = None,
can_delete_messages: typing.Union[base.Boolean, None] = None,
can_invite_users: typing.Union[base.Boolean, None] = None,
can_restrict_members: typing.Union[base.Boolean, None] = None,
can_pin_messages: typing.Union[base.Boolean, None] = None,
can_promote_members: typing.Union[base.Boolean, None] = None,
) -> base.Boolean:
"""
Use this method to promote or demote a user in a supergroup or a channel.
The bot must be an administrator in the chat for this to work and must have the appropriate admin rights.
@ -277,16 +289,18 @@ class Chat(base.TelegramObject):
:return: Returns True on success.
:rtype: :obj:`base.Boolean`
"""
return await self.bot.promote_chat_member(self.id,
user_id=user_id,
can_change_info=can_change_info,
can_post_messages=can_post_messages,
can_edit_messages=can_edit_messages,
can_delete_messages=can_delete_messages,
can_invite_users=can_invite_users,
can_restrict_members=can_restrict_members,
can_pin_messages=can_pin_messages,
can_promote_members=can_promote_members)
return await self.bot.promote_chat_member(
self.id,
user_id=user_id,
can_change_info=can_change_info,
can_post_messages=can_post_messages,
can_edit_messages=can_edit_messages,
can_delete_messages=can_delete_messages,
can_invite_users=can_invite_users,
can_restrict_members=can_restrict_members,
can_pin_messages=can_pin_messages,
can_promote_members=can_promote_members,
)
async def pin_message(self, message_id: int, disable_notification: bool = False):
"""
@ -422,9 +436,9 @@ class ChatType(helper.Helper):
@staticmethod
def _check(obj, chat_types) -> bool:
if hasattr(obj, 'chat'):
if hasattr(obj, "chat"):
obj = obj.chat
if not hasattr(obj, 'type'):
if not hasattr(obj, "type"):
return False
return obj.type in chat_types
@ -511,12 +525,13 @@ class ChatActions(helper.Helper):
@classmethod
async def _do(cls, action: str, sleep=None):
from aiogram import Bot
await Bot.get_current().send_chat_action(Chat.get_current().id, action)
if sleep:
await asyncio.sleep(sleep)
@classmethod
def calc_timeout(cls, text, timeout=.8):
def calc_timeout(cls, text, timeout=0.8):
"""
Calculate timeout for text

View file

@ -13,6 +13,7 @@ class ChatMember(base.TelegramObject):
https://core.telegram.org/bots/api#chatmember
"""
user: User = fields.Field(base=User)
status: base.String = fields.Field()
until_date: datetime.datetime = fields.DateTimeField()
@ -32,9 +33,12 @@ class ChatMember(base.TelegramObject):
can_add_web_page_previews: base.Boolean = fields.Field()
def is_admin(self):
warnings.warn('`is_admin` method deprecated due to updates in Bot API 4.2. '
'This method renamed to `is_chat_admin` and will be available until aiogram 2.3',
DeprecationWarning, stacklevel=2)
warnings.warn(
"`is_admin` method deprecated due to updates in Bot API 4.2. "
"This method renamed to `is_chat_admin` and will be available until aiogram 2.3",
DeprecationWarning,
stacklevel=2,
)
return self.is_chat_admin()
def is_chat_admin(self):
@ -62,16 +66,22 @@ class ChatMemberStatus(helper.Helper):
@classmethod
def is_admin(cls, role):
warnings.warn('`is_admin` method deprecated due to updates in Bot API 4.2. '
'This method renamed to `is_chat_admin` and will be available until aiogram 2.3',
DeprecationWarning, stacklevel=2)
warnings.warn(
"`is_admin` method deprecated due to updates in Bot API 4.2. "
"This method renamed to `is_chat_admin` and will be available until aiogram 2.3",
DeprecationWarning,
stacklevel=2,
)
return cls.is_chat_admin(role)
@classmethod
def is_member(cls, role):
warnings.warn('`is_member` method deprecated due to updates in Bot API 4.2. '
'This method renamed to `is_chat_member` and will be available until aiogram 2.3',
DeprecationWarning, stacklevel=2)
warnings.warn(
"`is_member` method deprecated due to updates in Bot API 4.2. "
"This method renamed to `is_chat_member` and will be available until aiogram 2.3",
DeprecationWarning,
stacklevel=2,
)
return cls.is_chat_member(role)
@classmethod

View file

@ -11,10 +11,13 @@ class ChatPhoto(base.TelegramObject):
https://core.telegram.org/bots/api#chatphoto
"""
small_file_id: base.String = fields.Field()
big_file_id: base.String = fields.Field()
async def download_small(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True):
async def download_small(
self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True
):
"""
Download file
@ -38,10 +41,17 @@ class ChatPhoto(base.TelegramObject):
if is_path and make_dirs:
os.makedirs(os.path.dirname(destination), exist_ok=True)
return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout,
chunk_size=chunk_size, seek=seek)
return await self.bot.download_file(
file_path=file.file_path,
destination=destination,
timeout=timeout,
chunk_size=chunk_size,
seek=seek,
)
async def download_big(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True):
async def download_big(
self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True
):
"""
Download file
@ -65,8 +75,13 @@ class ChatPhoto(base.TelegramObject):
if is_path and make_dirs:
os.makedirs(os.path.dirname(destination), exist_ok=True)
return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout,
chunk_size=chunk_size, seek=seek)
return await self.bot.download_file(
file_path=file.file_path,
destination=destination,
timeout=timeout,
chunk_size=chunk_size,
seek=seek,
)
async def get_small_file(self):
return await self.bot.get_file(self.small_file_id)

View file

@ -15,8 +15,9 @@ class ChosenInlineResult(base.TelegramObject):
https://core.telegram.org/bots/api#choseninlineresult
"""
result_id: base.String = fields.Field()
from_user: User = fields.Field(alias='from', base=User)
from_user: User = fields.Field(alias="from", base=User)
location: Location = fields.Field(base=Location)
inline_message_id: base.String = fields.Field()
query: base.String = fields.Field()

View file

@ -8,6 +8,7 @@ class Contact(base.TelegramObject):
https://core.telegram.org/bots/api#contact
"""
phone_number: base.String = fields.Field()
first_name: base.String = fields.Field()
last_name: base.String = fields.Field()
@ -18,7 +19,7 @@ class Contact(base.TelegramObject):
def full_name(self):
name = self.first_name
if self.last_name is not None:
name += ' ' + self.last_name
name += " " + self.last_name
return name
def __hash__(self):

View file

@ -10,6 +10,7 @@ class Document(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#document
"""
file_id: base.String = fields.Field()
thumb: PhotoSize = fields.Field(base=PhotoSize)
file_name: base.String = fields.Field()

View file

@ -1,7 +1,7 @@
import abc
import datetime
__all__ = ('BaseField', 'Field', 'ListField', 'DateTimeField', 'TextField', 'ListOfLists')
__all__ = ("BaseField", "Field", "ListField", "DateTimeField", "TextField", "ListOfLists")
class BaseField(metaclass=abc.ABCMeta):
@ -29,7 +29,7 @@ class BaseField(metaclass=abc.ABCMeta):
self.alias = name
def resolve_base(self, instance):
if self.base_object is None or hasattr(self.base_object, 'telegram_types'):
if self.base_object is None or hasattr(self.base_object, "telegram_types"):
return
elif isinstance(self.base_object, str):
self.base_object = instance.telegram_types.get(self.base_object)
@ -100,16 +100,18 @@ class Field(BaseField):
"""
def serialize(self, value):
if self.base_object is not None and hasattr(value, 'to_python'):
if self.base_object is not None and hasattr(value, "to_python"):
return value.to_python()
return value
def deserialize(self, value, parent=None):
if isinstance(value, dict) \
and self.base_object is not None \
and not hasattr(value, 'base_object') \
and not hasattr(value, 'to_python'):
return self.base_object(conf={'parent': parent}, **value)
if (
isinstance(value, dict)
and self.base_object is not None
and not hasattr(value, "base_object")
and not hasattr(value, "to_python")
):
return self.base_object(conf={"parent": parent}, **value)
return value
@ -119,7 +121,7 @@ class ListField(Field):
"""
def __init__(self, *args, **kwargs):
default = kwargs.pop('default', None)
default = kwargs.pop("default", None)
if default is None:
default = []
@ -154,7 +156,7 @@ class ListOfLists(Field):
def deserialize(self, value, parent=None):
result = []
deserialize = super(ListOfLists, self).deserialize
if hasattr(value, '__iter__'):
if hasattr(value, "__iter__"):
for row in value:
row_result = []
for item in row:

View file

@ -16,6 +16,7 @@ class File(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#file
"""
file_id: base.String = fields.Field()
file_size: base.Integer = fields.Field()
file_path: base.String = fields.Field()

View file

@ -22,6 +22,7 @@ class ForceReply(base.TelegramObject):
https://core.telegram.org/bots/api#forcereply
"""
force_reply: base.Boolean = fields.Field(default=True)
selective: base.Boolean = fields.Field()

View file

@ -15,6 +15,7 @@ class Game(base.TelegramObject):
https://core.telegram.org/bots/api#game
"""
title: base.String = fields.Field()
description: base.String = fields.Field()
photo: typing.List[PhotoSize] = fields.ListField(base=PhotoSize)

View file

@ -11,6 +11,7 @@ class GameHighScore(base.TelegramObject):
https://core.telegram.org/bots/api#gamehighscore
"""
position: base.Integer = fields.Field()
user: User = fields.Field(base=User)
score: base.Integer = fields.Field()

View file

@ -15,26 +15,29 @@ class InlineKeyboardMarkup(base.TelegramObject):
https://core.telegram.org/bots/api#inlinekeyboardmarkup
"""
inline_keyboard: 'typing.List[typing.List[InlineKeyboardButton]]' = fields.ListOfLists(base='InlineKeyboardButton')
inline_keyboard: "typing.List[typing.List[InlineKeyboardButton]]" = fields.ListOfLists(
base="InlineKeyboardButton"
)
def __init__(self, row_width=3, inline_keyboard=None, **kwargs):
if inline_keyboard is None:
inline_keyboard = []
conf = kwargs.pop('conf', {}) or {}
conf['row_width'] = row_width
conf = kwargs.pop("conf", {}) or {}
conf["row_width"] = row_width
super(InlineKeyboardMarkup, self).__init__(**kwargs,
conf=conf,
inline_keyboard=inline_keyboard)
super(InlineKeyboardMarkup, self).__init__(
**kwargs, conf=conf, inline_keyboard=inline_keyboard
)
@property
def row_width(self):
return self.conf.get('row_width', 3)
return self.conf.get("row_width", 3)
@row_width.setter
def row_width(self, value):
self.conf['row_width'] = value
self.conf["row_width"] = value
def add(self, *args):
"""
@ -89,6 +92,7 @@ class InlineKeyboardButton(base.TelegramObject):
https://core.telegram.org/bots/api#inlinekeyboardbutton
"""
text: base.String = fields.Field()
url: base.String = fields.Field()
login_url: LoginUrl = fields.Field(base=LoginUrl)
@ -98,19 +102,26 @@ class InlineKeyboardButton(base.TelegramObject):
callback_game: CallbackGame = fields.Field(base=CallbackGame)
pay: base.Boolean = fields.Field()
def __init__(self, text: base.String,
url: base.String = None,
login_url: LoginUrl = None,
callback_data: base.String = None,
switch_inline_query: base.String = None,
switch_inline_query_current_chat: base.String = None,
callback_game: CallbackGame = None,
pay: base.Boolean = None, **kwargs):
super(InlineKeyboardButton, self).__init__(text=text,
url=url,
login_url=login_url,
callback_data=callback_data,
switch_inline_query=switch_inline_query,
switch_inline_query_current_chat=switch_inline_query_current_chat,
callback_game=callback_game,
pay=pay, **kwargs)
def __init__(
self,
text: base.String,
url: base.String = None,
login_url: LoginUrl = None,
callback_data: base.String = None,
switch_inline_query: base.String = None,
switch_inline_query_current_chat: base.String = None,
callback_game: CallbackGame = None,
pay: base.Boolean = None,
**kwargs,
):
super(InlineKeyboardButton, self).__init__(
text=text,
url=url,
login_url=login_url,
callback_data=callback_data,
switch_inline_query=switch_inline_query,
switch_inline_query_current_chat=switch_inline_query_current_chat,
callback_game=callback_game,
pay=pay,
**kwargs,
)

View file

@ -15,19 +15,22 @@ class InlineQuery(base.TelegramObject):
https://core.telegram.org/bots/api#inlinequery
"""
id: base.String = fields.Field()
from_user: User = fields.Field(alias='from', base=User)
from_user: User = fields.Field(alias="from", base=User)
location: Location = fields.Field(base=Location)
query: base.String = fields.Field()
offset: base.String = fields.Field()
async def answer(self,
results: typing.List[InlineQueryResult],
cache_time: typing.Union[base.Integer, None] = None,
is_personal: typing.Union[base.Boolean, None] = None,
next_offset: typing.Union[base.String, None] = None,
switch_pm_text: typing.Union[base.String, None] = None,
switch_pm_parameter: typing.Union[base.String, None] = None):
async def answer(
self,
results: typing.List[InlineQueryResult],
cache_time: typing.Union[base.Integer, None] = None,
is_personal: typing.Union[base.Boolean, None] = None,
next_offset: typing.Union[base.String, None] = None,
switch_pm_text: typing.Union[base.String, None] = None,
switch_pm_parameter: typing.Union[base.String, None] = None,
):
"""
Use this method to send answers to an inline query.
No more than 50 results per query are allowed.
@ -57,10 +60,12 @@ class InlineQuery(base.TelegramObject):
:return: On success, True is returned
:rtype: :obj:`base.Boolean`
"""
return await self.bot.answer_inline_query(self.id,
results=results,
cache_time=cache_time,
is_personal=is_personal,
next_offset=next_offset,
switch_pm_text=switch_pm_text,
switch_pm_parameter=switch_pm_parameter)
return await self.bot.answer_inline_query(
self.id,
results=results,
cache_time=cache_time,
is_personal=is_personal,
next_offset=next_offset,
switch_pm_text=switch_pm_text,
switch_pm_parameter=switch_pm_parameter,
)

File diff suppressed because it is too large Load diff

View file

@ -12,7 +12,7 @@ from ..bot import api
CHUNK_SIZE = 65536
log = logging.getLogger('aiogram')
log = logging.getLogger("aiogram")
class InputFile(base.TelegramObject):
@ -35,7 +35,7 @@ class InputFile(base.TelegramObject):
super(InputFile, self).__init__(conf=conf)
if isinstance(path_or_bytesio, str):
# As path
self._file = open(path_or_bytesio, 'rb')
self._file = open(path_or_bytesio, "rb")
self._path = path_or_bytesio
if filename is None:
filename = os.path.split(path_or_bytesio)[-1]
@ -46,7 +46,7 @@ class InputFile(base.TelegramObject):
self._path = None
self._file = path_or_bytesio
else:
raise TypeError('Not supported file type.')
raise TypeError("Not supported file type.")
self._filename = filename
@ -56,7 +56,7 @@ class InputFile(base.TelegramObject):
"""
Close file descriptor
"""
if not hasattr(self, '_file'):
if not hasattr(self, "_file"):
return
if inspect.iscoroutinefunction(self._file.close):
@ -123,7 +123,7 @@ class InputFile(base.TelegramObject):
:param filename:
:param chunk_size:
"""
with open(filename, 'wb') as fp:
with open(filename, "wb") as fp:
while True:
# Chunk writer
data = self.file.read(chunk_size)
@ -143,11 +143,11 @@ class InputFile(base.TelegramObject):
__repr__ = __str__
def to_python(self):
raise TypeError('Object of this type is not exportable!')
raise TypeError("Object of this type is not exportable!")
@classmethod
def to_object(cls, data):
raise TypeError('Object of this type is not importable!')
raise TypeError("Object of this type is not importable!")
class _WebPipe:
@ -165,7 +165,7 @@ class _WebPipe:
@property
def name(self):
if not self._name:
*_, part = self.url.rpartition('/')
*_, part = self.url.rpartition("/")
if part:
self._name = part
else:
@ -206,7 +206,7 @@ class _WebPipe:
async def read(self, chunk_size=-1):
if not self._response:
raise LookupError('I/O operation on closed stream')
raise LookupError("I/O operation on closed stream")
response: aiohttp.ClientResponse = self._response
reader: aiohttp.StreamReader = response.content
@ -214,6 +214,6 @@ class _WebPipe:
def __str__(self):
result = f"WebPipe url='{self.url}', name='{self.name}'"
return '<' + result + '>'
return "<" + result + ">"
__repr__ = __str__

View file

@ -6,7 +6,7 @@ from . import base
from . import fields
from .input_file import InputFile
ATTACHMENT_PREFIX = 'attach://'
ATTACHMENT_PREFIX = "attach://"
class InputMedia(base.TelegramObject):
@ -22,9 +22,12 @@ class InputMedia(base.TelegramObject):
https://core.telegram.org/bots/api#inputmedia
"""
type: base.String = fields.Field(default='photo')
media: base.String = fields.Field(alias='media', on_change='_media_changed')
thumb: typing.Union[base.InputFile, base.String] = fields.Field(alias='thumb', on_change='_thumb_changed')
type: base.String = fields.Field(default="photo")
media: base.String = fields.Field(alias="media", on_change="_media_changed")
thumb: typing.Union[base.InputFile, base.String] = fields.Field(
alias="thumb", on_change="_thumb_changed"
)
caption: base.String = fields.Field()
parse_mode: base.Boolean = fields.Field()
@ -32,13 +35,13 @@ class InputMedia(base.TelegramObject):
self._thumb_file = None
self._media_file = None
media = kwargs.pop('media', None)
media = kwargs.pop("media", None)
if isinstance(media, (io.IOBase, InputFile)):
self.file = media
elif media is not None:
self.media = media
thumb = kwargs.pop('thumb', None)
thumb = kwargs.pop("thumb", None)
if isinstance(thumb, (io.IOBase, InputFile)):
self.thumb_file = thumb
elif thumb is not None:
@ -58,7 +61,7 @@ class InputMedia(base.TelegramObject):
@file.setter
def file(self, file: io.IOBase):
self.media = 'attach://' + secrets.token_urlsafe(16)
self.media = "attach://" + secrets.token_urlsafe(16)
self._media_file = file
@file.deleter
@ -67,7 +70,7 @@ class InputMedia(base.TelegramObject):
self._media_file = None
def _media_changed(self, value):
if value is None or isinstance(value, str) and not value.startswith('attach://'):
if value is None or isinstance(value, str) and not value.startswith("attach://"):
self._media_file = None
@property
@ -76,7 +79,7 @@ class InputMedia(base.TelegramObject):
@thumb_file.setter
def thumb_file(self, file: io.IOBase):
self.thumb = 'attach://' + secrets.token_urlsafe(16)
self.thumb = "attach://" + secrets.token_urlsafe(16)
self._thumb_file = file
@thumb_file.deleter
@ -85,7 +88,7 @@ class InputMedia(base.TelegramObject):
self._thumb_file = None
def _thumb_changed(self, value):
if value is None or isinstance(value, str) and not value.startswith('attach://'):
if value is None or isinstance(value, str) and not value.startswith("attach://"):
self._thumb_file = None
def get_files(self):
@ -106,14 +109,28 @@ class InputMediaAnimation(InputMedia):
height: base.Integer = fields.Field()
duration: base.Integer = fields.Field()
def __init__(self, media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None,
parse_mode: base.Boolean = None, **kwargs):
super(InputMediaAnimation, self).__init__(type='animation', media=media, thumb=thumb, caption=caption,
width=width, height=height, duration=duration,
parse_mode=parse_mode, conf=kwargs)
def __init__(
self,
media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None,
height: base.Integer = None,
duration: base.Integer = None,
parse_mode: base.Boolean = None,
**kwargs,
):
super(InputMediaAnimation, self).__init__(
type="animation",
media=media,
thumb=thumb,
caption=caption,
width=width,
height=height,
duration=duration,
parse_mode=parse_mode,
conf=kwargs,
)
class InputMediaDocument(InputMedia):
@ -123,11 +140,22 @@ class InputMediaDocument(InputMedia):
https://core.telegram.org/bots/api#inputmediadocument
"""
def __init__(self, media: base.InputFile, thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None, parse_mode: base.Boolean = None, **kwargs):
super(InputMediaDocument, self).__init__(type='document', media=media, thumb=thumb,
caption=caption, parse_mode=parse_mode,
conf=kwargs)
def __init__(
self,
media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
parse_mode: base.Boolean = None,
**kwargs,
):
super(InputMediaDocument, self).__init__(
type="document",
media=media,
thumb=thumb,
caption=caption,
parse_mode=parse_mode,
conf=kwargs,
)
class InputMediaAudio(InputMedia):
@ -143,18 +171,32 @@ class InputMediaAudio(InputMedia):
performer: base.String = fields.Field()
title: base.String = fields.Field()
def __init__(self, media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None, height: base.Integer = None,
duration: base.Integer = None,
performer: base.String = None,
title: base.String = None,
parse_mode: base.Boolean = None, **kwargs):
super(InputMediaAudio, self).__init__(type='audio', media=media, thumb=thumb, caption=caption,
width=width, height=height, duration=duration,
performer=performer, title=title,
parse_mode=parse_mode, conf=kwargs)
def __init__(
self,
media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None,
height: base.Integer = None,
duration: base.Integer = None,
performer: base.String = None,
title: base.String = None,
parse_mode: base.Boolean = None,
**kwargs,
):
super(InputMediaAudio, self).__init__(
type="audio",
media=media,
thumb=thumb,
caption=caption,
width=width,
height=height,
duration=duration,
performer=performer,
title=title,
parse_mode=parse_mode,
conf=kwargs,
)
class InputMediaPhoto(InputMedia):
@ -164,11 +206,22 @@ class InputMediaPhoto(InputMedia):
https://core.telegram.org/bots/api#inputmediaphoto
"""
def __init__(self, media: base.InputFile, thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None, parse_mode: base.Boolean = None, **kwargs):
super(InputMediaPhoto, self).__init__(type='photo', media=media, thumb=thumb,
caption=caption, parse_mode=parse_mode,
conf=kwargs)
def __init__(
self,
media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
parse_mode: base.Boolean = None,
**kwargs,
):
super(InputMediaPhoto, self).__init__(
type="photo",
media=media,
thumb=thumb,
caption=caption,
parse_mode=parse_mode,
conf=kwargs,
)
class InputMediaVideo(InputMedia):
@ -177,21 +230,36 @@ class InputMediaVideo(InputMedia):
https://core.telegram.org/bots/api#inputmediavideo
"""
width: base.Integer = fields.Field()
height: base.Integer = fields.Field()
duration: base.Integer = fields.Field()
supports_streaming: base.Boolean = fields.Field()
def __init__(self, media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None,
parse_mode: base.Boolean = None,
supports_streaming: base.Boolean = None, **kwargs):
super(InputMediaVideo, self).__init__(type='video', media=media, thumb=thumb, caption=caption,
width=width, height=height, duration=duration,
parse_mode=parse_mode,
supports_streaming=supports_streaming, conf=kwargs)
def __init__(
self,
media: base.InputFile,
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None,
height: base.Integer = None,
duration: base.Integer = None,
parse_mode: base.Boolean = None,
supports_streaming: base.Boolean = None,
**kwargs,
):
super(InputMediaVideo, self).__init__(
type="video",
media=media,
thumb=thumb,
caption=caption,
width=width,
height=height,
duration=duration,
parse_mode=parse_mode,
supports_streaming=supports_streaming,
conf=kwargs,
)
class MediaGroup(base.TelegramObject):
@ -199,7 +267,9 @@ class MediaGroup(base.TelegramObject):
Helper for sending media group
"""
def __init__(self, medias: typing.Optional[typing.List[typing.Union[InputMedia, typing.Dict]]] = None):
def __init__(
self, medias: typing.Optional[typing.List[typing.Union[InputMedia, typing.Dict]]] = None
):
super(MediaGroup, self).__init__()
self.media = []
@ -222,13 +292,13 @@ class MediaGroup(base.TelegramObject):
:param media:
"""
if isinstance(media, dict):
if 'type' not in media:
if "type" not in media:
raise ValueError(f"Invalid media!")
media_type = media['type']
if media_type == 'photo':
media_type = media["type"]
if media_type == "photo":
media = InputMediaPhoto(**media)
elif media_type == 'video':
elif media_type == "video":
media = InputMediaVideo(**media)
# elif media_type == 'document':
# media = InputMediaDocument(**media)
@ -240,9 +310,11 @@ class MediaGroup(base.TelegramObject):
raise TypeError(f"Invalid media type '{media_type}'!")
elif not isinstance(media, InputMedia):
raise TypeError(f"Media must be an instance of InputMedia or dict, not {type(media).__name__}")
raise TypeError(
f"Media must be an instance of InputMedia or dict, not {type(media).__name__}"
)
elif media.type in ['document', 'audio', 'animation']:
elif media.type in ["document", "audio", "animation"]:
raise ValueError(f"This type of media is not supported by media groups!")
self.media.append(media)
@ -313,8 +385,9 @@ class MediaGroup(base.TelegramObject):
self.attach(document)
'''
def attach_photo(self, photo: typing.Union[InputMediaPhoto, base.InputFile],
caption: base.String = None):
def attach_photo(
self, photo: typing.Union[InputMediaPhoto, base.InputFile], caption: base.String = None
):
"""
Attach photo
@ -325,10 +398,15 @@ class MediaGroup(base.TelegramObject):
photo = InputMediaPhoto(media=photo, caption=caption)
self.attach(photo)
def attach_video(self, video: typing.Union[InputMediaVideo, base.InputFile],
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None, height: base.Integer = None, duration: base.Integer = None):
def attach_video(
self,
video: typing.Union[InputMediaVideo, base.InputFile],
thumb: typing.Union[base.InputFile, base.String] = None,
caption: base.String = None,
width: base.Integer = None,
height: base.Integer = None,
duration: base.Integer = None,
):
"""
Attach video
@ -339,8 +417,14 @@ class MediaGroup(base.TelegramObject):
:param duration:
"""
if not isinstance(video, InputMedia):
video = InputMediaVideo(media=video, thumb=thumb, caption=caption,
width=width, height=height, duration=duration)
video = InputMediaVideo(
media=video,
thumb=thumb,
caption=caption,
width=width,
height=height,
duration=duration,
)
self.attach(video)
def to_python(self) -> typing.List:

View file

@ -12,6 +12,7 @@ class InputMessageContent(base.TelegramObject):
https://core.telegram.org/bots/api#inputmessagecontent
"""
pass
@ -24,16 +25,21 @@ class InputContactMessageContent(InputMessageContent):
https://core.telegram.org/bots/api#inputcontactmessagecontent
"""
phone_number: base.String = fields.Field()
first_name: base.String = fields.Field()
last_name: base.String = fields.Field()
vcard: base.String = fields.Field()
def __init__(self, phone_number: base.String,
first_name: typing.Optional[base.String] = None,
last_name: typing.Optional[base.String] = None):
super(InputContactMessageContent, self).__init__(phone_number=phone_number, first_name=first_name,
last_name=last_name)
def __init__(
self,
phone_number: base.String,
first_name: typing.Optional[base.String] = None,
last_name: typing.Optional[base.String] = None,
):
super(InputContactMessageContent, self).__init__(
phone_number=phone_number, first_name=first_name, last_name=last_name
)
class InputLocationMessageContent(InputMessageContent):
@ -45,11 +51,11 @@ class InputLocationMessageContent(InputMessageContent):
https://core.telegram.org/bots/api#inputlocationmessagecontent
"""
latitude: base.Float = fields.Field()
longitude: base.Float = fields.Field()
def __init__(self, latitude: base.Float,
longitude: base.Float):
def __init__(self, latitude: base.Float, longitude: base.Float):
super(InputLocationMessageContent, self).__init__(latitude=latitude, longitude=longitude)
@ -59,6 +65,7 @@ class InputTextMessageContent(InputMessageContent):
https://core.telegram.org/bots/api#inputtextmessagecontent
"""
message_text: base.String = fields.Field()
parse_mode: base.String = fields.Field()
disable_web_page_preview: base.Boolean = fields.Field()
@ -69,14 +76,20 @@ class InputTextMessageContent(InputMessageContent):
except RuntimeError:
pass
def __init__(self, message_text: typing.Optional[base.String] = None,
parse_mode: typing.Optional[base.String] = None,
disable_web_page_preview: typing.Optional[base.Boolean] = None):
def __init__(
self,
message_text: typing.Optional[base.String] = None,
parse_mode: typing.Optional[base.String] = None,
disable_web_page_preview: typing.Optional[base.Boolean] = None,
):
if parse_mode is None:
parse_mode = self.safe_get_parse_mode()
super(InputTextMessageContent, self).__init__(message_text=message_text, parse_mode=parse_mode,
disable_web_page_preview=disable_web_page_preview)
super(InputTextMessageContent, self).__init__(
message_text=message_text,
parse_mode=parse_mode,
disable_web_page_preview=disable_web_page_preview,
)
class InputVenueMessageContent(InputMessageContent):
@ -88,16 +101,25 @@ class InputVenueMessageContent(InputMessageContent):
https://core.telegram.org/bots/api#inputvenuemessagecontent
"""
latitude: base.Float = fields.Field()
longitude: base.Float = fields.Field()
title: base.String = fields.Field()
address: base.String = fields.Field()
foursquare_id: base.String = fields.Field()
def __init__(self, latitude: typing.Optional[base.Float] = None,
longitude: typing.Optional[base.Float] = None,
title: typing.Optional[base.String] = None,
address: typing.Optional[base.String] = None,
foursquare_id: typing.Optional[base.String] = None):
super(InputVenueMessageContent, self).__init__(latitude=latitude, longitude=longitude, title=title,
address=address, foursquare_id=foursquare_id)
def __init__(
self,
latitude: typing.Optional[base.Float] = None,
longitude: typing.Optional[base.Float] = None,
title: typing.Optional[base.String] = None,
address: typing.Optional[base.String] = None,
foursquare_id: typing.Optional[base.String] = None,
):
super(InputVenueMessageContent, self).__init__(
latitude=latitude,
longitude=longitude,
title=title,
address=address,
foursquare_id=foursquare_id,
)

View file

@ -8,6 +8,7 @@ class Invoice(base.TelegramObject):
https://core.telegram.org/bots/api#invoice
"""
title: base.String = fields.Field()
description: base.String = fields.Field()
start_parameter: base.String = fields.Field()

View file

@ -8,6 +8,7 @@ class LabeledPrice(base.TelegramObject):
https://core.telegram.org/bots/api#labeledprice
"""
label: base.String = fields.Field()
amount: base.Integer = fields.Field()

View file

@ -8,5 +8,6 @@ class Location(base.TelegramObject):
https://core.telegram.org/bots/api#location
"""
longitude: base.Float = fields.Field()
latitude: base.Float = fields.Field()

View file

@ -10,21 +10,24 @@ class LoginUrl(base.TelegramObject):
https://core.telegram.org/bots/api#loginurl
"""
url: base.String = fields.Field()
forward_text: base.String = fields.Field()
bot_username: base.String = fields.Field()
request_write_access: base.Boolean = fields.Field()
def __init__(self,
url: base.String,
forward_text: base.String = None,
bot_username: base.String = None,
request_write_access: base.Boolean = None,
**kwargs):
def __init__(
self,
url: base.String,
forward_text: base.String = None,
bot_username: base.String = None,
request_write_access: base.Boolean = None,
**kwargs,
):
super(LoginUrl, self).__init__(
url=url,
forward_text=forward_text,
bot_username=bot_username,
request_write_access=request_write_access,
**kwargs
**kwargs,
)

View file

@ -8,6 +8,7 @@ class MaskPosition(base.TelegramObject):
https://core.telegram.org/bots/api#maskposition
"""
point: base.String = fields.Field()
x_shift: base.Float = fields.Field()
y_shift: base.Float = fields.Field()

File diff suppressed because it is too large Load diff

View file

@ -12,6 +12,7 @@ class MessageEntity(base.TelegramObject):
https://core.telegram.org/bots/api#messageentity
"""
type: base.String = fields.Field()
offset: base.Integer = fields.Field()
length: base.Integer = fields.Field()
@ -25,16 +26,16 @@ class MessageEntity(base.TelegramObject):
:param text: full text
:return: part of text
"""
if sys.maxunicode == 0xffff:
return text[self.offset:self.offset + self.length]
if sys.maxunicode == 0xFFFF:
return text[self.offset : self.offset + self.length]
if not isinstance(text, bytes):
entity_text = text.encode('utf-16-le')
entity_text = text.encode("utf-16-le")
else:
entity_text = text
entity_text = entity_text[self.offset * 2:(self.offset + self.length) * 2]
return entity_text.decode('utf-16-le')
entity_text = entity_text[self.offset * 2 : (self.offset + self.length) * 2]
return entity_text.decode("utf-16-le")
def parse(self, text, as_html=True):
"""
@ -95,6 +96,7 @@ class MessageEntityType(helper.Helper):
:key: TEXT_LINK
:key: TEXT_MENTION
"""
mode = helper.HelperMode.snake_case
MENTION = helper.Item() # mention - @username

View file

@ -7,7 +7,9 @@ class Downloadable:
Mixin for files
"""
async def download(self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True):
async def download(
self, destination=None, timeout=30, chunk_size=65536, seek=True, make_dirs=True
):
"""
Download file
@ -31,8 +33,13 @@ class Downloadable:
if is_path and make_dirs:
os.makedirs(os.path.dirname(destination), exist_ok=True)
return await self.bot.download_file(file_path=file.file_path, destination=destination, timeout=timeout,
chunk_size=chunk_size, seek=seek)
return await self.bot.download_file(
file_path=file.file_path,
destination=destination,
timeout=timeout,
chunk_size=chunk_size,
seek=seek,
)
async def get_file(self):
"""
@ -40,7 +47,7 @@ class Downloadable:
:return: :obj:`aiogram.types.File`
"""
if hasattr(self, 'file_path'):
if hasattr(self, "file_path"):
return self
else:
return await self.bot.get_file(self.file_id)

View file

@ -9,6 +9,7 @@ class OrderInfo(base.TelegramObject):
https://core.telegram.org/bots/api#orderinfo
"""
name: base.String = fields.Field()
phone_number: base.String = fields.Field()
email: base.String = fields.Field()

View file

@ -28,10 +28,17 @@ class PassportElementErrorDataField(PassportElementError):
field_name: base.String = fields.Field()
data_hash: base.String = fields.Field()
def __init__(self, source: base.String, type: base.String, field_name: base.String,
data_hash: base.String, message: base.String):
super(PassportElementErrorDataField, self).__init__(source=source, type=type, field_name=field_name,
data_hash=data_hash, message=message)
def __init__(
self,
source: base.String,
type: base.String,
field_name: base.String,
data_hash: base.String,
message: base.String,
):
super(PassportElementErrorDataField, self).__init__(
source=source, type=type, field_name=field_name, data_hash=data_hash, message=message
)
class PassportElementErrorFile(PassportElementError):
@ -44,9 +51,12 @@ class PassportElementErrorFile(PassportElementError):
file_hash: base.String = fields.Field()
def __init__(self, source: base.String, type: base.String, file_hash: base.String, message: base.String):
super(PassportElementErrorFile, self).__init__(source=source, type=type, file_hash=file_hash,
message=message)
def __init__(
self, source: base.String, type: base.String, file_hash: base.String, message: base.String
):
super(PassportElementErrorFile, self).__init__(
source=source, type=type, file_hash=file_hash, message=message
)
class PassportElementErrorFiles(PassportElementError):
@ -59,10 +69,16 @@ class PassportElementErrorFiles(PassportElementError):
file_hashes: typing.List[base.String] = fields.ListField()
def __init__(self, source: base.String, type: base.String, file_hashes: typing.List[base.String],
message: base.String):
super(PassportElementErrorFiles, self).__init__(source=source, type=type, file_hashes=file_hashes,
message=message)
def __init__(
self,
source: base.String,
type: base.String,
file_hashes: typing.List[base.String],
message: base.String,
):
super(PassportElementErrorFiles, self).__init__(
source=source, type=type, file_hashes=file_hashes, message=message
)
class PassportElementErrorFrontSide(PassportElementError):
@ -75,9 +91,12 @@ class PassportElementErrorFrontSide(PassportElementError):
file_hash: base.String = fields.Field()
def __init__(self, source: base.String, type: base.String, file_hash: base.String, message: base.String):
super(PassportElementErrorFrontSide, self).__init__(source=source, type=type, file_hash=file_hash,
message=message)
def __init__(
self, source: base.String, type: base.String, file_hash: base.String, message: base.String
):
super(PassportElementErrorFrontSide, self).__init__(
source=source, type=type, file_hash=file_hash, message=message
)
class PassportElementErrorReverseSide(PassportElementError):
@ -90,9 +109,12 @@ class PassportElementErrorReverseSide(PassportElementError):
file_hash: base.String = fields.Field()
def __init__(self, source: base.String, type: base.String, file_hash: base.String, message: base.String):
super(PassportElementErrorReverseSide, self).__init__(source=source, type=type, file_hash=file_hash,
message=message)
def __init__(
self, source: base.String, type: base.String, file_hash: base.String, message: base.String
):
super(PassportElementErrorReverseSide, self).__init__(
source=source, type=type, file_hash=file_hash, message=message
)
class PassportElementErrorSelfie(PassportElementError):
@ -105,6 +127,9 @@ class PassportElementErrorSelfie(PassportElementError):
file_hash: base.String = fields.Field()
def __init__(self, source: base.String, type: base.String, file_hash: base.String, message: base.String):
super(PassportElementErrorSelfie, self).__init__(source=source, type=type, file_hash=file_hash,
message=message)
def __init__(
self, source: base.String, type: base.String, file_hash: base.String, message: base.String
):
super(PassportElementErrorSelfie, self).__init__(
source=source, type=type, file_hash=file_hash, message=message
)

View file

@ -9,6 +9,7 @@ class PhotoSize(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#photosize
"""
file_id: base.String = fields.Field()
width: base.Integer = fields.Field()
height: base.Integer = fields.Field()

View file

@ -17,8 +17,9 @@ class PreCheckoutQuery(base.TelegramObject):
https://core.telegram.org/bots/api#precheckoutquery
"""
id: base.String = fields.Field()
from_user: User = fields.Field(alias='from', base=User)
from_user: User = fields.Field(alias="from", base=User)
currency: base.String = fields.Field()
total_amount: base.Integer = fields.Field()
invoice_payload: base.String = fields.Field()

View file

@ -10,27 +10,37 @@ class ReplyKeyboardMarkup(base.TelegramObject):
https://core.telegram.org/bots/api#replykeyboardmarkup
"""
keyboard: 'typing.List[typing.List[KeyboardButton]]' = fields.ListOfLists(base='KeyboardButton', default=[])
keyboard: "typing.List[typing.List[KeyboardButton]]" = fields.ListOfLists(
base="KeyboardButton", default=[]
)
resize_keyboard: base.Boolean = fields.Field()
one_time_keyboard: base.Boolean = fields.Field()
selective: base.Boolean = fields.Field()
def __init__(self, keyboard: 'typing.List[typing.List[KeyboardButton]]' = None,
resize_keyboard: base.Boolean = None,
one_time_keyboard: base.Boolean = None,
selective: base.Boolean = None,
row_width: base.Integer = 3):
super(ReplyKeyboardMarkup, self).__init__(keyboard=keyboard, resize_keyboard=resize_keyboard,
one_time_keyboard=one_time_keyboard, selective=selective,
conf={'row_width': row_width})
def __init__(
self,
keyboard: "typing.List[typing.List[KeyboardButton]]" = None,
resize_keyboard: base.Boolean = None,
one_time_keyboard: base.Boolean = None,
selective: base.Boolean = None,
row_width: base.Integer = 3,
):
super(ReplyKeyboardMarkup, self).__init__(
keyboard=keyboard,
resize_keyboard=resize_keyboard,
one_time_keyboard=one_time_keyboard,
selective=selective,
conf={"row_width": row_width},
)
@property
def row_width(self):
return self.conf.get('row_width', 3)
return self.conf.get("row_width", 3)
@row_width.setter
def row_width(self, value):
self.conf['row_width'] = value
self.conf["row_width"] = value
def add(self, *args):
"""
@ -86,16 +96,20 @@ class KeyboardButton(base.TelegramObject):
https://core.telegram.org/bots/api#keyboardbutton
"""
text: base.String = fields.Field()
request_contact: base.Boolean = fields.Field()
request_location: base.Boolean = fields.Field()
def __init__(self, text: base.String,
request_contact: base.Boolean = None,
request_location: base.Boolean = None):
super(KeyboardButton, self).__init__(text=text,
request_contact=request_contact,
request_location=request_location)
def __init__(
self,
text: base.String,
request_contact: base.Boolean = None,
request_location: base.Boolean = None,
):
super(KeyboardButton, self).__init__(
text=text, request_contact=request_contact, request_location=request_location
)
class ReplyKeyboardRemove(base.TelegramObject):
@ -104,6 +118,7 @@ class ReplyKeyboardRemove(base.TelegramObject):
https://core.telegram.org/bots/api#replykeyboardremove
"""
remove_keyboard: base.Boolean = fields.Field(default=True)
selective: base.Boolean = fields.Field()

View file

@ -8,5 +8,6 @@ class ResponseParameters(base.TelegramObject):
https://core.telegram.org/bots/api#responseparameters
"""
migrate_to_chat_id: base.Integer = fields.Field()
retry_after: base.Integer = fields.Field()

View file

@ -8,6 +8,7 @@ class ShippingAddress(base.TelegramObject):
https://core.telegram.org/bots/api#shippingaddress
"""
country_code: base.String = fields.Field()
state: base.String = fields.Field()
city: base.String = fields.Field()

View file

@ -11,11 +11,14 @@ class ShippingOption(base.TelegramObject):
https://core.telegram.org/bots/api#shippingoption
"""
id: base.String = fields.Field()
title: base.String = fields.Field()
prices: typing.List[LabeledPrice] = fields.ListField(base=LabeledPrice)
def __init__(self, id: base.String, title: base.String, prices: typing.List[LabeledPrice] = None):
def __init__(
self, id: base.String, title: base.String, prices: typing.List[LabeledPrice] = None
):
if prices is None:
prices = []

View file

@ -10,8 +10,9 @@ class ShippingQuery(base.TelegramObject):
https://core.telegram.org/bots/api#shippingquery
"""
id: base.String = fields.Field()
from_user: User = fields.Field(alias='from', base=User)
from_user: User = fields.Field(alias="from", base=User)
invoice_payload: base.String = fields.Field()
shipping_address: ShippingAddress = fields.Field(base=ShippingAddress)

View file

@ -11,6 +11,7 @@ class Sticker(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#sticker
"""
file_id: base.String = fields.Field()
width: base.Integer = fields.Field()
height: base.Integer = fields.Field()

View file

@ -11,6 +11,7 @@ class StickerSet(base.TelegramObject):
https://core.telegram.org/bots/api#stickerset
"""
name: base.String = fields.Field()
title: base.String = fields.Field()
contains_masks: base.Boolean = fields.Field()

View file

@ -9,6 +9,7 @@ class SuccessfulPayment(base.TelegramObject):
https://core.telegram.org/bots/api#successfulpayment
"""
currency: base.String = fields.Field()
total_amount: base.Integer = fields.Field()
invoice_payload: base.String = fields.Field()

View file

@ -19,6 +19,7 @@ class Update(base.TelegramObject):
https://core.telegram.org/bots/api#update
"""
update_id: base.Integer = fields.Field()
message: Message = fields.Field(base=Message)
edited_message: Message = fields.Field(base=Message)
@ -47,6 +48,7 @@ class AllowedUpdates(helper.Helper):
Example:
>>> bot.get_updates(allowed_updates=AllowedUpdates.MESSAGE + AllowedUpdates.EDITED_MESSAGE)
"""
mode = helper.HelperMode.snake_case
MESSAGE = helper.ListItem() # message

View file

@ -13,6 +13,7 @@ class User(base.TelegramObject):
https://core.telegram.org/bots/api#user
"""
id: base.Integer = fields.Field()
is_bot: base.Boolean = fields.Field()
first_name: base.String = fields.Field()
@ -29,7 +30,7 @@ class User(base.TelegramObject):
"""
full_name = self.first_name
if self.last_name:
full_name += ' ' + self.last_name
full_name += " " + self.last_name
return full_name
@property
@ -41,7 +42,7 @@ class User(base.TelegramObject):
:return: str
"""
if self.username:
return '@' + self.username
return "@" + self.username
return self.full_name
@property
@ -53,16 +54,16 @@ class User(base.TelegramObject):
"""
if not self.language_code:
return None
if not hasattr(self, '_locale'):
setattr(self, '_locale', babel.core.Locale.parse(self.language_code, sep='-'))
return getattr(self, '_locale')
if not hasattr(self, "_locale"):
setattr(self, "_locale", babel.core.Locale.parse(self.language_code, sep="-"))
return getattr(self, "_locale")
@property
def url(self):
return f"tg://user?id={self.id}"
def get_mention(self, name=None, as_html=None):
if as_html is None and self.bot.parse_mode and self.bot.parse_mode.lower() == 'html':
if as_html is None and self.bot.parse_mode and self.bot.parse_mode.lower() == "html":
as_html = True
if name is None:

View file

@ -11,5 +11,6 @@ class UserProfilePhotos(base.TelegramObject):
https://core.telegram.org/bots/api#userprofilephotos
"""
total_count: base.Integer = fields.Field()
photos: typing.List[typing.List[PhotoSize]] = fields.ListOfLists(base=PhotoSize)

View file

@ -9,6 +9,7 @@ class Venue(base.TelegramObject):
https://core.telegram.org/bots/api#venue
"""
location: Location = fields.Field(base=Location)
title: base.String = fields.Field()
address: base.String = fields.Field()

View file

@ -10,6 +10,7 @@ class Video(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#video
"""
file_id: base.String = fields.Field()
width: base.Integer = fields.Field()
height: base.Integer = fields.Field()

View file

@ -10,6 +10,7 @@ class VideoNote(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#videonote
"""
file_id: base.String = fields.Field()
length: base.Integer = fields.Field()
duration: base.Integer = fields.Field()

View file

@ -9,6 +9,7 @@ class Voice(base.TelegramObject, mixins.Downloadable):
https://core.telegram.org/bots/api#voice
"""
file_id: base.String = fields.Field()
duration: base.Integer = fields.Field()
mime_type: base.String = fields.Field()

View file

@ -10,6 +10,7 @@ class WebhookInfo(base.TelegramObject):
https://core.telegram.org/bots/api#webhookinfo
"""
url: base.String = fields.Field()
has_custom_certificate: base.Boolean = fields.Field()
pending_update_count: base.Integer = fields.Field()

View file

@ -18,10 +18,10 @@ def generate_hash(data: dict, token: str) -> str:
:return:
"""
secret = hashlib.sha256()
secret.update(token.encode('utf-8'))
secret.update(token.encode("utf-8"))
sorted_params = collections.OrderedDict(sorted(data.items()))
msg = '\n'.join("{}={}".format(k, v) for k, v in sorted_params.items() if k != 'hash')
return hmac.new(secret.digest(), msg.encode('utf-8'), digestmod=hashlib.sha256).hexdigest()
msg = "\n".join("{}={}".format(k, v) for k, v in sorted_params.items() if k != "hash")
return hmac.new(secret.digest(), msg.encode("utf-8"), digestmod=hashlib.sha256).hexdigest()
def check_token(data: dict, token: str) -> bool:
@ -32,5 +32,5 @@ def check_token(data: dict, token: str) -> bool:
:param token:
:return:
"""
param_hash = data.get('hash', '') or ''
param_hash = data.get("hash", "") or ""
return param_hash == generate_hash(data, token)

View file

@ -26,15 +26,15 @@ class CallbackData:
Callback data factory
"""
def __init__(self, prefix, *parts, sep=':'):
def __init__(self, prefix, *parts, sep=":"):
if not isinstance(prefix, str):
raise TypeError(f"Prefix must be instance of str not {type(prefix).__name__}")
elif not prefix:
raise ValueError('Prefix can\'t be empty')
raise ValueError("Prefix can't be empty")
elif sep in prefix:
raise ValueError(f"Separator '{sep}' can't be used in prefix")
elif not parts:
raise TypeError('Parts is not passed!')
raise TypeError("Parts is not passed!")
self.prefix = prefix
self.sep = sep
@ -72,11 +72,11 @@ class CallbackData:
data.append(value)
if args or kwargs:
raise TypeError('Too many arguments is passed!')
raise TypeError("Too many arguments is passed!")
callback_data = self.sep.join(data)
if len(callback_data) > 64:
raise ValueError('Resulted callback data is too long!')
raise ValueError("Resulted callback data is too long!")
return callback_data
@ -91,9 +91,9 @@ class CallbackData:
if prefix != self.prefix:
raise ValueError("Passed callback data can't be parsed with that prefix.")
elif len(parts) != len(self._part_names):
raise ValueError('Invalid parts count!')
raise ValueError("Invalid parts count!")
result = {'@': prefix}
result = {"@": prefix}
result.update(zip(self._part_names, parts))
return result
@ -117,7 +117,7 @@ class CallbackDataFilter(Filter):
@classmethod
def validate(cls, full_config: typing.Dict[str, typing.Any]):
raise ValueError('That filter can\'t be used in filters factory!')
raise ValueError("That filter can't be used in filters factory!")
async def check(self, query: types.CallbackQuery):
try:
@ -132,4 +132,4 @@ class CallbackDataFilter(Filter):
else:
if value != data.get(key):
return False
return {'callback_data': data}
return {"callback_data": data}

View file

@ -34,7 +34,7 @@ def deprecated(reason):
@functools.wraps(func)
def wrapper(*args, **kwargs):
warn_deprecated(msg.format(name=func.__name__, reason=reason))
warnings.simplefilter('default', DeprecationWarning)
warnings.simplefilter("default", DeprecationWarning)
return func(*args, **kwargs)
return wrapper
@ -70,6 +70,6 @@ def deprecated(reason):
def warn_deprecated(message, warning=DeprecationWarning, stacklevel=2):
warnings.simplefilter('always', warning)
warnings.simplefilter("always", warning)
warnings.warn(message, category=warning, stacklevel=stacklevel)
warnings.simplefilter('default', warning)
warnings.simplefilter("default", warning)

View file

@ -90,13 +90,13 @@ import time
# TODO: Use exceptions detector from `aiograph`.
_PREFIXES = ['error: ', '[error]: ', 'bad request: ', 'conflict: ', 'not found: ']
_PREFIXES = ["error: ", "[error]: ", "bad request: ", "conflict: ", "not found: "]
def _clean_message(text):
for prefix in _PREFIXES:
if text.startswith(prefix):
text = text[len(prefix):]
text = text[len(prefix) :]
return (text[0].upper() + text[1:]).strip()
@ -106,7 +106,7 @@ class TelegramAPIError(Exception):
class _MatchErrorMixin:
match = ''
match = ""
text = None
__subclasses = []
@ -166,67 +166,72 @@ class MessageNotModified(MessageError):
"""
Will be raised when you try to set new text is equals to current text.
"""
match = 'message is not modified'
match = "message is not modified"
class MessageToForwardNotFound(MessageError):
"""
Will be raised when you try to forward very old or deleted or unknown message.
"""
match = 'message to forward not found'
match = "message to forward not found"
class MessageToDeleteNotFound(MessageError):
"""
Will be raised when you try to delete very old or deleted or unknown message.
"""
match = 'message to delete not found'
match = "message to delete not found"
class MessageToReplyNotFound(MessageError):
"""
Will be raised when you try to reply to very old or deleted or unknown message.
"""
match = 'message to reply not found'
match = "message to reply not found"
class MessageIdentifierNotSpecified(MessageError):
match = 'message identifier is not specified'
match = "message identifier is not specified"
class MessageTextIsEmpty(MessageError):
match = 'Message text is empty'
match = "Message text is empty"
class MessageCantBeEdited(MessageError):
match = 'message can\'t be edited'
match = "message can't be edited"
class MessageCantBeDeleted(MessageError):
match = 'message can\'t be deleted'
match = "message can't be deleted"
class MessageToEditNotFound(MessageError):
match = 'message to edit not found'
match = "message to edit not found"
class MessageIsTooLong(MessageError):
match = 'message is too long'
match = "message is too long"
class ToMuchMessages(MessageError):
"""
Will be raised when you try to send media group with more than 10 items.
"""
match = 'Too much messages to send as an album'
match = "Too much messages to send as an album"
class ObjectExpectedAsReplyMarkup(BadRequest):
match = 'object expected as reply markup'
match = "object expected as reply markup"
class InlineKeyboardExpected(BadRequest):
match = 'inline keyboard expected'
match = "inline keyboard expected"
class PollError(BadRequest):
@ -238,7 +243,7 @@ class PollCantBeStopped(PollError):
class PollHasAlreadyBeenClosed(PollError):
match = 'poll has already been closed'
match = "poll has already been closed"
class PollsCantBeSentToPrivateChats(PollError):
@ -277,109 +282,112 @@ class MessageWithPollNotFound(PollError, MessageError):
"""
Will be raised when you try to stop poll with message without poll
"""
match = 'message with poll to stop not found'
match = "message with poll to stop not found"
class MessageIsNotAPoll(PollError, MessageError):
"""
Will be raised when you try to stop poll with message without poll
"""
match = 'message is not a poll'
match = "message is not a poll"
class ChatNotFound(BadRequest):
match = 'chat not found'
match = "chat not found"
class ChatIdIsEmpty(BadRequest):
match = 'chat_id is empty'
match = "chat_id is empty"
class InvalidUserId(BadRequest):
match = 'user_id_invalid'
text = 'Invalid user id'
match = "user_id_invalid"
text = "Invalid user id"
class ChatDescriptionIsNotModified(BadRequest):
match = 'chat description is not modified'
match = "chat description is not modified"
class InvalidQueryID(BadRequest):
match = 'query is too old and response timeout expired or query id is invalid'
match = "query is too old and response timeout expired or query id is invalid"
class InvalidPeerID(BadRequest):
match = 'PEER_ID_INVALID'
text = 'Invalid peer ID'
match = "PEER_ID_INVALID"
text = "Invalid peer ID"
class InvalidHTTPUrlContent(BadRequest):
match = 'Failed to get HTTP URL content'
match = "Failed to get HTTP URL content"
class ButtonURLInvalid(BadRequest):
match = 'BUTTON_URL_INVALID'
text = 'Button URL invalid'
match = "BUTTON_URL_INVALID"
text = "Button URL invalid"
class URLHostIsEmpty(BadRequest):
match = 'URL host is empty'
match = "URL host is empty"
class StartParamInvalid(BadRequest):
match = 'START_PARAM_INVALID'
text = 'Start param invalid'
match = "START_PARAM_INVALID"
text = "Start param invalid"
class ButtonDataInvalid(BadRequest):
match = 'BUTTON_DATA_INVALID'
text = 'Button data invalid'
match = "BUTTON_DATA_INVALID"
text = "Button data invalid"
class WrongFileIdentifier(BadRequest):
match = 'wrong file identifier/HTTP URL specified'
match = "wrong file identifier/HTTP URL specified"
class GroupDeactivated(BadRequest):
match = 'group is deactivated'
match = "group is deactivated"
class PhotoAsInputFileRequired(BadRequest):
"""
Will be raised when you try to set chat photo from file ID.
"""
match = 'Photo should be uploaded as an InputFile'
match = "Photo should be uploaded as an InputFile"
class InvalidStickersSet(BadRequest):
match = 'STICKERSET_INVALID'
text = 'Stickers set is invalid'
match = "STICKERSET_INVALID"
text = "Stickers set is invalid"
class NoStickerInRequest(BadRequest):
match = 'there is no sticker in the request'
match = "there is no sticker in the request"
class ChatAdminRequired(BadRequest):
match = 'CHAT_ADMIN_REQUIRED'
text = 'Admin permissions is required!'
match = "CHAT_ADMIN_REQUIRED"
text = "Admin permissions is required!"
class NeedAdministratorRightsInTheChannel(BadRequest):
match = 'need administrator rights in the channel chat'
text = 'Admin permissions is required!'
match = "need administrator rights in the channel chat"
text = "Admin permissions is required!"
class NotEnoughRightsToPinMessage(BadRequest):
match = 'not enough rights to pin a message'
match = "not enough rights to pin a message"
class MethodNotAvailableInPrivateChats(BadRequest):
match = 'method is available only for supergroups and channel'
match = "method is available only for supergroups and channel"
class CantDemoteChatCreator(BadRequest):
match = 'can\'t demote chat creator'
match = "can't demote chat creator"
class CantRestrictSelf(BadRequest):
@ -388,34 +396,34 @@ class CantRestrictSelf(BadRequest):
class NotEnoughRightsToRestrict(BadRequest):
match = 'not enough rights to restrict/unrestrict chat member'
match = "not enough rights to restrict/unrestrict chat member"
class PhotoDimensions(BadRequest):
match = 'PHOTO_INVALID_DIMENSIONS'
text = 'Invalid photo dimensions'
match = "PHOTO_INVALID_DIMENSIONS"
text = "Invalid photo dimensions"
class UnavailableMembers(BadRequest):
match = 'supergroup members are unavailable'
match = "supergroup members are unavailable"
class TypeOfFileMismatch(BadRequest):
match = 'type of file mismatch'
match = "type of file mismatch"
class WrongRemoteFileIdSpecified(BadRequest):
match = 'wrong remote file id specified'
match = "wrong remote file id specified"
class PaymentProviderInvalid(BadRequest):
match = 'PAYMENT_PROVIDER_INVALID'
text = 'payment provider invalid'
match = "PAYMENT_PROVIDER_INVALID"
text = "payment provider invalid"
class CurrencyTotalAmountInvalid(BadRequest):
match = 'currency_total_amount_invalid'
text = 'currency total amount invalid'
match = "currency_total_amount_invalid"
text = "currency total amount invalid"
class BadWebhook(BadRequest):
@ -423,44 +431,44 @@ class BadWebhook(BadRequest):
class WebhookRequireHTTPS(BadWebhook):
match = 'HTTPS url must be provided for webhook'
text = 'bad webhook: ' + match
match = "HTTPS url must be provided for webhook"
text = "bad webhook: " + match
class BadWebhookPort(BadWebhook):
match = 'Webhook can be set up only on ports 80, 88, 443 or 8443'
text = 'bad webhook: ' + match
match = "Webhook can be set up only on ports 80, 88, 443 or 8443"
text = "bad webhook: " + match
class BadWebhookAddrInfo(BadWebhook):
match = 'getaddrinfo: Temporary failure in name resolution'
text = 'bad webhook: ' + match
match = "getaddrinfo: Temporary failure in name resolution"
text = "bad webhook: " + match
class BadWebhookNoAddressAssociatedWithHostname(BadWebhook):
match = 'failed to resolve host: no address associated with hostname'
match = "failed to resolve host: no address associated with hostname"
class CantParseUrl(BadRequest):
match = 'can\'t parse URL'
match = "can't parse URL"
class UnsupportedUrlProtocol(BadRequest):
match = 'unsupported URL protocol'
match = "unsupported URL protocol"
class CantParseEntities(BadRequest):
match = 'can\'t parse entities'
match = "can't parse entities"
class ResultIdDuplicate(BadRequest):
match = 'result_id_duplicate'
text = 'Result ID duplicate'
match = "result_id_duplicate"
text = "Result ID duplicate"
class BotDomainInvalid(BadRequest):
match = 'bot_domain_invalid'
text = 'Invalid bot domain'
match = "bot_domain_invalid"
text = "Invalid bot domain"
class NotFound(TelegramAPIError, _MatchErrorMixin):
@ -468,7 +476,7 @@ class NotFound(TelegramAPIError, _MatchErrorMixin):
class MethodNotKnown(NotFound):
match = 'method not found'
match = "method not found"
class ConflictError(TelegramAPIError, _MatchErrorMixin):
@ -476,13 +484,15 @@ class ConflictError(TelegramAPIError, _MatchErrorMixin):
class TerminatedByOtherGetUpdates(ConflictError):
match = 'terminated by other getUpdates request'
text = 'Terminated by other getUpdates request; ' \
'Make sure that only one bot instance is running'
match = "terminated by other getUpdates request"
text = (
"Terminated by other getUpdates request; "
"Make sure that only one bot instance is running"
)
class CantGetUpdates(ConflictError):
match = 'can\'t use getUpdates method while webhook is active'
match = "can't use getUpdates method while webhook is active"
class Unauthorized(TelegramAPIError, _MatchErrorMixin):
@ -490,23 +500,23 @@ class Unauthorized(TelegramAPIError, _MatchErrorMixin):
class BotKicked(Unauthorized):
match = 'Bot was kicked from a chat'
match = "Bot was kicked from a chat"
class BotBlocked(Unauthorized):
match = 'bot was blocked by the user'
match = "bot was blocked by the user"
class UserDeactivated(Unauthorized):
match = 'user is deactivated'
match = "user is deactivated"
class CantInitiateConversation(Unauthorized):
match = 'bot can\'t initiate conversation with a user'
match = "bot can't initiate conversation with a user"
class CantTalkWithBots(Unauthorized):
match = 'bot can\'t send messages to bots'
match = "bot can't send messages to bots"
class NetworkError(TelegramAPIError):
@ -515,34 +525,43 @@ class NetworkError(TelegramAPIError):
class RestartingTelegram(TelegramAPIError):
def __init__(self):
super(RestartingTelegram, self).__init__('The Telegram Bot API service is restarting. Wait few second.')
super(RestartingTelegram, self).__init__(
"The Telegram Bot API service is restarting. Wait few second."
)
class RetryAfter(TelegramAPIError):
def __init__(self, retry_after):
super(RetryAfter, self).__init__(f"Flood control exceeded. Retry in {retry_after} seconds.")
super(RetryAfter, self).__init__(
f"Flood control exceeded. Retry in {retry_after} seconds."
)
self.timeout = retry_after
class MigrateToChat(TelegramAPIError):
def __init__(self, chat_id):
super(MigrateToChat, self).__init__(f"The group has been migrated to a supergroup. New id: {chat_id}.")
super(MigrateToChat, self).__init__(
f"The group has been migrated to a supergroup. New id: {chat_id}."
)
self.migrate_to_chat_id = chat_id
class Throttled(TelegramAPIError):
def __init__(self, **kwargs):
from ..dispatcher.storage import DELTA, EXCEEDED_COUNT, KEY, LAST_CALL, RATE_LIMIT, RESULT
self.key = kwargs.pop(KEY, '<None>')
self.key = kwargs.pop(KEY, "<None>")
self.called_at = kwargs.pop(LAST_CALL, time.time())
self.rate = kwargs.pop(RATE_LIMIT, None)
self.result = kwargs.pop(RESULT, False)
self.exceeded_count = kwargs.pop(EXCEEDED_COUNT, 0)
self.delta = kwargs.pop(DELTA, 0)
self.user = kwargs.pop('user', None)
self.chat = kwargs.pop('chat', None)
self.user = kwargs.pop("user", None)
self.chat = kwargs.pop("chat", None)
def __str__(self):
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
f"exceeded: {self.exceeded_count}, " \
return (
f"Rate limit exceeded! (Limit: {self.rate} s, "
f"exceeded: {self.exceeded_count}, "
f"time delta: {round(self.delta, 3)} s)"
)

View file

@ -12,7 +12,7 @@ from ..bot.api import log
from ..dispatcher.dispatcher import Dispatcher
from ..dispatcher.webhook import BOT_DISPATCHER_KEY, DEFAULT_ROUTE_NAME, WebhookRequestHandler
APP_EXECUTOR_KEY = 'APP_EXECUTOR'
APP_EXECUTOR_KEY = "APP_EXECUTOR"
def _setup_callbacks(executor, on_startup=None, on_shutdown=None):
@ -22,8 +22,17 @@ def _setup_callbacks(executor, on_startup=None, on_shutdown=None):
executor.on_shutdown(on_shutdown)
def start_polling(dispatcher, *, loop=None, skip_updates=False, reset_webhook=True,
on_startup=None, on_shutdown=None, timeout=20, fast=True):
def start_polling(
dispatcher,
*,
loop=None,
skip_updates=False,
reset_webhook=True,
on_startup=None,
on_shutdown=None,
timeout=20,
fast=True,
):
"""
Start bot in long-polling mode
@ -41,11 +50,19 @@ def start_polling(dispatcher, *, loop=None, skip_updates=False, reset_webhook=Tr
executor.start_polling(reset_webhook=reset_webhook, timeout=timeout, fast=fast)
def set_webhook(dispatcher: Dispatcher, webhook_path: str, *, loop: Optional[asyncio.AbstractEventLoop] = None,
skip_updates: bool = None, on_startup: Optional[Callable] = None,
on_shutdown: Optional[Callable] = None, check_ip: bool = False,
retry_after: Optional[Union[str, int]] = None, route_name: str = DEFAULT_ROUTE_NAME,
web_app: Optional[Application] = None):
def set_webhook(
dispatcher: Dispatcher,
webhook_path: str,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
skip_updates: bool = None,
on_startup: Optional[Callable] = None,
on_shutdown: Optional[Callable] = None,
check_ip: bool = False,
retry_after: Optional[Union[str, int]] = None,
route_name: str = DEFAULT_ROUTE_NAME,
web_app: Optional[Application] = None,
):
"""
Set webhook for bot
@ -61,17 +78,32 @@ def set_webhook(dispatcher: Dispatcher, webhook_path: str, *, loop: Optional[asy
:param web_app: Optional[Application] (default: None)
:return:
"""
executor = Executor(dispatcher, skip_updates=skip_updates, check_ip=check_ip, retry_after=retry_after,
loop=loop)
executor = Executor(
dispatcher,
skip_updates=skip_updates,
check_ip=check_ip,
retry_after=retry_after,
loop=loop,
)
_setup_callbacks(executor, on_startup, on_shutdown)
executor.set_webhook(webhook_path, route_name=route_name, web_app=web_app)
return executor
def start_webhook(dispatcher, webhook_path, *, loop=None, skip_updates=None,
on_startup=None, on_shutdown=None, check_ip=False, retry_after=None, route_name=DEFAULT_ROUTE_NAME,
**kwargs):
def start_webhook(
dispatcher,
webhook_path,
*,
loop=None,
skip_updates=None,
on_startup=None,
on_shutdown=None,
check_ip=False,
retry_after=None,
route_name=DEFAULT_ROUTE_NAME,
**kwargs,
):
"""
Start bot in webhook mode
@ -86,20 +118,21 @@ def start_webhook(dispatcher, webhook_path, *, loop=None, skip_updates=None,
:param kwargs:
:return:
"""
executor = set_webhook(dispatcher=dispatcher,
webhook_path=webhook_path,
loop=loop,
skip_updates=skip_updates,
on_startup=on_startup,
on_shutdown=on_shutdown,
check_ip=check_ip,
retry_after=retry_after,
route_name=route_name)
executor = set_webhook(
dispatcher=dispatcher,
webhook_path=webhook_path,
loop=loop,
skip_updates=skip_updates,
on_startup=on_startup,
on_shutdown=on_shutdown,
check_ip=check_ip,
retry_after=retry_after,
route_name=route_name,
)
executor.run_app(**kwargs)
def start(dispatcher, future, *, loop=None, skip_updates=None,
on_startup=None, on_shutdown=None):
def start(dispatcher, future, *, loop=None, skip_updates=None, on_startup=None, on_shutdown=None):
"""
Execute Future.
@ -142,6 +175,7 @@ class Executor:
self._freeze = False
from aiogram import Bot, Dispatcher
Bot.set_current(dispatcher.bot)
Dispatcher.set_current(dispatcher)
@ -160,7 +194,7 @@ class Executor:
@property
def web_app(self) -> web.Application:
if self._web_app is None:
raise RuntimeError('web.Application() is not configured!')
raise RuntimeError("web.Application() is not configured!")
return self._web_app
def on_startup(self, callback: callable, polling=True, webhook=True):
@ -173,7 +207,7 @@ class Executor:
"""
self._check_frozen()
if not webhook and not polling:
warn('This action has no effect!', UserWarning)
warn("This action has no effect!", UserWarning)
return
if isinstance(callback, (list, tuple, set)):
@ -196,7 +230,7 @@ class Executor:
"""
self._check_frozen()
if not webhook and not polling:
warn('This action has no effect!', UserWarning)
warn("This action has no effect!", UserWarning)
return
if isinstance(callback, (list, tuple, set)):
@ -211,7 +245,7 @@ class Executor:
def _check_frozen(self):
if self.frozen:
raise RuntimeError('Executor is frozen!')
raise RuntimeError("Executor is frozen!")
def _prepare_polling(self):
self._check_frozen()
@ -219,7 +253,9 @@ class Executor:
# self.loop.set_task_factory(context.task_factory)
def _prepare_webhook(self, path=None, handler=WebhookRequestHandler, route_name=DEFAULT_ROUTE_NAME, app=None):
def _prepare_webhook(
self, path=None, handler=WebhookRequestHandler, route_name=DEFAULT_ROUTE_NAME, app=None
):
self._check_frozen()
self._freeze = True
@ -233,14 +269,14 @@ class Executor:
raise RuntimeError("web.Application() is already configured!")
if self.retry_after:
app['RETRY_AFTER'] = self.retry_after
app["RETRY_AFTER"] = self.retry_after
if self._identity == app.get(self._identity):
# App is already configured
return
if path is not None:
app.router.add_route('*', path, handler, name=route_name)
app.router.add_route("*", path, handler, name=route_name)
async def _wrap_callback(cb, _):
return await cb(self.dispatcher)
@ -258,10 +294,15 @@ class Executor:
app[APP_EXECUTOR_KEY] = self
app[BOT_DISPATCHER_KEY] = self.dispatcher
app[self._identity] = datetime.datetime.now()
app['_check_ip'] = self.check_ip
app["_check_ip"] = self.check_ip
def set_webhook(self, webhook_path: Optional[str] = None, request_handler: Any = WebhookRequestHandler,
route_name: str = DEFAULT_ROUTE_NAME, web_app: Optional[Application] = None):
def set_webhook(
self,
webhook_path: Optional[str] = None,
request_handler: Any = WebhookRequestHandler,
route_name: str = DEFAULT_ROUTE_NAME,
web_app: Optional[Application] = None,
):
"""
Set webhook for bot
@ -277,8 +318,13 @@ class Executor:
def run_app(self, **kwargs):
web.run_app(self._web_app, **kwargs)
def start_webhook(self, webhook_path=None, request_handler=WebhookRequestHandler, route_name=DEFAULT_ROUTE_NAME,
**kwargs):
def start_webhook(
self,
webhook_path=None,
request_handler=WebhookRequestHandler,
route_name=DEFAULT_ROUTE_NAME,
**kwargs,
):
"""
Start bot in webhook mode
@ -288,7 +334,9 @@ class Executor:
:param kwargs:
:return:
"""
self.set_webhook(webhook_path=webhook_path, request_handler=request_handler, route_name=route_name)
self.set_webhook(
webhook_path=webhook_path, request_handler=request_handler, route_name=route_name
)
self.run_app(**kwargs)
def start_polling(self, reset_webhook=None, timeout=20, fast=True):
@ -303,7 +351,11 @@ class Executor:
try:
loop.run_until_complete(self._startup_polling())
loop.create_task(self.dispatcher.start_polling(reset_webhook=reset_webhook, timeout=timeout, fast=fast))
loop.create_task(
self.dispatcher.start_polling(
reset_webhook=reset_webhook, timeout=timeout, fast=fast
)
)
loop.run_forever()
except (KeyboardInterrupt, SystemExit):
# loop.stop()

View file

@ -16,7 +16,7 @@ Example:
class Helper:
mode = ''
mode = ""
@classmethod
def all(cls):
@ -37,13 +37,13 @@ class Helper:
class HelperMode(Helper):
mode = 'original'
mode = "original"
SCREAMING_SNAKE_CASE = 'SCREAMING_SNAKE_CASE'
lowerCamelCase = 'lowerCamelCase'
CamelCase = 'CamelCase'
snake_case = 'snake_case'
lowercase = 'lowercase'
SCREAMING_SNAKE_CASE = "SCREAMING_SNAKE_CASE"
lowerCamelCase = "lowerCamelCase"
CamelCase = "CamelCase"
snake_case = "snake_case"
lowercase = "lowercase"
@classmethod
def all(cls):
@ -65,10 +65,10 @@ class HelperMode(Helper):
"""
if text.isupper():
return text
result = ''
result = ""
for pos, symbol in enumerate(text):
if symbol.isupper() and pos > 0:
result += '_' + symbol
result += "_" + symbol
else:
result += symbol.upper()
return result
@ -94,10 +94,10 @@ class HelperMode(Helper):
:param first_upper: first symbol must be upper?
:return:
"""
result = ''
result = ""
need_upper = False
for pos, symbol in enumerate(text):
if symbol == '_' and pos > 0:
if symbol == "_" and pos > 0:
need_upper = True
else:
if need_upper:
@ -123,7 +123,7 @@ class HelperMode(Helper):
elif mode == cls.snake_case:
return cls._snake_case(text)
elif mode == cls.lowercase:
return cls._snake_case(text).replace('_', '')
return cls._snake_case(text).replace("_", "")
elif mode == cls.lowerCamelCase:
return cls._camel_case(text)
elif mode == cls.CamelCase:
@ -149,10 +149,10 @@ class Item:
def __set_name__(self, owner, name):
if not name.isupper():
raise NameError('Name for helper item must be in uppercase!')
raise NameError("Name for helper item must be in uppercase!")
if not self._value:
if hasattr(owner, 'mode'):
self._value = HelperMode.apply(name, getattr(owner, 'mode'))
if hasattr(owner, "mode"):
self._value = HelperMode.apply(name, getattr(owner, "mode"))
class ListItem(Item):

View file

@ -1,14 +1,14 @@
import importlib
import os
JSON = 'json'
RAPIDJSON = 'rapidjson'
UJSON = 'ujson'
JSON = "json"
RAPIDJSON = "rapidjson"
UJSON = "ujson"
# Detect mode
mode = JSON
for json_lib in (RAPIDJSON, UJSON):
if 'DISABLE_' + json_lib.upper() in os.environ:
if "DISABLE_" + json_lib.upper() in os.environ:
continue
try:
@ -20,30 +20,35 @@ for json_lib in (RAPIDJSON, UJSON):
break
if mode == RAPIDJSON:
def dumps(data):
return json.dumps(data, ensure_ascii=False, number_mode=json.NM_NATIVE,
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
def dumps(data):
return json.dumps(
data,
ensure_ascii=False,
number_mode=json.NM_NATIVE,
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC,
)
def loads(data):
return json.loads(data, number_mode=json.NM_NATIVE,
datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC)
return json.loads(
data, number_mode=json.NM_NATIVE, datetime_mode=json.DM_ISO8601 | json.DM_NAIVE_IS_UTC
)
elif mode == UJSON:
def loads(data):
return json.loads(data)
def dumps(data):
return json.dumps(data, ensure_ascii=False)
else:
import json
def dumps(data):
return json.dumps(data, ensure_ascii=False)
def loads(data):
return json.loads(data)

View file

@ -1,37 +1,32 @@
LIST_MD_SYMBOLS = '*_`['
LIST_MD_SYMBOLS = "*_`["
MD_SYMBOLS = (
(LIST_MD_SYMBOLS[0], LIST_MD_SYMBOLS[0]),
(LIST_MD_SYMBOLS[1], LIST_MD_SYMBOLS[1]),
(LIST_MD_SYMBOLS[2], LIST_MD_SYMBOLS[2]),
(LIST_MD_SYMBOLS[2] * 3 + '\n', '\n' + LIST_MD_SYMBOLS[2] * 3),
('<b>', '</b>'),
('<i>', '</i>'),
('<code>', '</code>'),
('<pre>', '</pre>'),
(LIST_MD_SYMBOLS[2] * 3 + "\n", "\n" + LIST_MD_SYMBOLS[2] * 3),
("<b>", "</b>"),
("<i>", "</i>"),
("<code>", "</code>"),
("<pre>", "</pre>"),
)
HTML_QUOTES_MAP = {
'<': '&lt;',
'>': '&gt;',
'&': '&amp;',
'"': '&quot;'
}
HTML_QUOTES_MAP = {"<": "&lt;", ">": "&gt;", "&": "&amp;", '"': "&quot;"}
_HQS = HTML_QUOTES_MAP.keys() # HQS for HTML QUOTES SYMBOLS
def _join(*content, sep=' '):
def _join(*content, sep=" "):
return sep.join(map(str, content))
def _escape(s, symbols=LIST_MD_SYMBOLS):
for symbol in symbols:
s = s.replace(symbol, '\\' + symbol)
s = s.replace(symbol, "\\" + symbol)
return s
def _md(string, symbols=('', '')):
def _md(string, symbols=("", "")):
start, end = symbols
return start + string + end
@ -47,13 +42,13 @@ def quote_html(content):
:param content: str
:return: str
"""
new_content = ''
new_content = ""
for symbol in content:
new_content += HTML_QUOTES_MAP[symbol] if symbol in _HQS else symbol
return new_content
def text(*content, sep=' '):
def text(*content, sep=" "):
"""
Join all elements with a separator
@ -64,7 +59,7 @@ def text(*content, sep=' '):
return _join(*content, sep=sep)
def bold(*content, sep=' '):
def bold(*content, sep=" "):
"""
Make bold text (Markdown)
@ -75,7 +70,7 @@ def bold(*content, sep=' '):
return _md(_join(*content, sep=sep), symbols=MD_SYMBOLS[0])
def hbold(*content, sep=' '):
def hbold(*content, sep=" "):
"""
Make bold text (HTML)
@ -86,7 +81,7 @@ def hbold(*content, sep=' '):
return _md(quote_html(_join(*content, sep=sep)), symbols=MD_SYMBOLS[4])
def italic(*content, sep=' '):
def italic(*content, sep=" "):
"""
Make italic text (Markdown)
@ -97,7 +92,7 @@ def italic(*content, sep=' '):
return _md(_join(*content, sep=sep), symbols=MD_SYMBOLS[1])
def hitalic(*content, sep=' '):
def hitalic(*content, sep=" "):
"""
Make italic text (HTML)
@ -108,7 +103,7 @@ def hitalic(*content, sep=' '):
return _md(quote_html(_join(*content, sep=sep)), symbols=MD_SYMBOLS[5])
def code(*content, sep=' '):
def code(*content, sep=" "):
"""
Make mono-width text (Markdown)
@ -119,7 +114,7 @@ def code(*content, sep=' '):
return _md(_join(*content, sep=sep), symbols=MD_SYMBOLS[2])
def hcode(*content, sep=' '):
def hcode(*content, sep=" "):
"""
Make mono-width text (HTML)
@ -130,7 +125,7 @@ def hcode(*content, sep=' '):
return _md(quote_html(_join(*content, sep=sep)), symbols=MD_SYMBOLS[6])
def pre(*content, sep='\n'):
def pre(*content, sep="\n"):
"""
Make mono-width text block (Markdown)
@ -141,7 +136,7 @@ def pre(*content, sep='\n'):
return _md(_join(*content, sep=sep), symbols=MD_SYMBOLS[3])
def hpre(*content, sep='\n'):
def hpre(*content, sep="\n"):
"""
Make mono-width text block (HTML)
@ -174,7 +169,7 @@ def hlink(title, url):
return '<a href="{0}">{1}</a>'.format(url, quote_html(title))
def escape_md(*content, sep=' '):
def escape_md(*content, sep=" "):
"""
Escape markdown text

View file

@ -1,16 +1,16 @@
import contextvars
from typing import TypeVar, Type
__all__ = ('DataMixin', 'ContextInstanceMixin')
__all__ = ("DataMixin", "ContextInstanceMixin")
class DataMixin:
@property
def data(self):
data = getattr(self, '_data', None)
data = getattr(self, "_data", None)
if data is None:
data = {}
setattr(self, '_data', data)
setattr(self, "_data", data)
return data
def __getitem__(self, item):
@ -26,12 +26,12 @@ class DataMixin:
return self.data.get(key, default)
T = TypeVar('T')
T = TypeVar("T")
class ContextInstanceMixin:
def __init_subclass__(cls, **kwargs):
cls.__context_instance = contextvars.ContextVar('instance_' + cls.__name__)
cls.__context_instance = contextvars.ContextVar("instance_" + cls.__name__)
return cls
@classmethod
@ -43,5 +43,7 @@ class ContextInstanceMixin:
@classmethod
def set_current(cls: Type[T], value: T):
if not isinstance(value, cls):
raise TypeError(f"Value should be instance of '{cls.__name__}' not '{type(value).__name__}'")
raise TypeError(
f"Value should be instance of '{cls.__name__}' not '{type(value).__name__}'"
)
cls.__context_instance.set(value)

View file

@ -12,7 +12,7 @@ def split_text(text: str, length: int = MAX_MESSAGE_LENGTH) -> typing.List[str]:
:return: list of parts
:rtype: :obj:`typing.List[str]`
"""
return [text[i:i + length] for i in range(0, len(text), length)]
return [text[i : i + length] for i in range(0, len(text), length)]
def safe_split_text(text: str, length: int = MAX_MESSAGE_LENGTH) -> typing.List[str]:
@ -30,7 +30,7 @@ def safe_split_text(text: str, length: int = MAX_MESSAGE_LENGTH) -> typing.List[
while temp_text:
if len(temp_text) > length:
try:
split_pos = temp_text[:length].rindex(' ')
split_pos = temp_text[:length].rindex(" ")
except ValueError:
split_pos = length
if split_pos < length // 4 * 3:
@ -56,4 +56,4 @@ def paginate(data: typing.Iterable, page: int = 0, limit: int = 10) -> typing.It
:return: sliced object
:rtype: :obj:`typing.Iterable`
"""
return data[page * limit:page * limit + limit]
return data[page * limit : page * limit + limit]

View file

@ -6,7 +6,7 @@ from babel.support import LazyProxy
from aiogram import types
from . import json
DEFAULT_FILTER = ['self', 'cls']
DEFAULT_FILTER = ["self", "cls"]
def generate_payload(exclude=None, **kwargs):
@ -21,10 +21,11 @@ def generate_payload(exclude=None, **kwargs):
"""
if exclude is None:
exclude = []
return {key: value for key, value in kwargs.items() if
key not in exclude + DEFAULT_FILTER
and value is not None
and not key.startswith('_')}
return {
key: value
for key, value in kwargs.items()
if key not in exclude + DEFAULT_FILTER and value is not None and not key.startswith("_")
}
def _normalize(obj):
@ -38,7 +39,7 @@ def _normalize(obj):
return [_normalize(item) for item in obj]
elif isinstance(obj, dict):
return {k: _normalize(v) for k, v in obj.items() if v is not None}
elif hasattr(obj, 'to_python'):
elif hasattr(obj, "to_python"):
return obj.to_python()
return obj
@ -52,7 +53,7 @@ def prepare_arg(value):
"""
if value is None:
return value
elif isinstance(value, (list, dict)) or hasattr(value, 'to_python'):
elif isinstance(value, (list, dict)) or hasattr(value, "to_python"):
return json.dumps(_normalize(value))
elif isinstance(value, datetime.timedelta):
now = datetime.datetime.now()

View file

@ -27,9 +27,13 @@ Welcome to aiogram's documentation!
:alt: Telegram Bot API
.. image:: https://img.shields.io/readthedocs/pip/stable.svg?style=flat-square
:target: http://aiogram.readthedocs.io/en/latest/?badge=latest
:target: http://aiogram.readthedocs.io/en/latest/?badge=latest?style=flat-square
:alt: Documentation Status
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square
:target: https://github.com/python/black
:alt: Code style: Black
.. image:: https://img.shields.io/github/issues/aiogram/aiogram.svg?style=flat-square
:target: https://github.com/aiogram/aiogram/issues
:alt: Github issues

View file

@ -35,16 +35,18 @@ from aiogram.utils.executor import start_polling, start_webhook
logging.basicConfig(level=logging.INFO)
# Configure arguments parser.
parser = argparse.ArgumentParser(description='Python telegram bot')
parser.add_argument('--token', '-t', nargs='?', type=str, default=None, help='Set working directory')
parser.add_argument('--sock', help='UNIX Socket path')
parser.add_argument('--host', help='Webserver host')
parser.add_argument('--port', type=int, help='Webserver port')
parser.add_argument('--cert', help='Path to SSL certificate')
parser.add_argument('--pkey', help='Path to SSL private key')
parser.add_argument('--host-name', help='Set webhook host name')
parser.add_argument('--webhook-port', type=int, help='Port for webhook (default=port)')
parser.add_argument('--webhook-path', default='/webhook', help='Port for webhook (default=port)')
parser = argparse.ArgumentParser(description="Python telegram bot")
parser.add_argument(
"--token", "-t", nargs="?", type=str, default=None, help="Set working directory"
)
parser.add_argument("--sock", help="UNIX Socket path")
parser.add_argument("--host", help="Webserver host")
parser.add_argument("--port", type=int, help="Webserver port")
parser.add_argument("--cert", help="Path to SSL certificate")
parser.add_argument("--pkey", help="Path to SSL private key")
parser.add_argument("--host-name", help="Set webhook host name")
parser.add_argument("--webhook-port", type=int, help="Port for webhook (default=port)")
parser.add_argument("--webhook-path", default="/webhook", help="Port for webhook (default=port)")
async def cmd_start(message: types.Message):
@ -53,7 +55,7 @@ async def cmd_start(message: types.Message):
def setup_handlers(dispatcher: Dispatcher):
# This example has only one messages handler
dispatcher.register_message_handler(cmd_start, commands=['start', 'welcome'])
dispatcher.register_message_handler(cmd_start, commands=["start", "welcome"])
async def on_startup(dispatcher, url=None, cert=None):
@ -73,7 +75,7 @@ async def on_startup(dispatcher, url=None, cert=None):
# Set new URL for webhook
if cert:
with open(cert, 'rb') as cert_file:
with open(cert, "rb") as cert_file:
await bot.set_webhook(url, certificate=cert_file)
else:
await bot.set_webhook(url)
@ -83,7 +85,7 @@ async def on_startup(dispatcher, url=None, cert=None):
async def on_shutdown(dispatcher):
print('Shutdown.')
print("Shutdown.")
def main(arguments):
@ -99,8 +101,8 @@ def main(arguments):
webhook_path = args.webhook_path
# Fi webhook path
if not webhook_path.startswith('/'):
webhook_path = '/' + webhook_path
if not webhook_path.startswith("/"):
webhook_path = "/" + webhook_path
# Generate webhook URL
webhook_url = f"https://{host_name}:{webhook_port}{webhook_path}"
@ -116,15 +118,21 @@ def main(arguments):
else:
ssl_context = None
start_webhook(dispatcher, webhook_path,
on_startup=functools.partial(on_startup, url=webhook_url, cert=cert),
on_shutdown=on_shutdown,
host=host, port=port, path=sock, ssl_context=ssl_context)
start_webhook(
dispatcher,
webhook_path,
on_startup=functools.partial(on_startup, url=webhook_url, cert=cert),
on_shutdown=on_shutdown,
host=host,
port=port,
path=sock,
ssl_context=ssl_context,
)
else:
start_polling(dispatcher, on_startup=on_startup, on_shutdown=on_shutdown)
if __name__ == '__main__':
if __name__ == "__main__":
argv = sys.argv[1:]
if not len(argv):

View file

@ -4,10 +4,10 @@ import logging
from aiogram import Bot, Dispatcher, types
from aiogram.utils import exceptions, executor
API_TOKEN = 'BOT TOKEN HERE'
API_TOKEN = "BOT TOKEN HERE"
logging.basicConfig(level=logging.INFO)
log = logging.getLogger('broadcast')
log = logging.getLogger("broadcast")
loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.HTML)
@ -61,15 +61,15 @@ async def broadcaster() -> int:
count = 0
try:
for user_id in get_users():
if await send_message(user_id, '<b>Hello!</b>'):
if await send_message(user_id, "<b>Hello!</b>"):
count += 1
await asyncio.sleep(.05) # 20 messages per second (Limit: 30 messages per second)
await asyncio.sleep(0.05) # 20 messages per second (Limit: 30 messages per second)
finally:
log.info(f"{count} messages successful sent.")
return count
if __name__ == '__main__':
if __name__ == "__main__":
# Execute broadcaster
executor.start(dp, broadcaster())

View file

@ -11,7 +11,7 @@ from aiogram.utils.exceptions import MessageNotModified, Throttled
logging.basicConfig(level=logging.INFO)
API_TOKEN = 'BOT TOKEN HERE'
API_TOKEN = "BOT TOKEN HERE"
loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop, parse_mode=types.ParseMode.HTML)
@ -21,16 +21,17 @@ dp.middleware.setup(LoggingMiddleware())
POSTS = {
str(uuid.uuid4()): {
'title': f"Post {index}",
'body': 'Lorem ipsum dolor sit amet, '
'consectetur adipiscing elit, '
'sed do eiusmod tempor incididunt ut '
'labore et dolore magna aliqua',
'votes': random.randint(-2, 5)
} for index in range(1, 6)
"title": f"Post {index}",
"body": "Lorem ipsum dolor sit amet, "
"consectetur adipiscing elit, "
"sed do eiusmod tempor incididunt ut "
"labore et dolore magna aliqua",
"votes": random.randint(-2, 5),
}
for index in range(1, 6)
}
posts_cb = CallbackData('post', 'id', 'action') # post:<id>:<action>
posts_cb = CallbackData("post", "id", "action") # post:<id>:<action>
def get_keyboard() -> types.InlineKeyboardMarkup:
@ -41,69 +42,73 @@ def get_keyboard() -> types.InlineKeyboardMarkup:
for post_id, post in POSTS.items():
markup.add(
types.InlineKeyboardButton(
post['title'],
callback_data=posts_cb.new(id=post_id, action='view'))
post["title"], callback_data=posts_cb.new(id=post_id, action="view")
)
)
return markup
def format_post(post_id: str, post: dict) -> (str, types.InlineKeyboardMarkup):
text = f"{md.hbold(post['title'])}\n" \
f"{md.quote_html(post['body'])}\n" \
f"\n" \
text = (
f"{md.hbold(post['title'])}\n"
f"{md.quote_html(post['body'])}\n"
f"\n"
f"Votes: {post['votes']}"
)
markup = types.InlineKeyboardMarkup()
markup.row(
types.InlineKeyboardButton('👍', callback_data=posts_cb.new(id=post_id, action='like')),
types.InlineKeyboardButton('👎', callback_data=posts_cb.new(id=post_id, action='unlike')),
types.InlineKeyboardButton("👍", callback_data=posts_cb.new(id=post_id, action="like")),
types.InlineKeyboardButton("👎", callback_data=posts_cb.new(id=post_id, action="unlike")),
)
markup.add(
types.InlineKeyboardButton("<< Back", callback_data=posts_cb.new(id="-", action="list"))
)
markup.add(types.InlineKeyboardButton('<< Back', callback_data=posts_cb.new(id='-', action='list')))
return text, markup
@dp.message_handler(commands='start')
@dp.message_handler(commands="start")
async def cmd_start(message: types.Message):
await message.reply('Posts', reply_markup=get_keyboard())
await message.reply("Posts", reply_markup=get_keyboard())
@dp.callback_query_handler(posts_cb.filter(action='list'))
@dp.callback_query_handler(posts_cb.filter(action="list"))
async def query_show_list(query: types.CallbackQuery):
await query.message.edit_text('Posts', reply_markup=get_keyboard())
await query.message.edit_text("Posts", reply_markup=get_keyboard())
@dp.callback_query_handler(posts_cb.filter(action='view'))
@dp.callback_query_handler(posts_cb.filter(action="view"))
async def query_view(query: types.CallbackQuery, callback_data: dict):
post_id = callback_data['id']
post_id = callback_data["id"]
post = POSTS.get(post_id, None)
if not post:
return await query.answer('Unknown post!')
return await query.answer("Unknown post!")
text, markup = format_post(post_id, post)
await query.message.edit_text(text, reply_markup=markup)
@dp.callback_query_handler(posts_cb.filter(action=['like', 'unlike']))
@dp.callback_query_handler(posts_cb.filter(action=["like", "unlike"]))
async def query_post_vote(query: types.CallbackQuery, callback_data: dict):
try:
await dp.throttle('vote', rate=1)
await dp.throttle("vote", rate=1)
except Throttled:
return await query.answer('Too many requests.')
return await query.answer("Too many requests.")
post_id = callback_data['id']
action = callback_data['action']
post_id = callback_data["id"]
action = callback_data["action"]
post = POSTS.get(post_id, None)
if not post:
return await query.answer('Unknown post!')
return await query.answer("Unknown post!")
if action == 'like':
post['votes'] += 1
elif action == 'unlike':
post['votes'] -= 1
if action == "like":
post["votes"] += 1
elif action == "unlike":
post["votes"] -= 1
await query.answer('Voted.')
await query.answer("Voted.")
text, markup = format_post(post_id, post)
await query.message.edit_text(text, reply_markup=markup)
@ -113,5 +118,5 @@ async def message_not_modified_handler(update, error):
return True
if __name__ == '__main__':
if __name__ == "__main__":
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -7,7 +7,7 @@ import logging
from aiogram import Bot, Dispatcher, executor, md, types
API_TOKEN = 'BOT TOKEN HERE'
API_TOKEN = "BOT TOKEN HERE"
logging.basicConfig(level=logging.INFO)
@ -20,14 +20,17 @@ dp = Dispatcher(bot)
async def check_language(message: types.Message):
locale = message.from_user.locale
await message.reply(md.text(
md.bold('Info about your language:'),
md.text(' 🔸', md.bold('Code:'), md.italic(locale.locale)),
md.text(' 🔸', md.bold('Territory:'), md.italic(locale.territory or 'Unknown')),
md.text(' 🔸', md.bold('Language name:'), md.italic(locale.language_name)),
md.text(' 🔸', md.bold('English language name:'), md.italic(locale.english_name)),
sep='\n'))
await message.reply(
md.text(
md.bold("Info about your language:"),
md.text(" 🔸", md.bold("Code:"), md.italic(locale.locale)),
md.text(" 🔸", md.bold("Territory:"), md.italic(locale.territory or "Unknown")),
md.text(" 🔸", md.bold("Language name:"), md.italic(locale.language_name)),
md.text(" 🔸", md.bold("English language name:"), md.italic(locale.english_name)),
sep="\n",
)
)
if __name__ == '__main__':
if __name__ == "__main__":
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -7,7 +7,7 @@ import logging
from aiogram import Bot, Dispatcher, executor, types
API_TOKEN = 'BOT TOKEN HERE'
API_TOKEN = "BOT TOKEN HERE"
# Configure logging
logging.basicConfig(level=logging.INFO)
@ -17,7 +17,7 @@ bot = Bot(token=API_TOKEN)
dp = Dispatcher(bot)
@dp.message_handler(commands=['start', 'help'])
@dp.message_handler(commands=["start", "help"])
async def send_welcome(message: types.Message):
"""
This handler will be called when client send `/start` or `/help` commands.
@ -25,11 +25,15 @@ async def send_welcome(message: types.Message):
await message.reply("Hi!\nI'm EchoBot!\nPowered by aiogram.")
@dp.message_handler(regexp='(^cat[s]?$|puss)')
@dp.message_handler(regexp="(^cat[s]?$|puss)")
async def cats(message: types.Message):
with open('data/cats.jpg', 'rb') as photo:
await bot.send_photo(message.chat.id, photo, caption='Cats is here 😺',
reply_to_message_id=message.message_id)
with open("data/cats.jpg", "rb") as photo:
await bot.send_photo(
message.chat.id,
photo,
caption="Cats is here 😺",
reply_to_message_id=message.message_id,
)
@dp.message_handler()
@ -37,5 +41,5 @@ async def echo(message: types.Message):
await bot.send_message(message.chat.id, message.text)
if __name__ == '__main__':
if __name__ == "__main__":
executor.start_polling(dp, skip_updates=True)

View file

@ -9,7 +9,7 @@ from aiogram.dispatcher.filters.state import State, StatesGroup
from aiogram.types import ParseMode
from aiogram.utils import executor
API_TOKEN = 'BOT TOKEN HERE'
API_TOKEN = "BOT TOKEN HERE"
loop = asyncio.get_event_loop()
@ -27,7 +27,7 @@ class Form(StatesGroup):
gender = State() # Will be represented in storage as 'Form:gender'
@dp.message_handler(commands=['start'])
@dp.message_handler(commands=["start"])
async def cmd_start(message: types.Message):
"""
Conversation's entry point
@ -39,9 +39,11 @@ async def cmd_start(message: types.Message):
# You can use state '*' if you need to handle all states
@dp.message_handler(state='*', commands=['cancel'])
@dp.message_handler(lambda message: message.text.lower() == 'cancel', state='*')
async def cancel_handler(message: types.Message, state: FSMContext, raw_state: Optional[str] = None):
@dp.message_handler(state="*", commands=["cancel"])
@dp.message_handler(lambda message: message.text.lower() == "cancel", state="*")
async def cancel_handler(
message: types.Message, state: FSMContext, raw_state: Optional[str] = None
):
"""
Allow user to cancel any action
"""
@ -51,7 +53,7 @@ async def cancel_handler(message: types.Message, state: FSMContext, raw_state: O
# Cancel state and inform user about it
await state.finish()
# And remove keyboard (just in case)
await message.reply('Canceled.', reply_markup=types.ReplyKeyboardRemove())
await message.reply("Canceled.", reply_markup=types.ReplyKeyboardRemove())
@dp.message_handler(state=Form.name)
@ -60,7 +62,7 @@ async def process_name(message: types.Message, state: FSMContext):
Process user name
"""
async with state.proxy() as data:
data['name'] = message.text
data["name"] = message.text
await Form.next()
await message.reply("How old are you?")
@ -89,7 +91,9 @@ async def process_age(message: types.Message, state: FSMContext):
await message.reply("What is your gender?", reply_markup=markup)
@dp.message_handler(lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender)
@dp.message_handler(
lambda message: message.text not in ["Male", "Female", "Other"], state=Form.gender
)
async def failed_process_gender(message: types.Message):
"""
In this example gender has to be one of: Male, Female, Other.
@ -100,21 +104,27 @@ async def failed_process_gender(message: types.Message):
@dp.message_handler(state=Form.gender)
async def process_gender(message: types.Message, state: FSMContext):
async with state.proxy() as data:
data['gender'] = message.text
data["gender"] = message.text
# Remove keyboard
markup = types.ReplyKeyboardRemove()
# And send message
await bot.send_message(message.chat.id, md.text(
md.text('Hi! Nice to meet you,', md.bold(data['name'])),
md.text('Age:', data['age']),
md.text('Gender:', data['gender']),
sep='\n'), reply_markup=markup, parse_mode=ParseMode.MARKDOWN)
await bot.send_message(
message.chat.id,
md.text(
md.text("Hi! Nice to meet you,", md.bold(data["name"])),
md.text("Age:", data["age"]),
md.text("Gender:", data["gender"]),
sep="\n",
),
reply_markup=markup,
parse_mode=ParseMode.MARKDOWN,
)
# Finish conversation
data.state = None
if __name__ == '__main__':
if __name__ == "__main__":
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -24,11 +24,11 @@ from pathlib import Path
from aiogram import Bot, Dispatcher, executor, types
from aiogram.contrib.middlewares.i18n import I18nMiddleware
TOKEN = 'BOT TOKEN HERE'
I18N_DOMAIN = 'mybot'
TOKEN = "BOT TOKEN HERE"
I18N_DOMAIN = "mybot"
BASE_DIR = Path(__file__).parent
LOCALES_DIR = BASE_DIR / 'locales'
LOCALES_DIR = BASE_DIR / "locales"
bot = Bot(TOKEN, parse_mode=types.ParseMode.HTML)
dp = Dispatcher(bot)
@ -41,16 +41,16 @@ dp.middleware.setup(i18n)
_ = i18n.gettext
@dp.message_handler(commands=['start'])
@dp.message_handler(commands=["start"])
async def cmd_start(message: types.Message):
# Simply use `_('message')` instead of `'message'` and never use f-strings for translatable texts.
await message.reply(_('Hello, <b>{user}</b>!').format(user=message.from_user.full_name))
await message.reply(_("Hello, <b>{user}</b>!").format(user=message.from_user.full_name))
@dp.message_handler(commands=['lang'])
@dp.message_handler(commands=["lang"])
async def cmd_lang(message: types.Message, locale):
await message.reply(_('Your current language: <i>{language}</i>').format(language=locale))
await message.reply(_("Your current language: <i>{language}</i>").format(language=locale))
if __name__ == '__main__':
if __name__ == "__main__":
executor.start_polling(dp, skip_updates=True)

View file

@ -3,7 +3,7 @@ import logging
from aiogram import Bot, types, Dispatcher, executor
API_TOKEN = 'BOT TOKEN HERE'
API_TOKEN = "BOT TOKEN HERE"
logging.basicConfig(level=logging.DEBUG)
@ -14,11 +14,12 @@ dp = Dispatcher(bot)
@dp.inline_handler()
async def inline_echo(inline_query: types.InlineQuery):
input_content = types.InputTextMessageContent(inline_query.query or 'echo')
item = types.InlineQueryResultArticle(id='1', title='echo',
input_message_content=input_content)
input_content = types.InputTextMessageContent(inline_query.query or "echo")
item = types.InlineQueryResultArticle(
id="1", title="echo", input_message_content=input_content
)
await bot.answer_inline_query(inline_query.id, results=[item], cache_time=1)
if __name__ == '__main__':
if __name__ == "__main__":
executor.start_polling(dp, loop=loop, skip_updates=True)

View file

@ -2,7 +2,7 @@ import asyncio
from aiogram import Bot, Dispatcher, executor, filters, types
API_TOKEN = 'BOT TOKEN HERE'
API_TOKEN = "BOT TOKEN HERE"
loop = asyncio.get_event_loop()
bot = Bot(token=API_TOKEN, loop=loop)
@ -24,13 +24,13 @@ async def send_welcome(message: types.Message):
media = types.MediaGroup()
# Attach local file
media.attach_photo(types.InputFile('data/cat.jpg'), 'Cat!')
media.attach_photo(types.InputFile("data/cat.jpg"), "Cat!")
# More local files and more cats!
media.attach_photo(types.InputFile('data/cats.jpg'), 'More cats!')
media.attach_photo(types.InputFile("data/cats.jpg"), "More cats!")
# You can also use URL's
# For example: get random puss:
media.attach_photo('http://lorempixel.com/400/200/cats/', 'Random cat.')
media.attach_photo("http://lorempixel.com/400/200/cats/", "Random cat.")
# And you can also use file ID:
# media.attach_photo('<file_id>', 'cat-cat-cat.')
@ -39,5 +39,5 @@ async def send_welcome(message: types.Message):
await message.reply_media_group(media=media)
if __name__ == '__main__':
if __name__ == "__main__":
executor.start_polling(dp, loop=loop, skip_updates=True)

Some files were not shown because too many files have changed in this diff Show more