Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor/agent
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Oct 15, 2024
2 parents 8f01717 + 71e95fb commit 05f87aa
Show file tree
Hide file tree
Showing 3 changed files with 1,569 additions and 1,402 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,6 @@ _build/
# model parameter
*.pth

logs/
logs/

.DS_Store
48 changes: 25 additions & 23 deletions crab/agents/backend_models/camel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

try:
from camel.agents import ChatAgent
from camel.configs import ChatGPTConfig
from camel.messages import BaseMessage
from camel.models import ModelFactory
from camel.toolkits import OpenAIFunction
Expand All @@ -33,29 +32,34 @@
CAMEL_ENABLED = False


def _find_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
for platform in ModelPlatformType:
if platform.value.lower() == model_platform_name.lower():
return platform
all_models = [platform.value for platform in ModelPlatformType]
raise ValueError(
f"Model {model_platform_name} not found. Supported models are {all_models}"
)
def _get_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
try:
return ModelPlatformType(model_platform_name)
except ValueError:
all_models = [platform.value for platform in ModelPlatformType]
raise ValueError(
f"Model {model_platform_name} not found. Supported models are {all_models}"
)


def _find_model_type(model_name: str) -> "str | ModelType":
for model in ModelType:
if model.value.lower() == model_name.lower():
return model
return model_name
def _get_model_type(model_name: str) -> "str | ModelType":
try:
return ModelType(model_name)
except ValueError:
return model_name


def _convert_action_to_schema(
action_space: list[Action] | None,
) -> "list[OpenAIFunction] | None":
if action_space is None:
return None
return [OpenAIFunction(action.entry) for action in action_space]
schema_list = []
for action in action_space:
new_action = action.to_openai_json_schema()
schema = {"type": "function", "function": new_action}
schema_list.append(OpenAIFunction(action.entry, schema))
return schema_list


def _convert_tool_calls_to_action_list(
Expand Down Expand Up @@ -84,9 +88,8 @@ def __init__(
if not CAMEL_ENABLED:
raise ImportError("Please install camel-ai to use CamelModel")
self.parameters = parameters or {}
# TODO: a better way?
self.model_type = _find_model_type(model)
self.model_platform_type = _find_model_platform_type(model_platform)
self.model_type = _get_model_type(model)
self.model_platform_type = _get_model_platform_type(model_platform)
self.client: ChatAgent | None = None
self.token_usage = 0

Expand All @@ -104,15 +107,14 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
config = self.parameters.copy()
if action_schema is not None:
config["tool_choice"] = "required"
config["tools"] = action_schema
config["tools"] = [
schema.get_openai_tool_schema() for schema in action_schema
]

chatgpt_config = ChatGPTConfig(
**config,
)
backend_model = ModelFactory.create(
self.model_platform_type,
self.model_type,
model_config_dict=chatgpt_config.as_dict(),
model_config_dict=config,
)
sysmsg = BaseMessage.make_assistant_message(
role_name="Assistant",
Expand Down
Loading

0 comments on commit 05f87aa

Please sign in to comment.