From fb1af781e8e0999ee78d5c0803c372598b0e3f27 Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Wed, 18 Sep 2024 20:23:57 +0300 Subject: [PATCH] Refactor claude model --- crab/agents/backend_models/claude_model.py | 257 +++++++++--------- crab/agents/backend_models/gemini_model.py | 9 +- .../backend_models/test_claude_model.py | 4 +- 3 files changed, 139 insertions(+), 131 deletions(-) diff --git a/crab/agents/backend_models/claude_model.py b/crab/agents/backend_models/claude_model.py index cf03e55..27a499f 100644 --- a/crab/agents/backend_models/claude_model.py +++ b/crab/agents/backend_models/claude_model.py @@ -12,10 +12,11 @@ # limitations under the License. # =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== from copy import deepcopy -from time import sleep from typing import Any -from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from crab import Action, ActionOutput, BackendModel, BackendOutput, Message, MessageType try: import anthropic @@ -47,37 +48,61 @@ def __init__( def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message self.action_space = action_space - self.action_schema = self._convert_action_to_schema(self.action_space) + self.action_schema = _convert_action_to_schema(self.action_space) self.token_usage = 0 - self.chat_history = [] - - def chat(self, message: list[tuple[str, MessageType]]) -> BackendOutput: - # Initialize chat history - request = [] - if self.history_messages_len > 0 and len(self.chat_history) > 0: - for history_message in self.chat_history[-self.history_messages_len :]: - request = request + history_message + self.chat_history: list[list[dict]] = [] - if not isinstance(message, list): + def chat(self, message: list[Message] | Message) -> BackendOutput: + if isinstance(message, tuple): message = [message] - - new_message = { - "role": "user", - "content": [self._convert_message(part) for part in message], - } + request = self.fetch_from_memory() + new_message = self.construct_new_message(message) request.append(new_message) - request = self._merge_request(request) - - response = self.call_api(request) - response_message = response + response_message = self.call_api(request) self.record_message(new_message, response_message) + return self.generate_backend_output(response_message) + + def construct_new_message(self, message: list[Message]) -> dict[str, Any]: + parts: list[dict] = [] + for content, msg_type in message: + match msg_type: + case MessageType.TEXT: + parts.append( + { + "type": "text", + "text": content, + } + ) + case MessageType.IMAGE_JPG_BASE64: + parts.append( + { + "type": "image", + "source": { + "data": content, + "type": "base64", + "media_type": "image/png", + }, + } + ) + return { + "role": "user", + "content": parts, + } - return self._format_response(response_message.content) + def fetch_from_memory(self) -> list[dict]: + request: list[dict] = [] + if self.history_messages_len > 0: + fetch_hisotry_len = min(self.history_messages_len, len(self.chat_history)) + for history_message in self.chat_history[-fetch_hisotry_len:]: + request = request + history_message + return request def get_token_usage(self): return self.token_usage - def record_message(self, new_message: dict, response_message: dict) -> None: + def record_message( + self, new_message: dict, response_message: anthropic.types.Message + ) -> None: self.chat_history.append([new_message]) self.chat_history[-1].append( {"role": response_message.role, "content": response_message.content} @@ -85,128 +110,106 @@ def record_message(self, new_message: dict, response_message: dict) -> None: if self.action_schema: tool_calls = response_message.content - self.chat_history[-1].append( - { - "role": "user", - "content": [ + tool_content = [] + for call in tool_calls: + if isinstance(call, ToolUseBlock): + tool_content.append( { "type": "tool_result", "tool_use_id": call.id, "content": "success", } - for call in tool_calls - if call is ToolUseBlock - ], + ) + self.chat_history[-1].append( + { + "role": "user", + "content": tool_content, } ) - def call_api(self, request_messages: list): - while True: - try: - if self.action_schema is not None: - response = self.client.messages.create( - system=self.system_message, # <-- system prompt - messages=request_messages, # type: ignore - model=self.model, - tools=self.action_schema, - tool_choice={ - "type": "any" if self.tool_call_required else "auto" - }, - **self.parameters, - ) - else: - response = self.client.messages.create( - system=self.system_message, # <-- system prompt - messages=request_messages, # type: ignore - model=self.model, - **self.parameters, - ) - except anthropic.RateLimitError: - print("Rate Limit Error: Please waiting...") - sleep(10) - except anthropic.APIStatusError: - print(len(request_messages)) - raise - else: - break + @retry( + wait=wait_fixed(10), + stop=stop_after_attempt(7), + retry=retry_if_exception_type( + ( + anthropic.APITimeoutError, + anthropic.APIConnectionError, + anthropic.InternalServerError, + ) + ), + ) + def call_api(self, request_messages: list[dict]) -> anthropic.types.Message: + request_messages = _merge_request(request_messages) + if self.action_schema is not None: + response = self.client.messages.create( + system=self.system_message, # <-- system prompt + messages=request_messages, # type: ignore + model=self.model, + tools=self.action_schema, + tool_choice={"type": "any" if self.tool_call_required else "auto"}, + **self.parameters, + ) + else: + response = self.client.messages.create( + system=self.system_message, # <-- system prompt + messages=request_messages, # type: ignore + model=self.model, + **self.parameters, + ) self.token_usage += response.usage.input_tokens + response.usage.output_tokens return response - @staticmethod - def _convert_message(message: tuple[str, MessageType]): - match message[1]: - case MessageType.TEXT: - return { - "type": "text", - "text": message[0], - } - case MessageType.IMAGE_JPG_BASE64: - return { - "type": "image", - "source": { - "data": message[0], - "type": "base64", - "media_type": "image/png", - }, - } - - @staticmethod - def _convert_action_to_schema(action_space): - if action_space is None: - return None - actions = [] - for action in action_space: - new_action = action.to_openai_json_schema() - new_action["input_schema"] = new_action.pop("parameters") - if "returns" in new_action: - new_action.pop("returns") - if "title" in new_action: - new_action.pop("title") - if "type" in new_action: - new_action["input_schema"]["type"] = new_action.pop("type") - if "required" in new_action: - new_action["input_schema"]["required"] = new_action.pop("required") - - actions.append(new_action) - return actions - - @staticmethod - def _convert_tool_calls_to_action_list(tool_calls) -> list[ActionOutput]: - if tool_calls is None: - return tool_calls - return [ - ActionOutput( - name=call.name, - arguments=call.input, - ) - for call in tool_calls - ] - - @staticmethod - def _merge_request(request: list[dict]): - merge_request = [deepcopy(request[0])] - for idx in range(1, len(request)): - if request[idx]["role"] == merge_request[-1]["role"]: - merge_request[-1]["content"].extend(request[idx]["content"]) - else: - merge_request.append(deepcopy(request[idx])) - - return merge_request - - @classmethod - def _format_response(cls, content: list): - message = None + def generate_backend_output( + cls, response_message: anthropic.types.Message + ) -> BackendOutput: + message = "" action_list = [] - for block in content: + for block in response_message.content: if isinstance(block, TextBlock): - message = block.text + message += block.text elif isinstance(block, ToolUseBlock): - action_list.append(block) + action_list.append( + ActionOutput( + name=block.name, + arguments=block.input, # type: ignore + ) + ) if not action_list: return BackendOutput(message=message, action_list=None) else: return BackendOutput( message=message, - action_list=cls._convert_tool_calls_to_action_list(action_list), + action_list=action_list, ) + + +def _merge_request(request: list[dict]) -> list[dict]: + merge_request = [deepcopy(request[0])] + for idx in range(1, len(request)): + if request[idx]["role"] == merge_request[-1]["role"]: + merge_request[-1]["content"].extend(request[idx]["content"]) + else: + merge_request.append(deepcopy(request[idx])) + + return merge_request + + +def _convert_action_to_schema(action_space): + if action_space is None: + return None + actions = [] + for action in action_space: + new_action = action.to_openai_json_schema() + new_action["input_schema"] = new_action.pop("parameters") + if "returns" in new_action: + new_action.pop("returns") + if "title" in new_action: + new_action.pop("title") + if "type" in new_action: + new_action["input_schema"]["type"] = new_action.pop("type") + if "required" in new_action: + new_action["input_schema"]["required"] = new_action.pop("required") + + actions.append(new_action) + return actions diff --git a/crab/agents/backend_models/gemini_model.py b/crab/agents/backend_models/gemini_model.py index e33d16b..4a25b56 100644 --- a/crab/agents/backend_models/gemini_model.py +++ b/crab/agents/backend_models/gemini_model.py @@ -15,7 +15,7 @@ from typing import Any from PIL.Image import Image -from tenacity import retry, stop_after_attempt, wait_fixed +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from crab import Action, ActionOutput, BackendModel, BackendOutput, Message, MessageType from crab.utils.common import base64_to_image, json_expand_refs @@ -28,6 +28,7 @@ Part, Tool, ) + from google.api_core.exceptions import ResourceExhausted from google.generativeai.types import content_types gemini_model_enable = True @@ -120,7 +121,11 @@ def record_message( {"role": response_message.role, "parts": response_message.parts} ) - @retry(wait=wait_fixed(10), stop=stop_after_attempt(7)) + @retry( + wait=wait_fixed(10), + stop=stop_after_attempt(7), + retry=retry_if_exception_type(ResourceExhausted), + ) def call_api(self, request_messages: list) -> Content: if self.action_schema is not None: tool_config = content_types.to_tool_config( diff --git a/test/agents/backend_models/test_claude_model.py b/test/agents/backend_models/test_claude_model.py index be3ddb8..f8e361a 100644 --- a/test/agents/backend_models/test_claude_model.py +++ b/test/agents/backend_models/test_claude_model.py @@ -42,7 +42,7 @@ def add(a: int, b: int): return a + b -# @pytest.mark.skip(reason="Mock data to be added") +@pytest.mark.skip(reason="Mock data to be added") def test_text_chat(claude_model_text): message = ("Hello!", MessageType.TEXT) output = claude_model_text.chat(message) @@ -63,7 +63,7 @@ def test_text_chat(claude_model_text): assert len(claude_model_text.chat_history) == 3 -# @pytest.mark.skip(reason="Mock data to be added") +@pytest.mark.skip(reason="Mock data to be added") def test_action_chat(claude_model_text): claude_model_text.reset("You are a helpful assistant.", [add]) message = (