From 4c20ede53d26a6b315871c67df9290bd72a7ad9b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 24 Sep 2024 09:26:34 -0700 Subject: [PATCH] add event logger util --- src/llama_stack/lib/agents/__init__.py | 5 + src/llama_stack/lib/agents/event_logger.py | 162 +++++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 src/llama_stack/lib/agents/__init__.py create mode 100644 src/llama_stack/lib/agents/event_logger.py diff --git a/src/llama_stack/lib/agents/__init__.py b/src/llama_stack/lib/agents/__init__.py new file mode 100644 index 0000000..756f351 --- /dev/null +++ b/src/llama_stack/lib/agents/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack/lib/agents/event_logger.py b/src/llama_stack/lib/agents/event_logger.py new file mode 100644 index 0000000..d894f59 --- /dev/null +++ b/src/llama_stack/lib/agents/event_logger.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Optional, Union + +from llama_stack.types import ToolResponseMessage +from llama_stack.types.agents import AgentsTurnStreamChunk +from termcolor import cprint + + +def interleaved_text_media_as_str( + content: Union[str, List[str]], sep: str = " " +) -> str: + def _process(c) -> str: + if isinstance(c, str): + return c + else: + return "" + + if isinstance(content, list): + return sep.join(_process(c) for c in content) + else: + return _process(content) + + +class LogEvent: + def __init__( + self, + role: Optional[str] = None, + content: str = "", + end: str = "\n", + color="white", + ): + self.role = role + self.content = content + self.color = color + self.end = "\n" if end is None else end + + def __str__(self): + if self.role is not None: + return f"{self.role}> {self.content}" + else: + return f"{self.content}" + + def print(self, flush=True): + cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) + + +class EventLogger: + async def log(self, event_generator): + previous_event_type = None + previous_step_type = None + + async for chunk in event_generator: + if not hasattr(chunk, "event"): + # Need to check for custom tool first + # since it does not produce event but instead + # a Message + if isinstance(chunk, ToolResponseMessage): + yield LogEvent( + role="CustomTool", content=chunk.content, color="grey" + ) + continue + + if not isinstance(chunk, AgentsTurnStreamChunk): + yield LogEvent(chunk, color="yellow") + continue + + event = chunk.event + event_type = event.payload.event_type + + if event_type in {"turn_start", "turn_complete"}: + # Currently not logging any turn realted info + yield None + continue + + step_type = event.payload.step_type + # handle safety + if step_type == "shield_call" and event_type == "step_complete": + violation = event.payload.step_details.violation + if not violation: + yield LogEvent( + role=step_type, content="No Violation", color="magenta" + ) + else: + yield LogEvent( + role=step_type, + content=f"{violation.metadata} {violation.user_message}", + color="red", + ) + + # handle inference + if step_type == "inference": + if event_type == "step_start": + yield LogEvent(role=step_type, content="", end="", color="yellow") + elif event_type == "step_progress": + # HACK: if previous was not step/event was not inference's step_progress + # this is the first time we are getting model inference response + # aka equivalent to step_start for inference. Hence, + # start with "Model>". + if ( + previous_event_type != "step_progress" + and previous_step_type != "inference" + ): + yield LogEvent( + role=step_type, content="", end="", color="yellow" + ) + + if event.payload.tool_call_delta: + if isinstance(event.payload.tool_call_delta.content, str): + yield LogEvent( + role=None, + content=event.payload.tool_call_delta.content, + end="", + color="cyan", + ) + else: + yield LogEvent( + role=None, + content=event.payload.text_delta_model_response, + end="", + color="yellow", + ) + else: + # step complete + yield LogEvent(role=None, content="") + + # handle tool_execution + if step_type == "tool_execution" and event_type == "step_complete": + # Only print tool calls and responses at the step_complete event + details = event.payload.step_details + for t in details.tool_calls: + yield LogEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ) + + for r in details.tool_responses: + yield LogEvent( + role=step_type, + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", + ) + + # memory retrieval + if step_type == "memory_retrieval" and event_type == "step_complete": + details = event.payload.step_details + content = interleaved_text_media_as_str(details.inserted_context) + content = content[:200] + "..." if len(content) > 200 else content + + yield LogEvent( + role=step_type, + content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>", + color="cyan", + ) + + preivous_event_type = event_type + previous_step_type = step_type