diff --git a/crab-benchmark-v0/main.py b/crab-benchmark-v0/main.py index 9be9655..231d4d0 100644 --- a/crab-benchmark-v0/main.py +++ b/crab-benchmark-v0/main.py @@ -180,10 +180,9 @@ def get_benchmark(env: str, ubuntu_url: str): loglevel = args.loglevel numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % loglevel) + raise ValueError("Invalid log level: %s" % loglevel) logging.basicConfig(level=numeric_level) - benchmark = get_benchmark(args.env, args.remote_url) if args.model == "gpt4o": diff --git a/crab/agents/backend_models/camel_model.py b/crab/agents/backend_models/camel_model.py index f58fe6a..38dcabe 100644 --- a/crab/agents/backend_models/camel_model.py +++ b/crab/agents/backend_models/camel_model.py @@ -14,6 +14,7 @@ import json from typing import Any +from openai.types.chat import ChatCompletionMessageToolCall from PIL import Image from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType @@ -32,6 +33,46 @@ CAMEL_ENABLED = False +def _find_model_platform_type(model_platform_name: str) -> "ModelPlatformType": + for platform in ModelPlatformType: + if platform.value.lower() == model_platform_name.lower(): + return platform + all_models = [platform.value for platform in ModelPlatformType] + raise ValueError( + f"Model {model_platform_name} not found. Supported models are {all_models}" + ) + + +def _find_model_type(model_name: str) -> "str | ModelType": + for model in ModelType: + if model.value.lower() == model_name.lower(): + return model + return model_name + + +def _convert_action_to_schema( + action_space: list[Action] | None, +) -> "list[OpenAIFunction] | None": + if action_space is None: + return None + return [OpenAIFunction(action.entry) for action in action_space] + + +def _convert_tool_calls_to_action_list( + tool_calls: list[ChatCompletionMessageToolCall] | None, +) -> list[ActionOutput] | None: + if tool_calls is None: + return None + + return [ + ActionOutput( + name=call.function.name, + arguments=json.loads(call.function.arguments), + ) + for call in tool_calls + ] + + class CamelModel(BackendModel): def __init__( self, @@ -44,8 +85,8 @@ def __init__( raise ImportError("Please install camel-ai to use CamelModel") self.parameters = parameters or {} # TODO: a better way? - self.model_type = self.find_model_type(model) - self.model_platform_type = self.find_model_platform_type(model_platform) + self.model_type = _find_model_type(model) + self.model_platform_type = _find_model_platform_type(model_platform) self.client: ChatAgent | None = None self.token_usage = 0 @@ -55,11 +96,11 @@ def __init__( history_messages_len, ) - def get_token_usage(self): + def get_token_usage(self) -> int: return self.token_usage def reset(self, system_message: str, action_space: list[Action] | None) -> None: - action_schema = self._convert_action_to_schema(action_space) + action_schema = _convert_action_to_schema(action_space) config = self.parameters.copy() if action_schema is not None: config["tool_choice"] = "required" @@ -85,45 +126,7 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None: ) self.token_usage = 0 - @staticmethod - def find_model_platform_type(model_platform_name: str) -> "ModelPlatformType": - for platform in ModelPlatformType: - if platform.value.lower() == model_platform_name.lower(): - return platform - all_models = [platform.value for platform in ModelPlatformType] - raise ValueError( - f"Model {model_platform_name} not found. Supported models are {all_models}" - ) - - @staticmethod - def find_model_type(model_name: str) -> "str | ModelType": - for model in ModelType: - if model.value.lower() == model_name.lower(): - return model - return model_name - - @staticmethod - def _convert_action_to_schema( - action_space: list[Action] | None, - ) -> "list[OpenAIFunction] | None": - if action_space is None: - return None - return [OpenAIFunction(action.entry) for action in action_space] - - @staticmethod - def _convert_tool_calls_to_action_list(tool_calls) -> list[ActionOutput]: - if tool_calls is None: - return tool_calls - - return [ - ActionOutput( - name=call.function.name, - arguments=json.loads(call.function.arguments), - ) - for call in tool_calls - ] - - def chat(self, messages: list[tuple[str, MessageType]]): + def chat(self, messages: list[tuple[str, MessageType]]) -> BackendOutput: # TODO: handle multiple text messages after message refactoring image_list: list[Image.Image] = [] content = "" @@ -144,5 +147,5 @@ def chat(self, messages: list[tuple[str, MessageType]]): return BackendOutput( message=response.msg.content, - action_list=self._convert_tool_calls_to_action_list([tool_call_request]), + action_list=_convert_tool_calls_to_action_list([tool_call_request]), ) diff --git a/crab/utils/measure.py b/crab/utils/measure.py index df42d43..bfa4497 100644 --- a/crab/utils/measure.py +++ b/crab/utils/measure.py @@ -1,3 +1,16 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== import logging import time from functools import wraps diff --git a/pyproject.toml b/pyproject.toml index 0e01c98..96f6dd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ google-generativeai = { version = "^0.6.0", optional = true } anthropic = { version = "^0.29.0", optional = true } groq = { version = "^0.5.0", optional = true } ollama = { version = "^0.2.0", optional = true } -camel-ai = { version="^0.1.8", extras=["all"], optional=true } +camel-ai = { version = "^0.1.8", extras = ["all"], optional = true } # text ocr easyocr = { version = "^1.7.1", optional = true } diff --git a/test/actions/test_visual_prompt_actions.py b/test/actions/test_visual_prompt_actions.py index 14164b6..e0b9792 100644 --- a/test/actions/test_visual_prompt_actions.py +++ b/test/actions/test_visual_prompt_actions.py @@ -1,3 +1,16 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== from pathlib import Path import pytest