diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index fcffbb289..0c7da7eb1 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -1,28 +1,43 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext -import os +try: + import aiofiles +except ImportError: + aiofiles_installed = False + +import os import pickle +import asyncio +from typing import Optional, Union, Callable, Any +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext + +def with_lock(func: Callable) -> Callable: + async def wrapper(self, *args, **kwargs): + async with self.lock: + return await func(self, *args, **kwargs) + return wrapper class StatePickleStorage(StateStorageBase): - def __init__(self, file_path="./.state-save/states.pkl") -> None: + def __init__(self, file_path: str = "./.state-save/states.pkl", + prefix='telebot', separator: Optional[str] = ":") -> None: + + if not aiofiles_installed: + raise ImportError("Please install aiofiles using `pip install aiofiles`") + self.file_path = file_path + self.prefix = prefix + self.separator = separator + self.lock = asyncio.Lock() self.create_dir() - self.data = self.read() - - async def convert_old_to_new(self): - # old looks like: - # {1: {'state': 'start', 'data': {'name': 'John'}} - # we should update old version pickle to new. - # new looks like: - # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}} - new_data = {} - for key, value in self.data.items(): - # this returns us id and dict with data and state - new_data[key] = {key: value} # convert this to new - # pass it to global data - self.data = new_data - self.update_data() # update data in file + + async def _read_from_file(self) -> dict: + async with aiofiles.open(self.file_path, 'rb') as f: + data = await f.read() + return pickle.loads(data) + + async def _write_to_file(self, data: dict) -> None: + async with aiofiles.open(self.file_path, 'wb') as f: + await f.write(pickle.dumps(data)) def create_dir(self): """ @@ -34,77 +49,100 @@ def create_dir(self): with open(self.file_path,'wb') as file: pickle.dump({}, file) - def read(self): - file = open(self.file_path, 'rb') - data = pickle.load(file) - file.close() - return data - - def update_data(self): - file = open(self.file_path, 'wb+') - pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL) - file.close() - - async def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): - state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - self.update_data() - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - self.update_data() - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} - self.update_data() + + @with_lock + async def set_state(self, chat_id: int, user_id: int, state: str, + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + if _key not in data: + data[_key] = {"state": state, "data": {}} + else: + data[_key]["state"] = state + await self._write_to_file(data) return True - - async def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - self.update_data() - return True + @with_lock + async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + return data.get(_key, {}).get("state") + + @with_lock + async def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + if _key in data: + del data[_key] + await self._write_to_file(data) + return True return False - - async def get_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] - - return None - async def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] - - return None - - async def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - self.update_data() - return True + @with_lock + async def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + state_data = data.get(_key, {}) + state_data["data"][key] = value + if _key not in data: + data[_key] = {"state": None, "data": state_data} + else: + data[_key]["data"][key] = value + await self._write_to_file(data) + return True + + @with_lock + async def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + return data.get(_key, {}).get("data", {}) + + @with_lock + async def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + if _key in data: + data[_key]["data"] = {} + await self._write_to_file(data) + return True return False - async def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - self.update_data() - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: + return StateDataContext( + self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id + ) - def get_interactive_data(self, chat_id, user_id): - return StateDataContext(self, chat_id, user_id) + @with_lock + async def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + data[_key]["data"] = data + await self._write_to_file(data) + return True - async def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data - self.update_data() \ No newline at end of file + def __str__(self) -> str: + return f"StatePickleStorage({self.file_path}, {self.prefix})" diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index 86fde3716..a6b19780d 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -1,179 +1,131 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext -import json -redis_installed = True -is_actual_aioredis = False try: - import aioredis - is_actual_aioredis = True + import redis + from redis.asyncio import Redis, ConnectionPool except ImportError: - try: - from redis import asyncio as aioredis - except ImportError: - redis_installed = False + redis_installed = False + +import json +from typing import Optional, Union, Callable, Coroutine +import asyncio + +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext + + +def async_with_lock(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + async def wrapper(self, *args, **kwargs): + async with self.lock: + return await func(self, *args, **kwargs) + return wrapper +def async_with_pipeline(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + async def wrapper(self, *args, **kwargs): + async with self.redis.pipeline() as pipe: + pipe.multi() + result = await func(self, pipe, *args, **kwargs) + await pipe.execute() + return result + return wrapper class StateRedisStorage(StateStorageBase): - """ - This class is for Redis storage. - This will work only for states. - To use it, just pass this class to: - TeleBot(storage=StateRedisStorage()) - """ - def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_', redis_url=None): + def __init__(self, host='localhost', port=6379, db=0, password=None, + prefix='telebot', + redis_url=None, + connection_pool: 'ConnectionPool'=None, + separator: Optional[str] = ":", + ) -> None: + if not redis_installed: - raise ImportError('AioRedis is not installed. Install it via "pip install aioredis"') + raise ImportError("Please install redis using `pip install redis`") + + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") - if is_actual_aioredis: - aioredis_version = tuple(map(int, aioredis.__version__.split(".")[0])) - if aioredis_version < (2,): - raise ImportError('Invalid aioredis version. Aioredis version should be >= 2.0.0') if redis_url: - self.redis = aioredis.Redis.from_url(redis_url) + self.redis = redis.asyncio.from_url(redis_url) + elif connection_pool: + self.redis = Redis(connection_pool=connection_pool) else: - self.redis = aioredis.Redis(host=host, port=port, db=db, password=password) + self.redis = Redis(host=host, port=port, db=db, password=password) + + self.lock = asyncio.Lock() + + @async_with_lock + @async_with_pipeline + async def set_state(self, pipe, chat_id: int, user_id: int, state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + if hasattr(state, "name"): + state = state.name - self.prefix = prefix - #self.con = Redis(connection_pool=self.redis) -> use this when necessary - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - - async def get_record(self, key): - """ - Function to get record from database. - It has nothing to do with states. - Made for backward compatibility - """ - result = await self.redis.get(self.prefix+str(key)) - if result: return json.loads(result) - return - - async def set_record(self, key, value): - """ - Function to set record to database. - It has nothing to do with states. - Made for backward compatibility - """ - - await self.redis.set(self.prefix+str(key), json.dumps(value)) + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + await pipe.hset(_key, "state", state) return True - async def delete_record(self, key): - """ - Function to delete record from database. - It has nothing to do with states. - Made for backward compatibility - """ - await self.redis.delete(self.prefix+str(key)) + async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + state_bytes = await self.redis.hget(_key, "state") + return state_bytes.decode('utf-8') if state_bytes else None + + async def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + result = await self.redis.delete(_key) + return result > 0 + + @async_with_lock + @async_with_pipeline + async def set_data(self, pipe, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + data = await pipe.hget(_key, "data") + data = await pipe.execute() + data = data[0] + if data is None: + await pipe.hset(_key, "data", json.dumps({key: value})) + else: + data = json.loads(data) + data[key] = value + await pipe.hset(_key, "data", json.dumps(data)) return True - async def set_state(self, chat_id, user_id, state): - """ - Set state for a particular user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if hasattr(state, 'name'): - state = state.name - if response: - if user_id in response: - response[user_id]['state'] = state - else: - response[user_id] = {'state': state, 'data': {}} + async def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + data = await self.redis.hget(_key, "data") + return json.loads(data) if data else {} + + @async_with_lock + @async_with_pipeline + async def reset_data(self, pipe, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + if await pipe.exists(_key): + await pipe.hset(_key, "data", "{}") else: - response = {user_id: {'state': state, 'data': {}}} - await self.set_record(chat_id, response) + return False + return True + def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: + return StateDataContext(self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id) + + @async_with_lock + @async_with_pipeline + async def save(self, pipe, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + if await pipe.exists(_key): + await pipe.hset(_key, "data", json.dumps(data)) + else: + return False return True - - async def delete_state(self, chat_id, user_id): - """ - Delete state for a particular user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - del response[user_id] - if user_id == str(chat_id): - await self.delete_record(chat_id) - return True - else: await self.set_record(chat_id, response) - return True - return False - - async def get_value(self, chat_id, user_id, key): - """ - Get value for a data of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - if key in response[user_id]['data']: - return response[user_id]['data'][key] - return None - - async def get_state(self, chat_id, user_id): - """ - Get state of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['state'] - - return None - - async def get_data(self, chat_id, user_id): - """ - Get data of particular user in a particular chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['data'] - return None - - async def reset_data(self, chat_id, user_id): - """ - Reset data of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = {} - await self.set_record(chat_id, response) - return True - - async def set_data(self, chat_id, user_id, key, value): - """ - Set data without interactive data. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'][key] = value - await self.set_record(chat_id, response) - return True - return False - - def get_interactive_data(self, chat_id, user_id): - """ - Get Data in interactive way. - You can use with() with this function. - """ - return StateDataContext(self, chat_id, user_id) - - async def save(self, chat_id, user_id, data): - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = data - await self.set_record(chat_id, response) - return True + + def __str__(self) -> str: + # include some connection info + return f"StateRedisStorage({self.redis})" diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index 917491814..32a9653c7 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -127,7 +127,4 @@ def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: O return True def __str__(self) -> str: - with self.lock: - with open(self.file_path, 'rb') as f: - data = pickle.load(f) - return f"" + return f"StatePickleStorage({self.file_path}, {self.prefix})" diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 0aa7c43ea..d41da05c8 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -151,6 +151,4 @@ def save_action(pipe): return True def __str__(self) -> str: - keys = self.redis.keys(f"{self.prefix}{self.separator}*") - data = {key.decode(): self.redis.hgetall(key) for key in keys} - return f"" + return f"StateRedisStorage({self.redis})"