Skip to content

Commit

Permalink
Enhancement: CAMEL model update (#28)
Browse files Browse the repository at this point in the history
Co-authored-by: Tianqi Xu <[email protected]>
  • Loading branch information
WHALEEYE and dandansamax authored Sep 16, 2024
1 parent b9ddf11 commit 8503080
Show file tree
Hide file tree
Showing 13 changed files with 1,442 additions and 1,296 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.2
rev: v0.6.5
hooks:
# Run the linter.
- id: ruff
Expand All @@ -13,4 +13,4 @@ repos:
name: Check License
entry: python licenses/update_license.py . licenses/license_template.txt
language: system
types: [python]
types: [python]
6 changes: 3 additions & 3 deletions crab-benchmark-v0/dataset/ubuntu_subtasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,9 @@ def get_rgb_values_outside_bbox(

# Create a mask for the bounding box area with margin
mask = np.ones(img.shape[:2], dtype=bool)
mask[
y_min_with_margin:y_max_with_margin, x_min_with_margin:x_max_with_margin
] = False
mask[y_min_with_margin:y_max_with_margin, x_min_with_margin:x_max_with_margin] = (
False
)

# Extract the RGB values outside the bounding box with margin
rgb_values = img[mask]
Expand Down
3 changes: 1 addition & 2 deletions crab-benchmark-v0/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,9 @@ def get_benchmark(env: str, ubuntu_url: str):
loglevel = args.loglevel
numeric_level = getattr(logging, loglevel.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % loglevel)
raise ValueError("Invalid log level: %s" % loglevel)
logging.basicConfig(level=numeric_level)


benchmark = get_benchmark(args.env, args.remote_url)

if args.model == "gpt4o":
Expand Down
6 changes: 4 additions & 2 deletions crab/actions/android_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def swipe(element: int, direction: SwipeDirection, dist: SwipeDist, env) -> None
offset = unit_dist, 0
else:
return "ERROR"
adb_command = f"shell input swipe {x} {y} {x+offset[0]} {y+offset[1]} 200"
adb_command = f"shell input swipe {x} {y} {x + offset[0]} {y + offset[1]} 200"
execute_adb(adb_command, env)
sleep(_DURATION)

Expand Down Expand Up @@ -213,7 +213,9 @@ def stop_all_apps(env) -> None:
execute_adb("shell input keyevent KEYCODE_HOME", env)
execute_adb("shell input keyevent KEYCODE_APP_SWITCH", env)
sleep(0.5)
command = f"shell input swipe 100 {env.height/2} {env.width-100} {env.height/2} 200"
command = (
f"shell input swipe 100 {env.height / 2} {env.width - 100} {env.height / 2} 200"
)
execute_adb(command, env)
sleep(0.5)
execute_adb("shell input tap 300 1400", env)
Expand Down
79 changes: 36 additions & 43 deletions crab/agents/backend_models/camel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import base64
import io
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

from openai.types.chat import ChatCompletionMessageToolCall
from PIL import Image

from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
Expand All @@ -34,7 +33,7 @@
CAMEL_ENABLED = False


def find_model_platform_type(model_platform_name: str) -> ModelPlatformType:
def _find_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
for platform in ModelPlatformType:
if platform.value.lower() == model_platform_name.lower():
return platform
Expand All @@ -44,33 +43,51 @@ def find_model_platform_type(model_platform_name: str) -> ModelPlatformType:
)


def find_model_type(model_name: str) -> Union[ModelType, str]:
def _find_model_type(model_name: str) -> "str | ModelType":
for model in ModelType:
if model.value.lower() == model_name.lower():
return model
return model_name


def decode_image(encoded_image: str) -> Image:
data = base64.b64decode(encoded_image)
return Image.open(io.BytesIO(data))
def _convert_action_to_schema(
action_space: list[Action] | None,
) -> "list[OpenAIFunction] | None":
if action_space is None:
return None
return [OpenAIFunction(action.entry) for action in action_space]


def _convert_tool_calls_to_action_list(
tool_calls: list[ChatCompletionMessageToolCall] | None,
) -> list[ActionOutput] | None:
if tool_calls is None:
return None

return [
ActionOutput(
name=call.function.name,
arguments=json.loads(call.function.arguments),
)
for call in tool_calls
]


class CamelModel(BackendModel):
def __init__(
self,
model: str,
model_platform: str,
parameters: Optional[Dict[str, Any]] = None,
parameters: dict[str, Any] | None = None,
history_messages_len: int = 0,
) -> None:
if not CAMEL_ENABLED:
raise ImportError("Please install camel-ai to use CamelModel")
self.parameters = parameters or {}
# TODO: a better way?
self.model_type = find_model_type(model)
self.model_platform_type = find_model_platform_type(model_platform)
self.client: Optional[ChatAgent] = None
self.model_type = _find_model_type(model)
self.model_platform_type = _find_model_platform_type(model_platform)
self.client: ChatAgent | None = None
self.token_usage = 0

super().__init__(
Expand All @@ -79,11 +96,11 @@ def __init__(
history_messages_len,
)

def get_token_usage(self):
def get_token_usage(self) -> int:
return self.token_usage

def reset(self, system_message: str, action_space: Optional[List[Action]]) -> None:
action_schema = self._convert_action_to_schema(action_space)
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
action_schema = _convert_action_to_schema(action_space)
config = self.parameters.copy()
if action_schema is not None:
config["tool_choice"] = "required"
Expand All @@ -109,30 +126,9 @@ def reset(self, system_message: str, action_space: Optional[List[Action]]) -> No
)
self.token_usage = 0

@staticmethod
def _convert_action_to_schema(
action_space: Optional[List[Action]],
) -> Optional[List[OpenAIFunction]]:
if action_space is None:
return None
return [OpenAIFunction(action.entry) for action in action_space]

@staticmethod
def _convert_tool_calls_to_action_list(tool_calls) -> List[ActionOutput]:
if tool_calls is None:
return tool_calls

return [
ActionOutput(
name=call.function.name,
arguments=json.loads(call.function.arguments),
)
for call in tool_calls
]

def chat(self, messages: List[Tuple[str, MessageType]]):
def chat(self, messages: list[tuple[str, MessageType]]) -> BackendOutput:
# TODO: handle multiple text messages after message refactoring
image_list: List[Image.Image] = []
image_list: list[Image.Image] = []
content = ""
for message in messages:
if message[1] == MessageType.IMAGE_JPG_BASE64:
Expand All @@ -147,12 +143,9 @@ def chat(self, messages: List[Tuple[str, MessageType]]):
)
response = self.client.step(usermsg)
self.token_usage += response.info["usage"]["total_tokens"]
tool_call_request = response.info.get("tool_call_request")

