Skip to content

Commit

Permalink
Fix typo & make model choosable & add new todo
Browse files Browse the repository at this point in the history
  • Loading branch information
anjieyang committed Oct 29, 2024
1 parent b1b8767 commit 293804e
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,52 @@
import customtkinter as ctk

from crab import Experiment
from crab.agents.backend_models import OpenAIModel
from crab.agents.backend_models import OpenAIModel, ClaudeModel, GeminiModel
from crab.agents.policies import SingleAgentPolicy
from gui.utils import get_benchmark

warnings.filterwarnings("ignore")

AVAILABLE_MODELS = {
"gpt-4o": ("OpenAIModel", "gpt-4o"),
"gpt-4turbo": ("OpenAIModel", "gpt-4-turbo"),
"gemini": ("GeminiModel", "gemini-1.5-pro-latest"),
"claude": ("ClaudeModel", "claude-3-opus-20240229"),
}

def get_model_instance(model_key: str):
if model_key not in AVAILABLE_MODELS:
raise ValueError(f"Model {model_key} not supported")

model_config = AVAILABLE_MODELS[model_key]
model_class_name = model_config[0]
model_name = model_config[1]

if model_class_name == "OpenAIModel":
return OpenAIModel(model=model_name, history_messages_len=2)
elif model_class_name == "GeminiModel":
return GeminiModel(model=model_name, history_messages_len=2)
elif model_class_name == "ClaudeModel":
return ClaudeModel(model=model_name, history_messages_len=2)

def assign_task():
task_description = input_entry.get()
input_entry.delete(0, "end")
display_message(task_description)

model = get_model_instance(selected_model.get())
agent_policy = SingleAgentPolicy(model_backend=model)

task_id = str(uuid4())
benchmark = get_benchmark(task_id, task_description)
expeirment = Experiment(
experiment = Experiment(
benchmark=benchmark,
task_id=task_id,
agent_policy=agent_policy,
log_dir=log_dir,
)
# TODO: redirect the output to the GUI
expeirment.start_benchmark()
experiment.start_benchmark()


def display_message(message, sender="user"):
Expand All @@ -56,9 +80,7 @@ def display_message(message, sender="user"):


if __name__ == "__main__":
# TODO: make this choosable by the user
model = OpenAIModel(model="gpt-4o", history_messages_len=2)
agent_policy = SingleAgentPolicy(model_backend=model)
# TODO: Handle JSON decode error from environment action endpoint and display model response in GUI
log_dir = (Path(__file__).parent / "logs").resolve()

ctk.set_appearance_mode("System")
Expand All @@ -68,6 +90,20 @@ def display_message(message, sender="user"):
app.title("CRAB")
app.geometry("400x500")

model_frame = ctk.CTkFrame(app)
model_frame.pack(pady=10, padx=10, fill="x")

model_label = ctk.CTkLabel(model_frame, text="Select Model:")
model_label.pack(side="left", padx=(0, 10))

selected_model = ctk.StringVar(value="gpt-4o")
model_dropdown = ctk.CTkOptionMenu(
model_frame,
values=list(AVAILABLE_MODELS.keys()),
variable=selected_model,
)
model_dropdown.pack(side="left", fill="x", expand=True)

chat_display_frame = ctk.CTkFrame(app, width=380, height=380)
chat_display_frame.pack(pady=10)
chat_display = ctk.CTkTextbox(
Expand Down

0 comments on commit 293804e

Please sign in to comment.