Skip to content

Commit

Permalink
all update types are supported for states(theoretically)
Browse files Browse the repository at this point in the history
  • Loading branch information
coder2020official committed Jul 8, 2024
1 parent 7e5a044 commit 4a7bc5d
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 65 deletions.
48 changes: 13 additions & 35 deletions telebot/custom_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from telebot import types


from telebot.states import resolve_context



Expand Down Expand Up @@ -407,17 +407,7 @@ def check(self, message, text):
"""
if text == '*': return True

# needs to work with callbackquery
if isinstance(message, types.Message):
chat_id = message.chat.id
user_id = message.from_user.id

if isinstance(message, types.CallbackQuery):

chat_id = message.message.chat.id
user_id = message.from_user.id
message = message.message

chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id)

if isinstance(text, list):
new_text = []
Expand All @@ -428,29 +418,17 @@ def check(self, message, text):
elif isinstance(text, State):
text = text.name

if message.chat.type in ['group', 'supergroup']:
group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id,
message_thread_id=message.message_thread_id)
if group_state is None and not message.is_topic_message: # needed for general topic and group messages
group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id)

if group_state == text:
return True
elif type(text) is list and group_state in text:
return True


else:
user_state = self.bot.current_states.get_state(
chat_id=chat_id,
user_id=user_id,
business_connection_id=message.business_connection_id,
bot_id=self.bot._user.id
)
if user_state == text:
return True
elif type(text) is list and user_state in text:
return True
user_state = self.bot.current_states.get_state(
chat_id=chat_id,
user_id=user_id,
business_connection_id=business_connection_id,
bot_id=bot_id,
message_thread_id=message_thread_id
)
if user_state == text:
return True
elif type(text) is list and user_state in text:
return True


class IsDigitFilter(SimpleCustomFilter):
Expand Down
41 changes: 40 additions & 1 deletion telebot/states/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Contains classes for states and state groups.
"""

from telebot import types

class State:
"""
Expand Down Expand Up @@ -41,3 +41,42 @@ def __init_subclass__(cls) -> None:
@classmethod
def state_list(self):
return self._state_list

def resolve_context(message, bot_id: int) -> tuple:
# chat_id, user_id, business_connection_id, bot_id, message_thread_id

# message, edited_message, channel_post, edited_channel_post, business_message, edited_business_message
if isinstance(message, types.Message):
return (message.chat.id, message.from_user.id, message.business_connection_id, bot_id,
message.message_thread_id if message.is_topic_message else None)
elif isinstance(message, types.CallbackQuery): # callback_query
return (message.message.chat.id, message.from_user.id, message.message.business_connection_id, bot_id,
message.message.message_thread_id if message.message.is_topic_message else None)
elif isinstance(message, types.BusinessConnection): # business_connection
return (message.user_chat_id, message.user.id, message.id, bot_id, None)
elif isinstance(message, types.BusinessMessagesDeleted): # deleted_business_messages
return (message.chat.id, message.chat.id, message.business_connection_id, bot_id, None)
elif isinstance(message, types.MessageReactionUpdated): # message_reaction
return (message.chat.id, message.user.id, None, bot_id, None)
elif isinstance(message, types.MessageReactionCountUpdated): # message_reaction_count
return (message.chat.id, None, None, bot_id, None)
elif isinstance(message, types.InlineQuery): # inline_query
return (None, message.from_user.id, None, bot_id, None)
elif isinstance(message, types.ChosenInlineResult): # chosen_inline_result
return (None, message.from_user.id, None, bot_id, None)
elif isinstance(message, types.ShippingQuery): # shipping_query
return (None, message.from_user.id, None, bot_id, None)
elif isinstance(message, types.PreCheckoutQuery): # pre_checkout_query
return (None, message.from_user.id, None, bot_id, None)
elif isinstance(message, types.PollAnswer): # poll_answer
return (None, message.user.id, None, bot_id, None)
elif isinstance(message, types.ChatMemberUpdated): # chat_member # my_chat_member
return (message.chat.id, message.from_user.id, None, bot_id, None)
elif isinstance(message, types.ChatJoinRequest): # chat_join_request
return (message.chat.id, message.from_user.id, None, bot_id, None)
elif isinstance(message, types.ChatBoostRemoved): # removed_chat_boost
return (message.chat.id, message.source.user.id if message.source else None, None, bot_id, None)
elif isinstance(message, types.ChatBoostUpdated): # chat_boost
return (message.chat.id, message.boost.source.user.id if message.boost.source else None, None, bot_id, None)
else:
pass # not yet supported :(
37 changes: 11 additions & 26 deletions telebot/states/sync/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from telebot.states import State, StatesGroup
from telebot.types import CallbackQuery, Message
from telebot import TeleBot
from telebot import TeleBot, types
from telebot.states import resolve_context

from typing import Union



class StateContext():
"""
Class representing a state context.
Expand All @@ -25,24 +27,6 @@ def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None:
self.bot: TeleBot = bot
self.bot_id = self.bot.bot_id

