Skip to content

Commit

Permalink
Add VLLM models
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Sep 23, 2024
1 parent 0830a38 commit f582817
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 24 deletions.
64 changes: 54 additions & 10 deletions crab-benchmark-v0/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
get_elements_prompt,
groundingdino_easyocr,
)
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
from crab.agents.backend_models import BackendModelConfig
from crab.agents.policies import (
MultiAgentByEnvPolicy,
MultiAgentByFuncPolicy,
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_benchmark(env: str, ubuntu_url: str):
default="single",
)
parser.add_argument(
"--remote-url",
"--ubuntu-url",
type=str,
help="remote url of Ubunutu environment",
default="http://127.0.0.1:8000",
Expand All @@ -170,6 +170,18 @@ def get_benchmark(env: str, ubuntu_url: str):
default="cross",
)
parser.add_argument("--task-id", type=str, help="task id")
parser.add_argument(
"--model-base-url",
type=str,
help="URL of the model API",
default="http://127.0.0.1:8000/v1",
)
parser.add_argument(
"--model-api-key",
type=str,
help="API key of the model API",
default="EMPTY",
)
parser.add_argument(
"--loglevel",
type=str,
Expand All @@ -183,16 +195,48 @@ def get_benchmark(env: str, ubuntu_url: str):
raise ValueError("Invalid log level: %s" % loglevel)
logging.basicConfig(level=numeric_level)

benchmark = get_benchmark(args.env, args.remote_url)
benchmark = get_benchmark(args.env, args.ubuntu_url)

if args.model == "gpt4o":
model = OpenAIModel(model="gpt-4o")
elif args.policy == "gpt4turbo":
model = OpenAIModel(model="gpt-4-turbo")
elif args.policy == "gemini":
model = GeminiModel(model="gemini-1.5-pro-latest")
elif args.policy == "claude":
model = ClaudeModel(model="claude-3-opus-20240229")
model = BackendModelConfig(
model_class="openai",
model_name="gpt-4o",
history_messages_len=2,
)
elif args.model == "gpt4turbo":
model = BackendModelConfig(
model_class="openai",
model_name="gpt-4-turbo",
history_messages_len=2,
)
elif args.model == "gemini":
model = BackendModelConfig(
model_class="gemini",
model_name="gemini-1.5-pro-latest",
history_messages_len=2,
)
elif args.model == "claude":
model = BackendModelConfig(
model_class="claude",
model_name="claude-3-opus-20240229",
history_messages_len=2,
)
elif args.model == "llava-1.6":
model = BackendModelConfig(
model_class="vllm",
model_name="llava-hf/llava-v1.6-34b-hf",
history_messages_len=2,
base_url=args.model_base_url,
api_key=args.model_api_key,
)
elif args.model == "pixtral":
model = BackendModelConfig(
model_class="vllm",
model_name="mistralai/Pixtral-12B-2409",
history_messages_len=1,
base_url=args.model_base_url,
api_key=args.model_api_key,
)
else:
print("Unsupported model: ", args.model)
exit()
Expand Down
23 changes: 22 additions & 1 deletion crab/agents/backend_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,36 @@
from .claude_model import ClaudeModel
from .gemini_model import GeminiModel
from .openai_model import OpenAIModel
from .vllm_model import VLLMModel


class BackendModelConfig(BaseModel):
model_class: Literal["openai", "claude", "gemini", "camel"]
model_class: Literal["openai", "claude", "gemini", "camel", "vllm"]
model_name: str
history_messages_len: int = 0
parameters: dict[str, Any] = {}
tool_call_required: bool = False
base_url: str | None = None # Only used in OpenAIModel and VLLMModel currently
api_key: str | None = None # Only used in OpenAIModel and VLLMModel currently


def create_backend_model(model_config: BackendModelConfig) -> BackendModel:
match model_config.model_class:
case "claude":
if model_config.base_url is not None or model_config.api_key is not None:
raise Warning(
"base_url and api_key are not supported for ClaudeModel currently."
)
return ClaudeModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
)
case "gemini":
if model_config.base_url is not None or model_config.api_key is not None:
raise Warning(
"base_url and api_key are not supported for GeminiModel currently."
)
return GeminiModel(
model=model_config.model_name,
parameters=model_config.parameters,
Expand All @@ -51,6 +62,16 @@ def create_backend_model(model_config: BackendModelConfig) -> BackendModel:
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
base_url=model_config.base_url,
api_key=model_config.api_key,
)
case "vllm":
return VLLMModel(
model=model_config.model_name,
parameters=model_config.parameters,
history_messages_len=model_config.history_messages_len,
base_url=model_config.base_url,
api_key=model_config.api_key,
)
case "camel":
raise NotImplementedError("Cannot support camel model currently.")
Expand Down
1 change: 1 addition & 0 deletions crab/agents/backend_models/claude_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self.action_schema: list[dict] | None = None
self.token_usage: int = 0
self.chat_history: list[list[dict]] = []
self.support_tool_call = True

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
Expand Down
1 change: 1 addition & 0 deletions crab/agents/backend_models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self.action_schema: list[Tool] | None = None
self.token_usage: int = 0
self.chat_history: list[list[dict]] = []
self.support_tool_call = True

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
Expand Down
6 changes: 4 additions & 2 deletions crab/agents/backend_models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
history_messages_len: int = 0,
tool_call_required: bool = False,
base_url: str | None = None,
api_key: str | None = None,
) -> None:
if not openai_model_enable:
raise ImportError("Please install openai to use OpenAIModel")
Expand All @@ -43,7 +44,7 @@ def __init__(

assert self.history_messages_len >= 0

self.client = openai.OpenAI(base_url=base_url)
self.client = openai.OpenAI(api_key=api_key, base_url=base_url)
self.tool_call_required: bool = tool_call_required
self.system_message: str = "You are a helpful assistant."
self.openai_system_message = {
Expand All @@ -54,6 +55,7 @@ def __init__(
self.action_schema: list[dict] | None = None
self.token_usage: int = 0
self.chat_history: list[list[ChatCompletionMessage | dict]] = []
self.support_tool_call = True

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
Expand Down Expand Up @@ -92,7 +94,7 @@ def record_message(
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": "",
"content": "success",
}
) # extend conversation with function response

Expand Down
80 changes: 80 additions & 0 deletions crab/agents/backend_models/vllm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import json
from typing import Any

from openai.types.chat import ChatCompletionMessage

from crab import Action, ActionOutput, BackendOutput
from crab.agents.backend_models.openai_model import OpenAIModel
from crab.agents.utils import extract_text_and_code_prompts


class VLLMModel(OpenAIModel):
def __init__(
self,
model: str,
parameters: dict[str, Any] = dict(),
history_messages_len: int = 0,
base_url: str | None = None,
api_key: str | None = None,
) -> None:
if base_url is None:
raise ValueError("base_url is required for VLLMModel")
super().__init__(
model,
parameters,
history_messages_len,
False,
base_url,
api_key,
)
self.support_tool_call = False

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
super().reset(system_message, action_space)
self.action_schema = None

def record_message(
self, new_message: dict, response_message: ChatCompletionMessage
) -> None:
self.chat_history.append([new_message])
self.chat_history[-1].append(
{"role": "assistant", "content": response_message.content}
)

def generate_backend_output(
self, response_message: ChatCompletionMessage
) -> BackendOutput:
content = response_message.content
text_list, code_list = extract_text_and_code_prompts(content)

action_list = []
try:
for code_block in code_list:
action_object = json.loads(code_block)
action_list.append(
ActionOutput(
name=action_object["name"], arguments=action_object["arguments"]
)
)
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse code block: {code_block}") from e
except KeyError as e:
raise RuntimeError(f"Received invalid action format: {code_block}") from e

return BackendOutput(
message="".join(text_list),
action_list=action_list,
)
45 changes: 41 additions & 4 deletions crab/agents/policies/single_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@


class SingleAgentPolicy(AgentPolicy):
_system_prompt = """You are a helpful assistant. Now you have to do a task as
described below:
_system_prompt_with_function_call = """\
You are a helpful assistant. Now you have to do a task as described below:
**"{task_description}."**
Expand All @@ -47,11 +47,45 @@ class SingleAgentPolicy(AgentPolicy):
you. Always do them by yourself using function calls.
"""

_system_prompt_no_function_call = """\
You are a helpful assistant. Now you have to do a task as described below:
**"{task_description}."**
You should never forget this task and always perform actions to achieve this task.
And this is the description of each given environment: {env_description}. You will
receive screenshots of the environments. The interactive UI elements on the
screenshot are labeled with numeric tags starting from 1.
A unit operation you can perform is called Action. You have a limited action space
as function calls: {action_descriptions}. You should generate JSON code blocks to
execute the actions. Each code block MUST contains only one json object, i.e. one
action. You can output multiple code blocks to execute multiple actions in a single
step. You must follow the JSON format below to output the action.
```json
{{"name": "action_name", "arguments": {{"arg1": "value1", "arg2": "value2"}}}}
```
or if not arguments needed:
```json
{{"name": "action_name", "arguments": {{}}}}
```
In each step, You MUST explain what do you see from the current observation and the
plan of the next action, then use a provided action in each step to achieve the
task. You should state what action to take and what the parameters should be. Your
answer MUST contain at least one code block. You SHOULD NEVER ask me to do anything
for you. Always do them by yourself.
"""

def __init__(
self,
model_backend: BackendModelConfig,
):
self.model_backend = create_backend_model(model_backend)
if self.model_backend.support_tool_call:
self.system_prompt = self._system_prompt_with_function_call
else:
self.system_prompt = self._system_prompt_no_function_call
self.reset(task_description="", action_spaces=None, env_descriptions={})

def reset(
Expand All @@ -62,9 +96,12 @@ def reset(
) -> list:
self.task_description = task_description
self.action_space = combine_multi_env_action_space(action_spaces)
system_message = self._system_prompt.format(
system_message = self.system_prompt.format(
task_description=task_description,
action_descriptions=generate_action_prompt(self.action_space),
action_descriptions=generate_action_prompt(
self.action_space,
expand=not self.model_backend.support_tool_call,
),
env_description=str(env_descriptions),
)
self.model_backend.reset(system_message, self.action_space)
Expand Down
Loading

0 comments on commit f582817

Please sign in to comment.