From 6cf86733f9db821b51ca06e0e5e5d03f83232d15 Mon Sep 17 00:00:00 2001 From: Henry Tai Date: Thu, 21 Nov 2024 03:46:37 +0800 Subject: [PATCH] print the error sent from the server fix conflict error --- src/llama_stack_client/lib/agents/agent.py | 13 +++++++++++-- .../lib/agents/event_logger.py | 18 ++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 1d89c8b..3bd7d0e 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -8,10 +8,16 @@ from llama_stack_client import LlamaStackClient from llama_stack_client.types import Attachment, ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig + from .custom_tool import CustomTool class Agent: - def __init__(self, client: LlamaStackClient, agent_config: AgentConfig, custom_tools: Tuple[CustomTool] = ()): + def __init__( + self, + client: LlamaStackClient, + agent_config: AgentConfig, + custom_tools: Tuple[CustomTool] = (), + ): self.client = client self.agent_config = agent_config self.agent_id = self._create_agent(agent_config) @@ -72,7 +78,10 @@ def create_turn( stream=True, ) for chunk in response: - if not self._has_tool_call(chunk): + if hasattr(chunk, "error"): + yield chunk + return + elif not self._has_tool_call(chunk): yield chunk else: next_message = self._run_tool(chunk) diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index e356463..ba07a0f 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -6,12 +6,14 @@ from typing import List, Optional, Union -from termcolor import cprint - from llama_stack_client.types import ToolResponseMessage +from termcolor import cprint + -def interleaved_text_media_as_str(content: Union[str, List[str]], sep: str = " ") -> str: +def interleaved_text_media_as_str( + content: Union[str, List[str]], sep: str = " " +) -> str: def _process(c) -> str: if isinstance(c, str): return c @@ -49,14 +51,18 @@ def print(self, flush=True): class EventLogger: def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): + if hasattr(chunk, "error"): + yield LogEvent(role=None, content=chunk.error["message"], color="red") + return if not hasattr(chunk, "event"): # Need to check for custom tool first # since it does not produce event but instead # a Message if isinstance(chunk, ToolResponseMessage): - yield LogEvent(role="CustomTool", content=chunk.content, color="green") + yield LogEvent( + role="CustomTool", content=chunk.content, color="green" + ) return - event = chunk.event event_type = event.payload.event_type @@ -153,4 +159,4 @@ def log(self, event_generator): for chunk in event_generator: for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): yield log_event - previous_event_type, previous_step_type = self._get_event_type_step_type(chunk) \ No newline at end of file + previous_event_type, previous_step_type = self._get_event_type_step_type(chunk)