Skip to content

Commit

Permalink
Refactor claude model
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Sep 18, 2024
1 parent 943011b commit fb1af78
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 131 deletions.
257 changes: 130 additions & 127 deletions crab/agents/backend_models/claude_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from copy import deepcopy
from time import sleep
from typing import Any

from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed

from crab import Action, ActionOutput, BackendModel, BackendOutput, Message, MessageType

try:
import anthropic
Expand Down Expand Up @@ -47,166 +48,168 @@ def __init__(
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
self.action_space = action_space
self.action_schema = self._convert_action_to_schema(self.action_space)
self.action_schema = _convert_action_to_schema(self.action_space)
self.token_usage = 0
self.chat_history = []

def chat(self, message: list[tuple[str, MessageType]]) -> BackendOutput:
# Initialize chat history
request = []
if self.history_messages_len > 0 and len(self.chat_history) > 0:
for history_message in self.chat_history[-self.history_messages_len :]:
request = request + history_message
self.chat_history: list[list[dict]] = []

if not isinstance(message, list):
def chat(self, message: list[Message] | Message) -> BackendOutput:
if isinstance(message, tuple):
message = [message]

new_message = {
"role": "user",
"content": [self._convert_message(part) for part in message],
}
request = self.fetch_from_memory()
new_message = self.construct_new_message(message)
request.append(new_message)
request = self._merge_request(request)

response = self.call_api(request)
response_message = response
response_message = self.call_api(request)
self.record_message(new_message, response_message)
return self.generate_backend_output(response_message)

def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
parts: list[dict] = []
for content, msg_type in message:
match msg_type:
case MessageType.TEXT:
parts.append(
{
"type": "text",
"text": content,
}
)
case MessageType.IMAGE_JPG_BASE64:
parts.append(
{
"type": "image",
"source": {
"data": content,
"type": "base64",
"media_type": "image/png",
},
}
)
return {
"role": "user",
"content": parts,
}

return self._format_response(response_message.content)
def fetch_from_memory(self) -> list[dict]:
request: list[dict] = []
if self.history_messages_len > 0:
fetch_hisotry_len = min(self.history_messages_len, len(self.chat_history))
for history_message in self.chat_history[-fetch_hisotry_len:]:
request = request + history_message
return request

def get_token_usage(self):
return self.token_usage

def record_message(self, new_message: dict, response_message: dict) -> None:
def record_message(
self, new_message: dict, response_message: anthropic.types.Message
) -> None:
self.chat_history.append([new_message])
self.chat_history[-1].append(
{"role": response_message.role, "content": response_message.content}
)

if self.action_schema:
tool_calls = response_message.content
self.chat_history[-1].append(
{
"role": "user",
"content": [
tool_content = []
for call in tool_calls:
if isinstance(call, ToolUseBlock):
tool_content.append(
{
"type": "tool_result",
"tool_use_id": call.id,
"content": "success",
}
for call in tool_calls
if call is ToolUseBlock
],
)
self.chat_history[-1].append(
{
"role": "user",
"content": tool_content,
}
)

def call_api(self, request_messages: list):
while True:
try:
if self.action_schema is not None:
response = self.client.messages.create(
system=self.system_message, # <-- system prompt
messages=request_messages, # type: ignore
model=self.model,
tools=self.action_schema,
tool_choice={
"type": "any" if self.tool_call_required else "auto"
},
**self.parameters,
)
else:
response = self.client.messages.create(
system=self.system_message, # <-- system prompt
messages=request_messages, # type: ignore
model=self.model,
**self.parameters,
)
except anthropic.RateLimitError:
print("Rate Limit Error: Please waiting...")
sleep(10)
except anthropic.APIStatusError:
print(len(request_messages))
raise
else:
break
@retry(
wait=wait_fixed(10),
stop=stop_after_attempt(7),
retry=retry_if_exception_type(
(
anthropic.APITimeoutError,
anthropic.APIConnectionError,
anthropic.InternalServerError,
)
),
)
def call_api(self, request_messages: list[dict]) -> anthropic.types.Message:
request_messages = _merge_request(request_messages)
if self.action_schema is not None:
response = self.client.messages.create(
system=self.system_message, # <-- system prompt
messages=request_messages, # type: ignore
model=self.model,
tools=self.action_schema,
tool_choice={"type": "any" if self.tool_call_required else "auto"},
**self.parameters,
)
else:
response = self.client.messages.create(
system=self.system_message, # <-- system prompt
messages=request_messages, # type: ignore
model=self.model,
**self.parameters,
)

self.token_usage += response.usage.input_tokens + response.usage.output_tokens
return response

@staticmethod
def _convert_message(message: tuple[str, MessageType]):
match message[1]:
case MessageType.TEXT:
return {
"type": "text",
"text": message[0],
}
case MessageType.IMAGE_JPG_BASE64:
return {
"type": "image",
"source": {
"data": message[0],
"type": "base64",
"media_type": "image/png",
},
}

@staticmethod
def _convert_action_to_schema(action_space):
if action_space is None:
return None
actions = []
for action in action_space:
new_action = action.to_openai_json_schema()
new_action["input_schema"] = new_action.pop("parameters")
if "returns" in new_action:
new_action.pop("returns")
if "title" in new_action:
new_action.pop("title")
if "type" in new_action:
new_action["input_schema"]["type"] = new_action.pop("type")
if "required" in new_action:
new_action["input_schema"]["required"] = new_action.pop("required")

actions.append(new_action)
return actions

@staticmethod
def _convert_tool_calls_to_action_list(tool_calls) -> list[ActionOutput]:
if tool_calls is None:
return tool_calls
return [
ActionOutput(
name=call.name,
arguments=call.input,
)
for call in tool_calls
]

@staticmethod
def _merge_request(request: list[dict]):
merge_request = [deepcopy(request[0])]
for idx in range(1, len(request)):
if request[idx]["role"] == merge_request[-1]["role"]:
merge_request[-1]["content"].extend(request[idx]["content"])
else:
merge_request.append(deepcopy(request[idx]))

return merge_request

@classmethod
def _format_response(cls, content: list):
message = None
def generate_backend_output(
cls, response_message: anthropic.types.Message
) -> BackendOutput:
message = ""
action_list = []
for block in content:
for block in response_message.content:
if isinstance(block, TextBlock):
message = block.text
message += block.text
elif isinstance(block, ToolUseBlock):
action_list.append(block)
action_list.append(
ActionOutput(
name=block.name,
arguments=block.input, # type: ignore
)
)
if not action_list:
return BackendOutput(message=message, action_list=None)
else:
return BackendOutput(
message=message,
action_list=cls._convert_tool_calls_to_action_list(action_list),
action_list=action_list,
)


def _merge_request(request: list[dict]) -> list[dict]:
merge_request = [deepcopy(request[0])]
for idx in range(1, len(request)):
if request[idx]["role"] == merge_request[-1]["role"]:
merge_request[-1]["content"].extend(request[idx]["content"])
else:
merge_request.append(deepcopy(request[idx]))

return merge_request


def _convert_action_to_schema(action_space):
if action_space is None:
return None
actions = []
for action in action_space:
new_action = action.to_openai_json_schema()
new_action["input_schema"] = new_action.pop("parameters")
if "returns" in new_action:
new_action.pop("returns")
if "title" in new_action:
new_action.pop("title")
if "type" in new_action:
new_action["input_schema"]["type"] = new_action.pop("type")
if "required" in new_action:
new_action["input_schema"]["required"] = new_action.pop("required")

actions.append(new_action)
return actions
9 changes: 7 additions & 2 deletions crab/agents/backend_models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any

from PIL.Image import Image
from tenacity import retry, stop_after_attempt, wait_fixed
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed

from crab import Action, ActionOutput, BackendModel, BackendOutput, Message, MessageType
from crab.utils.common import base64_to_image, json_expand_refs
Expand All @@ -28,6 +28,7 @@
Part,
Tool,
)
from google.api_core.exceptions import ResourceExhausted
from google.generativeai.types import content_types

gemini_model_enable = True
Expand Down Expand Up @@ -120,7 +121,11 @@ def record_message(
{"role": response_message.role, "parts": response_message.parts}
)

@retry(wait=wait_fixed(10), stop=stop_after_attempt(7))
@retry(
wait=wait_fixed(10),
stop=stop_after_attempt(7),
retry=retry_if_exception_type(ResourceExhausted),
)
def call_api(self, request_messages: list) -> Content:
if self.action_schema is not None:
tool_config = content_types.to_tool_config(
Expand Down
4 changes: 2 additions & 2 deletions test/agents/backend_models/test_claude_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def add(a: int, b: int):
return a + b


# @pytest.mark.skip(reason="Mock data to be added")
@pytest.mark.skip(reason="Mock data to be added")
def test_text_chat(claude_model_text):
message = ("Hello!", MessageType.TEXT)
output = claude_model_text.chat(message)
Expand All @@ -63,7 +63,7 @@ def test_text_chat(claude_model_text):
assert len(claude_model_text.chat_history) == 3


# @pytest.mark.skip(reason="Mock data to be added")
@pytest.mark.skip(reason="Mock data to be added")
def test_action_chat(claude_model_text):
claude_model_text.reset("You are a helpful assistant.", [add])
message = (
Expand Down

0 comments on commit fb1af78

Please sign in to comment.