From f58281781f84f0a49b4fa99219edfccfed5836ce Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Mon, 23 Sep 2024 16:02:43 +0300 Subject: [PATCH] Add VLLM models --- crab-benchmark-v0/main.py | 64 ++++++++++++++--- crab/agents/backend_models/__init__.py | 23 ++++++- crab/agents/backend_models/claude_model.py | 1 + crab/agents/backend_models/gemini_model.py | 1 + crab/agents/backend_models/openai_model.py | 6 +- crab/agents/backend_models/vllm_model.py | 80 ++++++++++++++++++++++ crab/agents/policies/single_agent.py | 45 ++++++++++-- crab/agents/utils.py | 61 +++++++++++++++-- 8 files changed, 257 insertions(+), 24 deletions(-) create mode 100644 crab/agents/backend_models/vllm_model.py diff --git a/crab-benchmark-v0/main.py b/crab-benchmark-v0/main.py index 231d4d0..07c4ba0 100644 --- a/crab-benchmark-v0/main.py +++ b/crab-benchmark-v0/main.py @@ -29,7 +29,7 @@ get_elements_prompt, groundingdino_easyocr, ) -from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel +from crab.agents.backend_models import BackendModelConfig from crab.agents.policies import ( MultiAgentByEnvPolicy, MultiAgentByFuncPolicy, @@ -158,7 +158,7 @@ def get_benchmark(env: str, ubuntu_url: str): default="single", ) parser.add_argument( - "--remote-url", + "--ubuntu-url", type=str, help="remote url of Ubunutu environment", default="http://127.0.0.1:8000", @@ -170,6 +170,18 @@ def get_benchmark(env: str, ubuntu_url: str): default="cross", ) parser.add_argument("--task-id", type=str, help="task id") + parser.add_argument( + "--model-base-url", + type=str, + help="URL of the model API", + default="http://127.0.0.1:8000/v1", + ) + parser.add_argument( + "--model-api-key", + type=str, + help="API key of the model API", + default="EMPTY", + ) parser.add_argument( "--loglevel", type=str, @@ -183,16 +195,48 @@ def get_benchmark(env: str, ubuntu_url: str): raise ValueError("Invalid log level: %s" % loglevel) logging.basicConfig(level=numeric_level) - benchmark = get_benchmark(args.env, args.remote_url) + benchmark = get_benchmark(args.env, args.ubuntu_url) if args.model == "gpt4o": - model = OpenAIModel(model="gpt-4o") - elif args.policy == "gpt4turbo": - model = OpenAIModel(model="gpt-4-turbo") - elif args.policy == "gemini": - model = GeminiModel(model="gemini-1.5-pro-latest") - elif args.policy == "claude": - model = ClaudeModel(model="claude-3-opus-20240229") + model = BackendModelConfig( + model_class="openai", + model_name="gpt-4o", + history_messages_len=2, + ) + elif args.model == "gpt4turbo": + model = BackendModelConfig( + model_class="openai", + model_name="gpt-4-turbo", + history_messages_len=2, + ) + elif args.model == "gemini": + model = BackendModelConfig( + model_class="gemini", + model_name="gemini-1.5-pro-latest", + history_messages_len=2, + ) + elif args.model == "claude": + model = BackendModelConfig( + model_class="claude", + model_name="claude-3-opus-20240229", + history_messages_len=2, + ) + elif args.model == "llava-1.6": + model = BackendModelConfig( + model_class="vllm", + model_name="llava-hf/llava-v1.6-34b-hf", + history_messages_len=2, + base_url=args.model_base_url, + api_key=args.model_api_key, + ) + elif args.model == "pixtral": + model = BackendModelConfig( + model_class="vllm", + model_name="mistralai/Pixtral-12B-2409", + history_messages_len=1, + base_url=args.model_base_url, + api_key=args.model_api_key, + ) else: print("Unsupported model: ", args.model) exit() diff --git a/crab/agents/backend_models/__init__.py b/crab/agents/backend_models/__init__.py index c087ca0..172b6a1 100644 --- a/crab/agents/backend_models/__init__.py +++ b/crab/agents/backend_models/__init__.py @@ -22,25 +22,36 @@ from .claude_model import ClaudeModel from .gemini_model import GeminiModel from .openai_model import OpenAIModel +from .vllm_model import VLLMModel class BackendModelConfig(BaseModel): - model_class: Literal["openai", "claude", "gemini", "camel"] + model_class: Literal["openai", "claude", "gemini", "camel", "vllm"] model_name: str history_messages_len: int = 0 parameters: dict[str, Any] = {} tool_call_required: bool = False + base_url: str | None = None # Only used in OpenAIModel and VLLMModel currently + api_key: str | None = None # Only used in OpenAIModel and VLLMModel currently def create_backend_model(model_config: BackendModelConfig) -> BackendModel: match model_config.model_class: case "claude": + if model_config.base_url is not None or model_config.api_key is not None: + raise Warning( + "base_url and api_key are not supported for ClaudeModel currently." + ) return ClaudeModel( model=model_config.model_name, parameters=model_config.parameters, history_messages_len=model_config.history_messages_len, ) case "gemini": + if model_config.base_url is not None or model_config.api_key is not None: + raise Warning( + "base_url and api_key are not supported for GeminiModel currently." + ) return GeminiModel( model=model_config.model_name, parameters=model_config.parameters, @@ -51,6 +62,16 @@ def create_backend_model(model_config: BackendModelConfig) -> BackendModel: model=model_config.model_name, parameters=model_config.parameters, history_messages_len=model_config.history_messages_len, + base_url=model_config.base_url, + api_key=model_config.api_key, + ) + case "vllm": + return VLLMModel( + model=model_config.model_name, + parameters=model_config.parameters, + history_messages_len=model_config.history_messages_len, + base_url=model_config.base_url, + api_key=model_config.api_key, ) case "camel": raise NotImplementedError("Cannot support camel model currently.") diff --git a/crab/agents/backend_models/claude_model.py b/crab/agents/backend_models/claude_model.py index d2c8252..ed37f47 100644 --- a/crab/agents/backend_models/claude_model.py +++ b/crab/agents/backend_models/claude_model.py @@ -50,6 +50,7 @@ def __init__( self.action_schema: list[dict] | None = None self.token_usage: int = 0 self.chat_history: list[list[dict]] = [] + self.support_tool_call = True def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message diff --git a/crab/agents/backend_models/gemini_model.py b/crab/agents/backend_models/gemini_model.py index 213b7e3..3032d94 100644 --- a/crab/agents/backend_models/gemini_model.py +++ b/crab/agents/backend_models/gemini_model.py @@ -59,6 +59,7 @@ def __init__( self.action_schema: list[Tool] | None = None self.token_usage: int = 0 self.chat_history: list[list[dict]] = [] + self.support_tool_call = True def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message diff --git a/crab/agents/backend_models/openai_model.py b/crab/agents/backend_models/openai_model.py index 2d1eb23..e8a11eb 100644 --- a/crab/agents/backend_models/openai_model.py +++ b/crab/agents/backend_models/openai_model.py @@ -33,6 +33,7 @@ def __init__( history_messages_len: int = 0, tool_call_required: bool = False, base_url: str | None = None, + api_key: str | None = None, ) -> None: if not openai_model_enable: raise ImportError("Please install openai to use OpenAIModel") @@ -43,7 +44,7 @@ def __init__( assert self.history_messages_len >= 0 - self.client = openai.OpenAI(base_url=base_url) + self.client = openai.OpenAI(api_key=api_key, base_url=base_url) self.tool_call_required: bool = tool_call_required self.system_message: str = "You are a helpful assistant." self.openai_system_message = { @@ -54,6 +55,7 @@ def __init__( self.action_schema: list[dict] | None = None self.token_usage: int = 0 self.chat_history: list[list[ChatCompletionMessage | dict]] = [] + self.support_tool_call = True def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message @@ -92,7 +94,7 @@ def record_message( "tool_call_id": tool_call.id, "role": "tool", "name": tool_call.function.name, - "content": "", + "content": "success", } ) # extend conversation with function response diff --git a/crab/agents/backend_models/vllm_model.py b/crab/agents/backend_models/vllm_model.py new file mode 100644 index 0000000..18ed12c --- /dev/null +++ b/crab/agents/backend_models/vllm_model.py @@ -0,0 +1,80 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +import json +from typing import Any + +from openai.types.chat import ChatCompletionMessage + +from crab import Action, ActionOutput, BackendOutput +from crab.agents.backend_models.openai_model import OpenAIModel +from crab.agents.utils import extract_text_and_code_prompts + + +class VLLMModel(OpenAIModel): + def __init__( + self, + model: str, + parameters: dict[str, Any] = dict(), + history_messages_len: int = 0, + base_url: str | None = None, + api_key: str | None = None, + ) -> None: + if base_url is None: + raise ValueError("base_url is required for VLLMModel") + super().__init__( + model, + parameters, + history_messages_len, + False, + base_url, + api_key, + ) + self.support_tool_call = False + + def reset(self, system_message: str, action_space: list[Action] | None) -> None: + super().reset(system_message, action_space) + self.action_schema = None + + def record_message( + self, new_message: dict, response_message: ChatCompletionMessage + ) -> None: + self.chat_history.append([new_message]) + self.chat_history[-1].append( + {"role": "assistant", "content": response_message.content} + ) + + def generate_backend_output( + self, response_message: ChatCompletionMessage + ) -> BackendOutput: + content = response_message.content + text_list, code_list = extract_text_and_code_prompts(content) + + action_list = [] + try: + for code_block in code_list: + action_object = json.loads(code_block) + action_list.append( + ActionOutput( + name=action_object["name"], arguments=action_object["arguments"] + ) + ) + except json.JSONDecodeError as e: + raise RuntimeError(f"Failed to parse code block: {code_block}") from e + except KeyError as e: + raise RuntimeError(f"Received invalid action format: {code_block}") from e + + return BackendOutput( + message="".join(text_list), + action_list=action_list, + ) diff --git a/crab/agents/policies/single_agent.py b/crab/agents/policies/single_agent.py index 7746c53..74a6cd6 100644 --- a/crab/agents/policies/single_agent.py +++ b/crab/agents/policies/single_agent.py @@ -26,8 +26,8 @@ class SingleAgentPolicy(AgentPolicy): - _system_prompt = """You are a helpful assistant. Now you have to do a task as - described below: + _system_prompt_with_function_call = """\ + You are a helpful assistant. Now you have to do a task as described below: **"{task_description}."** @@ -47,11 +47,45 @@ class SingleAgentPolicy(AgentPolicy): you. Always do them by yourself using function calls. """ + _system_prompt_no_function_call = """\ + You are a helpful assistant. Now you have to do a task as described below: + + **"{task_description}."** + + You should never forget this task and always perform actions to achieve this task. + And this is the description of each given environment: {env_description}. You will + receive screenshots of the environments. The interactive UI elements on the + screenshot are labeled with numeric tags starting from 1. + + A unit operation you can perform is called Action. You have a limited action space + as function calls: {action_descriptions}. You should generate JSON code blocks to + execute the actions. Each code block MUST contains only one json object, i.e. one + action. You can output multiple code blocks to execute multiple actions in a single + step. You must follow the JSON format below to output the action. + ```json + {{"name": "action_name", "arguments": {{"arg1": "value1", "arg2": "value2"}}}} + ``` + or if not arguments needed: + ```json + {{"name": "action_name", "arguments": {{}}}} + ``` + + In each step, You MUST explain what do you see from the current observation and the + plan of the next action, then use a provided action in each step to achieve the + task. You should state what action to take and what the parameters should be. Your + answer MUST contain at least one code block. You SHOULD NEVER ask me to do anything + for you. Always do them by yourself. + """ + def __init__( self, model_backend: BackendModelConfig, ): self.model_backend = create_backend_model(model_backend) + if self.model_backend.support_tool_call: + self.system_prompt = self._system_prompt_with_function_call + else: + self.system_prompt = self._system_prompt_no_function_call self.reset(task_description="", action_spaces=None, env_descriptions={}) def reset( @@ -62,9 +96,12 @@ def reset( ) -> list: self.task_description = task_description self.action_space = combine_multi_env_action_space(action_spaces) - system_message = self._system_prompt.format( + system_message = self.system_prompt.format( task_description=task_description, - action_descriptions=generate_action_prompt(self.action_space), + action_descriptions=generate_action_prompt( + self.action_space, + expand=not self.model_backend.support_tool_call, + ), env_description=str(env_descriptions), ) self.model_backend.reset(system_message, self.action_space) diff --git a/crab/agents/utils.py b/crab/agents/utils.py index e3a18c7..b174b92 100644 --- a/crab/agents/utils.py +++ b/crab/agents/utils.py @@ -24,7 +24,7 @@ def combine_multi_env_action_space( for env in action_space: for action in action_space[env]: new_action = action.model_copy() - new_action.name = new_action.name + "__in__" + env + new_action.name = new_action.name + "_in_" + env new_action.description = f"In {env} environment, " + new_action.description result.append(new_action) return result @@ -38,10 +38,10 @@ def decode_combined_action( """ result = [] for output in output_actions: - name_env = output.name.split("__in__") + name_env = output.name.split("_in_") if len(name_env) != 2: raise RuntimeError( - 'The decoded action name should contain the splitter "__in__".' + 'The decoded action name should contain the splitter "_in_".' ) new_output = output.model_copy() new_output.name = name_env[0] @@ -50,7 +50,54 @@ def decode_combined_action( return result -def generate_action_prompt(action_space: list[Action]) -> str: - return "".join( - [f"[{action.name}: {action.description}]\n" for action in action_space] - ) +def generate_action_prompt(action_space: list[Action], expand: bool = False) -> str: + if expand: + return "".join( + [ + f"[**{action.name}**:\n" + f"action description: {action.description}\n" + f"action arguments json schema: {action.to_openai_json_schema()}\n" + "]\n" + for action in action_space + ] + ) + else: + return "".join( + [f"[{action.name}: {action.description}]\n" for action in action_space] + ) + + +def extract_text_and_code_prompts(content: str) -> tuple[list[str], list[str]]: + r"""Extract text and code prompts from the message content. + + Returns: + A tuple (text_list, code_list) where, text_list is a list of text and code_list + is a list of extracted codes both from the content. + """ + text_prompts: list[str] = [] + code_prompts: list[str] = [] + + lines = content.split("\n") + idx = 0 + start_idx = 0 + while idx < len(lines): + while idx < len(lines) and (not lines[idx].lstrip().startswith("```")): + idx += 1 + text = "\n".join(lines[start_idx:idx]).strip() + text_prompts.append(text) + + if idx >= len(lines): + break + + # code_type = lines[idx].strip()[3:].strip() + idx += 1 + start_idx = idx + while not lines[idx].lstrip().startswith("```"): + idx += 1 + code = "\n".join(lines[start_idx:idx]).strip() + code_prompts.append(code) + + idx += 1 + start_idx = idx + + return text_prompts, code_prompts