From 293804ee8e8b9b490170280a75876505721be9a2 Mon Sep 17 00:00:00 2001 From: Anjie Yang Date: Tue, 29 Oct 2024 23:26:29 +0800 Subject: [PATCH] Fix typo & make model choosable & add new todo --- gui/main.py | 48 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/gui/main.py b/gui/main.py index 998aa27..8510de7 100644 --- a/gui/main.py +++ b/gui/main.py @@ -18,28 +18,52 @@ import customtkinter as ctk from crab import Experiment -from crab.agents.backend_models import OpenAIModel +from crab.agents.backend_models import OpenAIModel, ClaudeModel, GeminiModel from crab.agents.policies import SingleAgentPolicy from gui.utils import get_benchmark warnings.filterwarnings("ignore") +AVAILABLE_MODELS = { + "gpt-4o": ("OpenAIModel", "gpt-4o"), + "gpt-4turbo": ("OpenAIModel", "gpt-4-turbo"), + "gemini": ("GeminiModel", "gemini-1.5-pro-latest"), + "claude": ("ClaudeModel", "claude-3-opus-20240229"), +} + +def get_model_instance(model_key: str): + if model_key not in AVAILABLE_MODELS: + raise ValueError(f"Model {model_key} not supported") + + model_config = AVAILABLE_MODELS[model_key] + model_class_name = model_config[0] + model_name = model_config[1] + + if model_class_name == "OpenAIModel": + return OpenAIModel(model=model_name, history_messages_len=2) + elif model_class_name == "GeminiModel": + return GeminiModel(model=model_name, history_messages_len=2) + elif model_class_name == "ClaudeModel": + return ClaudeModel(model=model_name, history_messages_len=2) def assign_task(): task_description = input_entry.get() input_entry.delete(0, "end") display_message(task_description) + model = get_model_instance(selected_model.get()) + agent_policy = SingleAgentPolicy(model_backend=model) + task_id = str(uuid4()) benchmark = get_benchmark(task_id, task_description) - expeirment = Experiment( + experiment = Experiment( benchmark=benchmark, task_id=task_id, agent_policy=agent_policy, log_dir=log_dir, ) # TODO: redirect the output to the GUI - expeirment.start_benchmark() + experiment.start_benchmark() def display_message(message, sender="user"): @@ -56,9 +80,7 @@ def display_message(message, sender="user"): if __name__ == "__main__": - # TODO: make this choosable by the user - model = OpenAIModel(model="gpt-4o", history_messages_len=2) - agent_policy = SingleAgentPolicy(model_backend=model) + # TODO: Handle JSON decode error from environment action endpoint and display model response in GUI log_dir = (Path(__file__).parent / "logs").resolve() ctk.set_appearance_mode("System") @@ -68,6 +90,20 @@ def display_message(message, sender="user"): app.title("CRAB") app.geometry("400x500") + model_frame = ctk.CTkFrame(app) + model_frame.pack(pady=10, padx=10, fill="x") + + model_label = ctk.CTkLabel(model_frame, text="Select Model:") + model_label.pack(side="left", padx=(0, 10)) + + selected_model = ctk.StringVar(value="gpt-4o") + model_dropdown = ctk.CTkOptionMenu( + model_frame, + values=list(AVAILABLE_MODELS.keys()), + variable=selected_model, + ) + model_dropdown.pack(side="left", fill="x", expand=True) + chat_display_frame = ctk.CTkFrame(app, width=380, height=380) chat_display_frame.pack(pady=10) chat_display = ctk.CTkTextbox(