From 4a7bc5d5006b664848feff84b5200c361b3e170b Mon Sep 17 00:00:00 2001 From: _run Date: Mon, 8 Jul 2024 21:47:38 +0500 Subject: [PATCH] all update types are supported for states(theoretically) --- telebot/custom_filters.py | 48 +++++++++---------------------- telebot/states/__init__.py | 41 +++++++++++++++++++++++++- telebot/states/sync/context.py | 37 +++++++----------------- telebot/states/sync/middleware.py | 6 ++-- tests/test_handler_backends.py | 2 +- 5 files changed, 69 insertions(+), 65 deletions(-) diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index 41a7c780c..1c5bd0b40 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -4,7 +4,7 @@ from telebot import types - +from telebot.states import resolve_context @@ -407,17 +407,7 @@ def check(self, message, text): """ if text == '*': return True - # needs to work with callbackquery - if isinstance(message, types.Message): - chat_id = message.chat.id - user_id = message.from_user.id - - if isinstance(message, types.CallbackQuery): - - chat_id = message.message.chat.id - user_id = message.from_user.id - message = message.message - + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) if isinstance(text, list): new_text = [] @@ -428,29 +418,17 @@ def check(self, message, text): elif isinstance(text, State): text = text.name - if message.chat.type in ['group', 'supergroup']: - group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id, - message_thread_id=message.message_thread_id) - if group_state is None and not message.is_topic_message: # needed for general topic and group messages - group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id) - - if group_state == text: - return True - elif type(text) is list and group_state in text: - return True - - - else: - user_state = self.bot.current_states.get_state( - chat_id=chat_id, - user_id=user_id, - business_connection_id=message.business_connection_id, - bot_id=self.bot._user.id - ) - if user_state == text: - return True - elif type(text) is list and user_state in text: - return True + user_state = self.bot.current_states.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + if user_state == text: + return True + elif type(text) is list and user_state in text: + return True class IsDigitFilter(SimpleCustomFilter): diff --git a/telebot/states/__init__.py b/telebot/states/__init__.py index 0a45f17e1..b8efe3bbe 100644 --- a/telebot/states/__init__.py +++ b/telebot/states/__init__.py @@ -1,7 +1,7 @@ """ Contains classes for states and state groups. """ - +from telebot import types class State: """ @@ -41,3 +41,42 @@ def __init_subclass__(cls) -> None: @classmethod def state_list(self): return self._state_list + +def resolve_context(message, bot_id: int) -> tuple: + # chat_id, user_id, business_connection_id, bot_id, message_thread_id + + # message, edited_message, channel_post, edited_channel_post, business_message, edited_business_message + if isinstance(message, types.Message): + return (message.chat.id, message.from_user.id, message.business_connection_id, bot_id, + message.message_thread_id if message.is_topic_message else None) + elif isinstance(message, types.CallbackQuery): # callback_query + return (message.message.chat.id, message.from_user.id, message.message.business_connection_id, bot_id, + message.message.message_thread_id if message.message.is_topic_message else None) + elif isinstance(message, types.BusinessConnection): # business_connection + return (message.user_chat_id, message.user.id, message.id, bot_id, None) + elif isinstance(message, types.BusinessMessagesDeleted): # deleted_business_messages + return (message.chat.id, message.chat.id, message.business_connection_id, bot_id, None) + elif isinstance(message, types.MessageReactionUpdated): # message_reaction + return (message.chat.id, message.user.id, None, bot_id, None) + elif isinstance(message, types.MessageReactionCountUpdated): # message_reaction_count + return (message.chat.id, None, None, bot_id, None) + elif isinstance(message, types.InlineQuery): # inline_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChosenInlineResult): # chosen_inline_result + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ShippingQuery): # shipping_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.PreCheckoutQuery): # pre_checkout_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.PollAnswer): # poll_answer + return (None, message.user.id, None, bot_id, None) + elif isinstance(message, types.ChatMemberUpdated): # chat_member # my_chat_member + return (message.chat.id, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChatJoinRequest): # chat_join_request + return (message.chat.id, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChatBoostRemoved): # removed_chat_boost + return (message.chat.id, message.source.user.id if message.source else None, None, bot_id, None) + elif isinstance(message, types.ChatBoostUpdated): # chat_boost + return (message.chat.id, message.boost.source.user.id if message.boost.source else None, None, bot_id, None) + else: + pass # not yet supported :( \ No newline at end of file diff --git a/telebot/states/sync/context.py b/telebot/states/sync/context.py index f0a395d01..c8009007a 100644 --- a/telebot/states/sync/context.py +++ b/telebot/states/sync/context.py @@ -1,10 +1,12 @@ from telebot.states import State, StatesGroup from telebot.types import CallbackQuery, Message -from telebot import TeleBot +from telebot import TeleBot, types +from telebot.states import resolve_context from typing import Union + class StateContext(): """ Class representing a state context. @@ -25,24 +27,6 @@ def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: self.bot: TeleBot = bot self.bot_id = self.bot.bot_id - def _resolve_context(self) -> Union[Message, CallbackQuery]: - chat_id = None - user_id = None - business_connection_id = self.message.business_connection_id - bot_id = self.bot_id - message_thread_id = None - - if isinstance(self.message, Message): - chat_id = self.message.chat.id - user_id = self.message.from_user.id - message_thread_id = self.message.message_thread_id if self.message.is_topic_message else None - elif isinstance(self.message, CallbackQuery): - chat_id = self.message.message.chat.id - user_id = self.message.from_user.id - message_thread_id = self.message.message.message_thread_id if self.message.message.is_topic_message else None - - return chat_id, user_id, business_connection_id, bot_id, message_thread_id - def set(self, state: Union[State, str]) -> None: """ Set state for current user. @@ -58,7 +42,7 @@ def start_ex(message: types.Message, state_context: StateContext): bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) if isinstance(state, State): state = state.name return self.bot.set_state( @@ -78,7 +62,7 @@ def get(self) -> str: :rtype: str """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.get_state( chat_id=chat_id, user_id=user_id, @@ -93,7 +77,7 @@ def reset_data(self) -> None: State will not be changed. """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.reset_data( chat_id=chat_id, user_id=user_id, @@ -106,7 +90,7 @@ def delete(self) -> None: """ Deletes state and data for current user. """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.delete_state( chat_id=chat_id, user_id=user_id, @@ -125,7 +109,7 @@ def data(self) -> dict: print(data) """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.retrieve_data( chat_id=chat_id, user_id=user_id, @@ -142,7 +126,7 @@ def add_data(self, **kwargs) -> None: :type kwargs: dict """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.add_data( chat_id=chat_id, user_id=user_id, @@ -150,4 +134,5 @@ def add_data(self, **kwargs) -> None: bot_id=bot_id, message_thread_id=message_thread_id, **kwargs - ) \ No newline at end of file + ) + \ No newline at end of file diff --git a/telebot/states/sync/middleware.py b/telebot/states/sync/middleware.py index e1c74cc23..b85f795da 100644 --- a/telebot/states/sync/middleware.py +++ b/telebot/states/sync/middleware.py @@ -1,17 +1,19 @@ from telebot.handler_backends import BaseMiddleware from telebot import TeleBot from telebot.states.sync.context import StateContext +from telebot.util import update_types +from telebot import types class StateMiddleware(BaseMiddleware): def __init__(self, bot: TeleBot) -> None: self.update_sensitive = False - self.update_types = ['message', 'edited_message', 'callback_query'] #TODO: support other types + self.update_types = update_types self.bot: TeleBot = bot def pre_process(self, message, data): data['state_context'] = StateContext(message, self.bot) def post_process(self, message, data, exception): - pass \ No newline at end of file + pass diff --git a/tests/test_handler_backends.py b/tests/test_handler_backends.py index bb541bf8a..f57200c1c 100644 --- a/tests/test_handler_backends.py +++ b/tests/test_handler_backends.py @@ -19,7 +19,7 @@ @pytest.fixture() def telegram_bot(): - return telebot.TeleBot('', threaded=False) + return telebot.TeleBot('', threaded=False, token_check=False) @pytest.fixture