def _resolve_context(self) -> Union[Message, CallbackQuery]:
chat_id = None
user_id = None
business_connection_id = self.message.business_connection_id
bot_id = self.bot_id
message_thread_id = None

if isinstance(self.message, Message):
chat_id = self.message.chat.id
user_id = self.message.from_user.id
message_thread_id = self.message.message_thread_id if self.message.is_topic_message else None
elif isinstance(self.message, CallbackQuery):
chat_id = self.message.message.chat.id
user_id = self.message.from_user.id
message_thread_id = self.message.message.message_thread_id if self.message.message.is_topic_message else None

return chat_id, user_id, business_connection_id, bot_id, message_thread_id

def set(self, state: Union[State, str]) -> None:
"""
Set state for current user.
Expand All @@ -58,7 +42,7 @@ def start_ex(message: types.Message, state_context: StateContext):
bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id)
"""

chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context()
chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
if isinstance(state, State):
state = state.name
return self.bot.set_state(
Expand All @@ -78,7 +62,7 @@ def get(self) -> str:
:rtype: str
"""

chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context()
chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
return self.bot.get_state(
chat_id=chat_id,
user_id=user_id,
Expand All @@ -93,7 +77,7 @@ def reset_data(self) -> None:
State will not be changed.
"""

chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context()
chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
return self.bot.reset_data(
chat_id=chat_id,
user_id=user_id,
Expand All @@ -106,7 +90,7 @@ def delete(self) -> None:
"""
Deletes state and data for current user.
"""
chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context()
chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
return self.bot.delete_state(
chat_id=chat_id,
user_id=user_id,
Expand All @@ -125,7 +109,7 @@ def data(self) -> dict:
print(data)
"""

chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context()
chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
return self.bot.retrieve_data(
chat_id=chat_id,
user_id=user_id,
Expand All @@ -142,12 +126,13 @@ def add_data(self, **kwargs) -> None:
:type kwargs: dict
"""

chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context()
chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
return self.bot.add_data(
chat_id=chat_id,
user_id=user_id,
business_connection_id=business_connection_id,
bot_id=bot_id,
message_thread_id=message_thread_id,
**kwargs
)
)

6 changes: 4 additions & 2 deletions telebot/states/sync/middleware.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from telebot.handler_backends import BaseMiddleware
from telebot import TeleBot
from telebot.states.sync.context import StateContext
from telebot.util import update_types
from telebot import types


class StateMiddleware(BaseMiddleware):

def __init__(self, bot: TeleBot) -> None:
self.update_sensitive = False
self.update_types = ['message', 'edited_message', 'callback_query'] #TODO: support other types
self.update_types = update_types
self.bot: TeleBot = bot

def pre_process(self, message, data):
data['state_context'] = StateContext(message, self.bot)

def post_process(self, message, data, exception):
pass
pass
2 changes: 1 addition & 1 deletion tests/test_handler_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture()
def telegram_bot():
return telebot.TeleBot('', threaded=False)
return telebot.TeleBot('', threaded=False, token_check=False)


@pytest.fixture
Expand Down

0 comments on commit 4a7bc5d

Please sign in to comment.