-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fce1c3d
commit 7e5a044
Showing
6 changed files
with
223 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
""" | ||
Contains classes for states and state groups. | ||
""" | ||
|
||
|
||
class State: | ||
""" | ||
Class representing a state. | ||
.. code-block:: python3 | ||
class MyStates(StatesGroup): | ||
my_state = State() # returns my_state:State string. | ||
""" | ||
def __init__(self) -> None: | ||
self.name: str = None | ||
self.group: StatesGroup = None | ||
def __str__(self) -> str: | ||
return f"<{self.group.__name__}:{self.name}>" | ||
|
||
|
||
class StatesGroup: | ||
""" | ||
Class representing common states. | ||
.. code-block:: python3 | ||
class MyStates(StatesGroup): | ||
my_state = State() # returns my_state:State string. | ||
""" | ||
def __init_subclass__(cls) -> None: | ||
state_list = [] | ||
for name, value in cls.__dict__.items(): | ||
if not name.startswith('__') and not callable(value) and isinstance(value, State): | ||
# change value of that variable | ||
value.name = ':'.join((cls.__name__, name)) | ||
value.group = cls | ||
state_list.append(value) | ||
cls._state_list = state_list | ||
|
||
@classmethod | ||
def state_list(self): | ||
return self._state_list |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .context import StateContext | ||
from .middleware import StateMiddleware | ||
|
||
__all__ = [ | ||
'StateContext', | ||
'StateMiddleware', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from telebot.states import State, StatesGroup | ||
from telebot.types import CallbackQuery, Message | ||
from telebot import TeleBot | ||
|
||
from typing import Union | ||
|
||
|
||
class StateContext(): | ||
""" | ||
Class representing a state context. | ||
Passed through a middleware to provide easy way to set states. | ||
.. code-block:: python3 | ||
@bot.message_handler(commands=['start']) | ||
def start_ex(message: types.Message, state_context: StateContext): | ||
state_context.set(MyStates.name) | ||
bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) | ||
# also, state_context.data(), .add_data(), .reset_data(), .delete() methods available. | ||
""" | ||
|
||
def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: | ||
self.message: Union[Message, CallbackQuery] = message | ||
self.bot: TeleBot = bot | ||
self.bot_id = self.bot.bot_id | ||
|
||
def _resolve_context(self) -> Union[Message, CallbackQuery]: | ||
chat_id = None | ||
user_id = None | ||
business_connection_id = self.message.business_connection_id | ||
bot_id = self.bot_id | ||
message_thread_id = None | ||
|
||
if isinstance(self.message, Message): | ||
chat_id = self.message.chat.id | ||
user_id = self.message.from_user.id | ||
message_thread_id = self.message.message_thread_id if self.message.is_topic_message else None | ||
elif isinstance(self.message, CallbackQuery): | ||
chat_id = self.message.message.chat.id | ||
user_id = self.message.from_user.id | ||
message_thread_id = self.message.message.message_thread_id if self.message.message.is_topic_message else None | ||
|
||
return chat_id, user_id, business_connection_id, bot_id, message_thread_id | ||
|
||
def set(self, state: Union[State, str]) -> None: | ||
""" | ||
Set state for current user. | ||
:param state: State object or state name. | ||
:type state: Union[State, str] | ||
.. code-block:: python3 | ||
@bot.message_handler(commands=['start']) | ||
def start_ex(message: types.Message, state_context: StateContext): | ||
state_context.set(MyStates.name) | ||
bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) | ||
""" | ||
|
||
chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() | ||
if isinstance(state, State): | ||
state = state.name | ||
return self.bot.set_state( | ||
chat_id=chat_id, | ||
user_id=user_id, | ||
state=state, | ||
business_connection_id=business_connection_id, | ||
bot_id=bot_id, | ||
message_thread_id=message_thread_id | ||
) | ||
|
||
def get(self) -> str: | ||
""" | ||
Get current state for current user. | ||
:return: Current state name. | ||
:rtype: str | ||
""" | ||
|
||
chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() | ||
return self.bot.get_state( | ||
chat_id=chat_id, | ||
user_id=user_id, | ||
business_connection_id=business_connection_id, | ||
bot_id=bot_id, | ||
message_thread_id=message_thread_id | ||
) | ||
|
||
def reset_data(self) -> None: | ||
""" | ||
Reset data for current user. | ||
State will not be changed. | ||
""" | ||
|
||
chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() | ||
return self.bot.reset_data( | ||
chat_id=chat_id, | ||
user_id=user_id, | ||
business_connection_id=business_connection_id, | ||
bot_id=bot_id, | ||
message_thread_id=message_thread_id | ||
) | ||
|
||
def delete(self) -> None: | ||
""" | ||
Deletes state and data for current user. | ||
""" | ||
chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() | ||
return self.bot.delete_state( | ||
chat_id=chat_id, | ||
user_id=user_id, | ||
business_connection_id=business_connection_id, | ||
bot_id=bot_id, | ||
message_thread_id=message_thread_id | ||
) | ||
|
||
def data(self) -> dict: | ||
""" | ||
Get data for current user. | ||
.. code-block:: python3 | ||
with state_context.data() as data: | ||
print(data) | ||
""" | ||
|
||
chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() | ||
return self.bot.retrieve_data( | ||
chat_id=chat_id, | ||
user_id=user_id, | ||
business_connection_id=business_connection_id, | ||
bot_id=bot_id, | ||
message_thread_id=message_thread_id | ||
) | ||
|
||
def add_data(self, **kwargs) -> None: | ||
""" | ||
Add data for current user. | ||
:param kwargs: Data to add. | ||
:type kwargs: dict | ||
""" | ||
|
||
chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() | ||
return self.bot.add_data( | ||
chat_id=chat_id, | ||
user_id=user_id, | ||
business_connection_id=business_connection_id, | ||
bot_id=bot_id, | ||
message_thread_id=message_thread_id, | ||
**kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from telebot.handler_backends import BaseMiddleware | ||
from telebot import TeleBot | ||
from telebot.states.sync.context import StateContext | ||
|
||
|
||
class StateMiddleware(BaseMiddleware): | ||
|
||
def __init__(self, bot: TeleBot) -> None: | ||
self.update_sensitive = False | ||
self.update_types = ['message', 'edited_message', 'callback_query'] #TODO: support other types | ||
self.bot: TeleBot = bot | ||
|
||
def pre_process(self, message, data): | ||
data['state_context'] = StateContext(message, self.bot) | ||
|
||
def post_process(self, message, data, exception): | ||
pass |