Skip to content

Commit

Permalink
Async version fully supported, partially tested.
Browse files Browse the repository at this point in the history
  • Loading branch information
coder2020official committed Jul 19, 2024
1 parent 15bced9 commit ff6485d
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 252 deletions.
206 changes: 122 additions & 84 deletions telebot/asyncio_storage/pickle_storage.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext
import os

try:
import aiofiles
except ImportError:
aiofiles_installed = False

import os
import pickle
import asyncio
from typing import Optional, Union, Callable, Any

from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext

def with_lock(func: Callable) -> Callable:
async def wrapper(self, *args, **kwargs):
async with self.lock:
return await func(self, *args, **kwargs)
return wrapper

class StatePickleStorage(StateStorageBase):
def __init__(self, file_path="./.state-save/states.pkl") -> None:
def __init__(self, file_path: str = "./.state-save/states.pkl",
prefix='telebot', separator: Optional[str] = ":") -> None:

if not aiofiles_installed:
raise ImportError("Please install aiofiles using `pip install aiofiles`")

self.file_path = file_path
self.prefix = prefix
self.separator = separator
self.lock = asyncio.Lock()
self.create_dir()
self.data = self.read()

async def convert_old_to_new(self):
# old looks like:
# {1: {'state': 'start', 'data': {'name': 'John'}}
# we should update old version pickle to new.
# new looks like:
# {1: {2: {'state': 'start', 'data': {'name': 'John'}}}}
new_data = {}
for key, value in self.data.items():
# this returns us id and dict with data and state
new_data[key] = {key: value} # convert this to new
# pass it to global data
self.data = new_data
self.update_data() # update data in file

async def _read_from_file(self) -> dict:
async with aiofiles.open(self.file_path, 'rb') as f:
data = await f.read()
return pickle.loads(data)

async def _write_to_file(self, data: dict) -> None:
async with aiofiles.open(self.file_path, 'wb') as f:
await f.write(pickle.dumps(data))

def create_dir(self):
"""
Expand All @@ -34,77 +49,100 @@ def create_dir(self):
with open(self.file_path,'wb') as file:
pickle.dump({}, file)

def read(self):
file = open(self.file_path, 'rb')
data = pickle.load(file)
file.close()
return data

def update_data(self):
file = open(self.file_path, 'wb+')
pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL)
file.close()

async def set_state(self, chat_id, user_id, state):
if hasattr(state, 'name'):
state = state.name
if chat_id in self.data:
if user_id in self.data[chat_id]:
self.data[chat_id][user_id]['state'] = state
self.update_data()
return True
else:
self.data[chat_id][user_id] = {'state': state, 'data': {}}
self.update_data()
return True
self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
self.update_data()

@with_lock
async def set_state(self, chat_id: int, user_id: int, state: str,
business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None,
bot_id: Optional[int] = None) -> bool:
_key = self._get_key(
chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id
)
data = await self._read_from_file()
if _key not in data:
data[_key] = {"state": state, "data": {}}
else:
data[_key]["state"] = state
await self._write_to_file(data)
return True

async def delete_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
del self.data[chat_id][user_id]
if chat_id == user_id:
del self.data[chat_id]
self.update_data()
return True

@with_lock
async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None,
message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]:
_key = self._get_key(
chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id
)
data = await self._read_from_file()
return data.get(_key, {}).get("state")

@with_lock
async def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None,
message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool:
_key = self._get_key(
chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id
)
data = await self._read_from_file()
if _key in data:
del data[_key]
await self._write_to_file(data)
return True
return False


async def get_state(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['state']

return None
async def get_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
return self.data[chat_id][user_id]['data']

return None

async def reset_data(self, chat_id, user_id):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'] = {}
self.update_data()
return True
@with_lock
async def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict],
business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None,
bot_id: Optional[int] = None) -> bool:
_key = self._get_key(
chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id
)
data = await self._read_from_file()
state_data = data.get(_key, {})
state_data["data"][key] = value
if _key not in data:
data[_key] = {"state": None, "data": state_data}
else:
data[_key]["data"][key] = value
await self._write_to_file(data)
return True

@with_lock
async def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None,
message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict:
_key = self._get_key(
chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id
)
data = await self._read_from_file()
return data.get(_key, {}).get("data", {})

@with_lock
async def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None,
message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool:
_key = self._get_key(
chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id
)
data = await self._read_from_file()
if _key in data:
data[_key]["data"] = {}
await self._write_to_file(data)
return True
return False

async def set_data(self, chat_id, user_id, key, value):
if self.data.get(chat_id):
if self.data[chat_id].get(user_id):
self.data[chat_id][user_id]['data'][key] = value
self.update_data()
return True
raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None,
message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]:
return StateDataContext(
self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id,
message_thread_id=message_thread_id, bot_id=bot_id
)

def get_interactive_data(self, chat_id, user_id):
return StateDataContext(self, chat_id, user_id)
@with_lock
async def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None,
message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool:
_key = self._get_key(
chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id
)
data = await self._read_from_file()
data[_key]["data"] = data
await self._write_to_file(data)
return True

async def save(self, chat_id, user_id, data):
self.data[chat_id][user_id]['data'] = data
self.update_data()
def __str__(self) -> str:
return f"StatePickleStorage({self.file_path}, {self.prefix})"
Loading

0 comments on commit ff6485d

Please sign in to comment.