diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 37ed32b..1d89c8b 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -10,7 +10,6 @@ from llama_stack_client.types.agent_create_params import AgentConfig from .custom_tool import CustomTool - class Agent: def __init__(self, client: LlamaStackClient, agent_config: AgentConfig, custom_tools: Tuple[CustomTool] = ()): self.client = client @@ -35,7 +34,30 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - async def create_turn( + def _has_tool_call(self, chunk): + if chunk.event.payload.event_type != "turn_complete": + return False + message = chunk.event.payload.turn.output_message + if message.stop_reason == "out_of_tokens": + return False + return len(message.tool_calls) > 0 + + def _run_tool(self, chunk): + message = chunk.event.payload.turn.output_message + tool_call = message.tool_calls[0] + if tool_call.tool_name not in self.custom_tools: + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + role="ipython", + ) + tool = self.custom_tools[tool_call.tool_name] + result_messages = tool.run([message]) + next_message = result_messages[0] + return next_message + + def create_turn( self, messages: List[Union[UserMessage, ToolResponseMessage]], attachments: Optional[List[Attachment]] = None, @@ -49,40 +71,9 @@ async def create_turn( attachments=attachments, stream=True, ) - turn = None for chunk in response: - if chunk.event.payload.event_type != "turn_complete": + if not self._has_tool_call(chunk): yield chunk else: - turn = chunk.event.payload.turn - - message = turn.output_message - if len(message.tool_calls) == 0: - yield chunk - return - - if message.stop_reason == "out_of_tokens": - yield chunk - return - - tool_call = message.tool_calls[0] - if tool_call.tool_name not in self.custom_tools: - m = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", - role="ipython", - ) - next_message = m - else: - tool = self.custom_tools[tool_call.tool_name] - result_messages = await self.execute_custom_tool(tool, message) - next_message = result_messages[0] - - yield next_message - - async def execute_custom_tool( - self, tool: CustomTool, message: Union[UserMessage, ToolResponseMessage] - ) -> List[Union[UserMessage, ToolResponseMessage]]: - result_messages = await tool.run([message]) - return result_messages + next_message = self._run_tool(chunk) + yield next_message diff --git a/src/llama_stack_client/lib/agents/custom_tool.py b/src/llama_stack_client/lib/agents/custom_tool.py index 75907a5..e1205ca 100644 --- a/src/llama_stack_client/lib/agents/custom_tool.py +++ b/src/llama_stack_client/lib/agents/custom_tool.py @@ -61,7 +61,7 @@ def get_tool_definition(self) -> FunctionCallToolDefinition: ) @abstractmethod - async def run( + def run( self, messages: List[Union[UserMessage, ToolResponseMessage]] ) -> List[Union[UserMessage, ToolResponseMessage]]: raise NotImplementedError diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 4e935d9..e356463 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -48,100 +48,109 @@ def print(self, flush=True): class EventLogger: - async def log(self, event_generator): - previous_event_type = None - previous_step_type = None - - async for chunk in event_generator: - if not hasattr(chunk, "event"): - # Need to check for custom tool first - # since it does not produce event but instead - # a Message - if isinstance(chunk, ToolResponseMessage): - yield LogEvent(role="CustomTool", content=chunk.content, color="green") - continue - - event = chunk.event - event_type = event.payload.event_type - - if event_type in {"turn_start", "turn_complete"}: - # Currently not logging any turn realted info - yield LogEvent(role=None, content="", end="", color="grey") - continue - - step_type = event.payload.step_type - # handle safety - if step_type == "shield_call" and event_type == "step_complete": - violation = event.payload.step_details.violation - if not violation: - yield LogEvent(role=step_type, content="No Violation", color="magenta") - else: - yield LogEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", - ) + def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): + if not hasattr(chunk, "event"): + # Need to check for custom tool first + # since it does not produce event but instead + # a Message + if isinstance(chunk, ToolResponseMessage): + yield LogEvent(role="CustomTool", content=chunk.content, color="green") + return + + event = chunk.event + event_type = event.payload.event_type + + if event_type in {"turn_start", "turn_complete"}: + # Currently not logging any turn realted info + yield LogEvent(role=None, content="", end="", color="grey") + return + + step_type = event.payload.step_type + # handle safety + if step_type == "shield_call" and event_type == "step_complete": + violation = event.payload.step_details.violation + if not violation: + yield LogEvent(role=step_type, content="No Violation", color="magenta") + else: + yield LogEvent( + role=step_type, + content=f"{violation.metadata} {violation.user_message}", + color="red", + ) - # handle inference - if step_type == "inference": - if event_type == "step_start": + # handle inference + if step_type == "inference": + if event_type == "step_start": + yield LogEvent(role=step_type, content="", end="", color="yellow") + elif event_type == "step_progress": + # HACK: if previous was not step/event was not inference's step_progress + # this is the first time we are getting model inference response + # aka equivalent to step_start for inference. Hence, + # start with "Model>". + if previous_event_type != "step_progress" and previous_step_type != "inference": yield LogEvent(role=step_type, content="", end="", color="yellow") - elif event_type == "step_progress": - # HACK: if previous was not step/event was not inference's step_progress - # this is the first time we are getting model inference response - # aka equivalent to step_start for inference. Hence, - # start with "Model>". - if previous_event_type != "step_progress" and previous_step_type != "inference": - yield LogEvent(role=step_type, content="", end="", color="yellow") - - if event.payload.tool_call_delta: - if isinstance(event.payload.tool_call_delta.content, str): - yield LogEvent( - role=None, - content=event.payload.tool_call_delta.content, - end="", - color="cyan", - ) - else: + + if event.payload.tool_call_delta: + if isinstance(event.payload.tool_call_delta.content, str): yield LogEvent( role=None, - content=event.payload.text_delta_model_response, + content=event.payload.tool_call_delta.content, end="", - color="yellow", + color="cyan", ) else: - # step complete - yield LogEvent(role=None, content="") - - # handle tool_execution - if step_type == "tool_execution" and event_type == "step_complete": - # Only print tool calls and responses at the step_complete event - details = event.payload.step_details - for t in details.tool_calls: yield LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", + role=None, + content=event.payload.text_delta_model_response, + end="", + color="yellow", ) + else: + # step complete + yield LogEvent(role=None, content="") + + # handle tool_execution + if step_type == "tool_execution" and event_type == "step_complete": + # Only print tool calls and responses at the step_complete event + details = event.payload.step_details + for t in details.tool_calls: + yield LogEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ) - for r in details.tool_responses: - yield LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", - ) - - # memory retrieval - if step_type == "memory_retrieval" and event_type == "step_complete": - details = event.payload.step_details - content = interleaved_text_media_as_str(details.inserted_context) - content = content[:200] + "..." if len(content) > 200 else content - + for r in details.tool_responses: yield LogEvent( role=step_type, - content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>", - color="cyan", + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", ) - previous_event_type = event_type - previous_step_type = step_type + # memory retrieval + if step_type == "memory_retrieval" and event_type == "step_complete": + details = event.payload.step_details + content = interleaved_text_media_as_str(details.inserted_context) + content = content[:200] + "..." if len(content) > 200 else content + + yield LogEvent( + role=step_type, + content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>", + color="cyan", + ) + + def _get_event_type_step_type(self, chunk): + if hasattr(chunk, "event"): + previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None + previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None + return previous_event_type, previous_step_type + return None, None + + def log(self, event_generator): + previous_event_type = None + previous_step_type = None + + for chunk in event_generator: + for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): + yield log_event + previous_event_type, previous_step_type = self._get_event_type_step_type(chunk) \ No newline at end of file diff --git a/tests/test_client.py b/tests/test_client.py index 45d816e..c1f8496 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1539,7 +1539,7 @@ def test_get_platform(self) -> None: import threading from llama_stack_client._utils import asyncify - from llama_stack_client._base_client import get_platform + from llama_stack_client._base_client import get_platform async def test_main() -> None: result = await asyncify(get_platform)()