Skip to content

Commit

Permalink
update frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
WHALEEYE committed Oct 31, 2024
1 parent daf3a53 commit 9797e35
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
25 changes: 9 additions & 16 deletions gui/gui_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path
from typing import Literal

from crab import AgentPolicy, Benchmark, Experiment, MessageType
from crab import ActionOutput, AgentPolicy, Benchmark, Experiment, MessageType


class GuiExperiment(Experiment):
Expand Down Expand Up @@ -54,23 +54,21 @@ def get_prompt(self):

def step(self, it) -> bool:
if self.display_callback:
self.display_callback(f"Step {self.step_cnt}:", "ai")
self.display_callback("CRAB is Thinking...", "system")

prompt = self.get_prompt()
self.log_prompt(prompt)

try:
response = self.agent_policy.chat(prompt)
if self.display_callback:
self.display_callback(f"Planning next action...", "ai")
except Exception as e:
if self.display_callback:
self.display_callback(f"Error: {str(e)}", "ai")
self.display_callback(f"Error: {str(e)}", "error")
self.write_main_csv_row("agent_exception")
return True

if self.display_callback:
self.display_callback(f"Executing: {response}", "ai")
self.display_callback(f"Acting: {response}", "action")
return self.execute_action(response)

def execute_action(self, response: list[ActionOutput]) -> bool:
Expand All @@ -85,30 +83,25 @@ def execute_action(self, response: list[ActionOutput]) -> bool:
if benchmark_result.terminated:
if self.display_callback:
self.display_callback(
f"✓ Task completed! Results: {self.metrics}", "ai"
f"✓ Task completed! Results: {self.metrics}", "system"
)
self.write_current_log_row(action)
self.write_current_log_row(benchmark_result.info["terminate_reason"])
return True

if self.display_callback:
self.display_callback(
f'Action "{action.name}" completed in {action.env}. '
f"Progress: {self.metrics}", "ai"
)
self.display_callback("Action completed.\n>>>>>", "system")
self.write_current_log_row(action)
self.step_cnt += 1
return False

def start_benchmark(self):
if self.display_callback:
self.display_callback("Starting benchmark...", "ai")
try:
super().start_benchmark()
except KeyboardInterrupt:
if self.display_callback:
self.display_callback("Experiment interrupted.", "ai")
self.display_callback("Experiment interrupted.", "error")
self.write_main_csv_row("experiment_interrupted")
finally:
if self.display_callback:
self.display_callback("Experiment finished.", "ai")
self.display_callback("Experiment finished.", "error")
44 changes: 30 additions & 14 deletions gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import warnings
from pathlib import Path
from typing import Literal
from uuid import uuid4

import customtkinter as ctk
Expand Down Expand Up @@ -50,8 +51,10 @@ def get_model_instance(model_key: str):

def assign_task():
task_description = input_entry.get()
if not task_description.strip():
return
input_entry.delete(0, "end")
display_message(task_description)
display_message(task_description, "user")

try:
model = get_model_instance(model_dropdown.get())
Expand All @@ -65,31 +68,44 @@ def assign_task():
agent_policy=agent_policy,
log_dir=log_dir,
)

experiment.set_display_callback(display_message)

def run_experiment():
try:
experiment.start_benchmark()
except Exception as e:
display_message(f"Error: {str(e)}", "ai")
display_message(f"Error: {str(e)}", "error")

import threading

thread = threading.Thread(target=run_experiment, daemon=True)
thread.start()

except Exception as e:
display_message(f"Error: {str(e)}", "ai")
display_message(f"Error: {str(e)}", "error")


def display_message(message, sender="user"):
def display_message(
message, category: Literal["system", "user", "action", "error"] = "system"
):
chat_display.configure(state="normal")
if sender == "user":
chat_display.insert("end", f"User: {message}\n", "user")
else:
chat_display.insert("end", f"AI: {message}\n", "ai")
chat_display.tag_config("user", justify="left", foreground="blue")
chat_display.tag_config("ai", justify="right", foreground="green")
if category == "user":
chat_display.insert("end", f"{message}\n", "user")
elif category == "system":
chat_display.insert("end", f"{message}\n", "system")
elif category == "error":
chat_display.insert("end", f"{message}\n", "error")
elif category == "action":
chat_display.insert("end", f"{message}\n", "action")
chat_display.tag_config(
"user", justify="right", foreground="lightblue", wrap="word"
)
chat_display.tag_config("system", justify="left", foreground="gray", wrap="word")
chat_display.tag_config(
"action", justify="left", foreground="lightgreen", wrap="word"
)
chat_display.tag_config("error", justify="left", foreground="red", wrap="word")
chat_display.configure(state="disabled")
chat_display.see("end")
app.update_idletasks()
Expand All @@ -98,7 +114,7 @@ def display_message(message, sender="user"):
if __name__ == "__main__":
log_dir = (Path(__file__).parent / "logs").resolve()

ctk.set_appearance_mode("System")
ctk.set_appearance_mode("dark")
ctk.set_default_color_theme("blue")

app = ctk.CTk()
Expand All @@ -118,7 +134,7 @@ def display_message(message, sender="user"):
model_dropdown.pack(pady=10, padx=10, fill="x")

chat_display_frame = ctk.CTkFrame(app, width=480, height=880)
chat_display_frame.pack(pady=10, expand=True, fill="y")
chat_display_frame.pack(pady=10, padx=10, expand=True, fill="both")
chat_display = ctk.CTkTextbox(
chat_display_frame, width=480, height=880, state="disabled", font=normal_font
)
Expand Down
3 changes: 3 additions & 0 deletions gui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@

_CACHED_HOST_OS = None


def check_host_os() -> HostOS:
global _CACHED_HOST_OS

if _CACHED_HOST_OS is None:
import platform

host_os = platform.system().lower()

if host_os == "linux":
Expand All @@ -47,6 +49,7 @@ def check_host_os() -> HostOS:

return _CACHED_HOST_OS


@evaluator(env_name="ubuntu")
def empty_evaluator_linux() -> bool:
return False
Expand Down

0 comments on commit 9797e35

Please sign in to comment.