# TODO: delete this after record_message is refactored
self.client.record_message(response.msg)
tool_call_request = response.info.get("external_tool_request")

return BackendOutput(
message=response.msg.content,
action_list=self._convert_tool_calls_to_action_list([tool_call_request]),
action_list=_convert_tool_calls_to_action_list([tool_call_request]),
)
12 changes: 4 additions & 8 deletions crab/core/agent_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,21 @@ class AgentPolicy(ABC):
def chat(
self,
observation: dict[str, list[tuple[str, MessageType]]],
) -> list[ActionOutput]:
...
) -> list[ActionOutput]: ...

@abstractmethod
def reset(
self,
task_description: str,
action_spaces: dict[str, list[Action]],
env_descriptions: dict[str, str],
) -> None:
...
) -> None: ...

@abstractmethod
def get_token_usage(self):
...
def get_token_usage(self): ...

@abstractmethod
def get_backend_model_name(self) -> str:
...
def get_backend_model_name(self) -> str: ...

@staticmethod
def combine_multi_env_action_space(
Expand Down
9 changes: 3 additions & 6 deletions crab/core/backend_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,14 @@ def __init__(
self.reset("You are a helpful assistant.", None)

@abstractmethod
def chat(self, contents: list[tuple[str, MessageType]]) -> BackendOutput:
...
def chat(self, contents: list[tuple[str, MessageType]]) -> BackendOutput: ...

@abstractmethod
def reset(
self,
system_message: str,
action_space: list[Action] | None,
):
...
): ...

@abstractmethod
def get_token_usage(self):
...
def get_token_usage(self): ...
2 changes: 1 addition & 1 deletion crab/core/models/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Evaluator(Action):
@field_validator("returns", mode="after")
@classmethod
def must_return_bool(cls, v: type[BaseModel]) -> type[BaseModel]:
if v.model_fields["returns"].annotation != bool:
if v.model_fields["returns"].annotation is not bool:
raise ValueError("Evaluator must return bool.")
return v

Expand Down
13 changes: 13 additions & 0 deletions crab/utils/measure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import logging
import time
from functools import wraps
Expand Down
Loading

0 comments on commit 8503080

Please sign in to comment.