From 8f27f7df9f1c24c29c4c3f8851132f8d99287094 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 13:17:16 -0800 Subject: [PATCH 1/9] move sync create_turn, event_logger --- src/llama_stack_client/lib/agents/agent.py | 64 +++++-- .../lib/agents/event_logger.py | 169 +++++++++--------- 2 files changed, 141 insertions(+), 92 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 37ed32b..4e43874 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -9,7 +9,9 @@ from llama_stack_client.types import Attachment, ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig from .custom_tool import CustomTool +from rich.console import Console +console = Console() class Agent: def __init__(self, client: LlamaStackClient, agent_config: AgentConfig, custom_tools: Tuple[CustomTool] = ()): @@ -35,7 +37,7 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - async def create_turn( + def _create_turn( self, messages: List[Union[UserMessage, ToolResponseMessage]], attachments: Optional[List[Attachment]] = None, @@ -73,16 +75,56 @@ async def create_turn( content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", role="ipython", ) - next_message = m + yield m else: - tool = self.custom_tools[tool_call.tool_name] - result_messages = await self.execute_custom_tool(tool, message) - next_message = result_messages[0] + yield chunk - yield next_message + def _has_tool_call(self, chunk): + if chunk.event.payload.event_type != "turn_complete": + return False + message = chunk.event.payload.turn.output_message + if len(message.tool_calls) == 0: + return False + if message.stop_reason == "out_of_tokens": + return False + return True - async def execute_custom_tool( - self, tool: CustomTool, message: Union[UserMessage, ToolResponseMessage] - ) -> List[Union[UserMessage, ToolResponseMessage]]: - result_messages = await tool.run([message]) - return result_messages + def create_turn( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + attachments: Optional[List[Attachment]] = None, + session_id: Optional[str] = None, + ): + response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + attachments=attachments, + stream=True, + ) + for chunk in response: + if not self._has_tool_call(chunk): + yield chunk + else: + console.print(chunk) + + async def async_create_turn( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + attachments: Optional[List[Attachment]] = None, + session_id: Optional[str] = None, + ): + response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + attachments=attachments, + stream=True, + ) + for chunk in response: + if not self._has_tool_call(chunk): + yield chunk + else: + console.print(chunk) \ No newline at end of file diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 4e935d9..0541559 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -48,100 +48,107 @@ def print(self, flush=True): class EventLogger: - async def log(self, event_generator): - previous_event_type = None - previous_step_type = None - - async for chunk in event_generator: - if not hasattr(chunk, "event"): - # Need to check for custom tool first - # since it does not produce event but instead - # a Message - if isinstance(chunk, ToolResponseMessage): - yield LogEvent(role="CustomTool", content=chunk.content, color="green") - continue - - event = chunk.event - event_type = event.payload.event_type - - if event_type in {"turn_start", "turn_complete"}: - # Currently not logging any turn realted info - yield LogEvent(role=None, content="", end="", color="grey") - continue - - step_type = event.payload.step_type - # handle safety - if step_type == "shield_call" and event_type == "step_complete": - violation = event.payload.step_details.violation - if not violation: - yield LogEvent(role=step_type, content="No Violation", color="magenta") - else: - yield LogEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", - ) + def get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): + if not hasattr(chunk, "event"): + # Need to check for custom tool first + # since it does not produce event but instead + # a Message + if isinstance(chunk, ToolResponseMessage): + yield LogEvent(role="CustomTool", content=chunk.content, color="green") + return + + event = chunk.event + event_type = event.payload.event_type + + if event_type in {"turn_start", "turn_complete"}: + # Currently not logging any turn realted info + yield LogEvent(role=None, content="", end="", color="grey") + return + + step_type = event.payload.step_type + # handle safety + if step_type == "shield_call" and event_type == "step_complete": + violation = event.payload.step_details.violation + if not violation: + yield LogEvent(role=step_type, content="No Violation", color="magenta") + else: + yield LogEvent( + role=step_type, + content=f"{violation.metadata} {violation.user_message}", + color="red", + ) - # handle inference - if step_type == "inference": - if event_type == "step_start": + # handle inference + if step_type == "inference": + if event_type == "step_start": + yield LogEvent(role=step_type, content="", end="", color="yellow") + elif event_type == "step_progress": + # HACK: if previous was not step/event was not inference's step_progress + # this is the first time we are getting model inference response + # aka equivalent to step_start for inference. Hence, + # start with "Model>". + if previous_event_type != "step_progress" and previous_step_type != "inference": yield LogEvent(role=step_type, content="", end="", color="yellow") - elif event_type == "step_progress": - # HACK: if previous was not step/event was not inference's step_progress - # this is the first time we are getting model inference response - # aka equivalent to step_start for inference. Hence, - # start with "Model>". - if previous_event_type != "step_progress" and previous_step_type != "inference": - yield LogEvent(role=step_type, content="", end="", color="yellow") - - if event.payload.tool_call_delta: - if isinstance(event.payload.tool_call_delta.content, str): - yield LogEvent( - role=None, - content=event.payload.tool_call_delta.content, - end="", - color="cyan", - ) - else: + + if event.payload.tool_call_delta: + if isinstance(event.payload.tool_call_delta.content, str): yield LogEvent( role=None, - content=event.payload.text_delta_model_response, + content=event.payload.tool_call_delta.content, end="", - color="yellow", + color="cyan", ) else: - # step complete - yield LogEvent(role=None, content="") - - # handle tool_execution - if step_type == "tool_execution" and event_type == "step_complete": - # Only print tool calls and responses at the step_complete event - details = event.payload.step_details - for t in details.tool_calls: yield LogEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", + role=None, + content=event.payload.text_delta_model_response, + end="", + color="yellow", ) + else: + # step complete + yield LogEvent(role=None, content="") + + # handle tool_execution + if step_type == "tool_execution" and event_type == "step_complete": + # Only print tool calls and responses at the step_complete event + details = event.payload.step_details + for t in details.tool_calls: + yield LogEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ) - for r in details.tool_responses: - yield LogEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", - ) - - # memory retrieval - if step_type == "memory_retrieval" and event_type == "step_complete": - details = event.payload.step_details - content = interleaved_text_media_as_str(details.inserted_context) - content = content[:200] + "..." if len(content) > 200 else content - + for r in details.tool_responses: yield LogEvent( role=step_type, - content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>", - color="cyan", + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", ) + # memory retrieval + if step_type == "memory_retrieval" and event_type == "step_complete": + details = event.payload.step_details + content = interleaved_text_media_as_str(details.inserted_context) + content = content[:200] + "..." if len(content) > 200 else content + + yield LogEvent( + role=step_type, + content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>", + color="cyan", + ) + + def log(self, event_generator): + previous_event_type = None + previous_step_type = None + + for chunk in event_generator: + for log_event in self.get_log_event(chunk, previous_event_type, previous_step_type): + yield log_event + + event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None + step_type = chunk.event.payload.step_type if event_type not in {"turn_start", "turn_complete"} else None + previous_event_type = event_type previous_step_type = step_type From a10e8b195bd5599472f53e448efb7f37e2095fcb Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 13:26:42 -0800 Subject: [PATCH 2/9] agent lib refactor --- src/llama_stack_client/lib/agents/agent.py | 73 +++++++++------------- 1 file changed, 31 insertions(+), 42 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 4e43874..9c79b1c 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -37,57 +37,46 @@ def create_session(self, session_name: str) -> int: self.sessions.append(self.session_id) return self.session_id - def _create_turn( - self, - messages: List[Union[UserMessage, ToolResponseMessage]], - attachments: Optional[List[Attachment]] = None, - session_id: Optional[str] = None, - ): - response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - attachments=attachments, - stream=True, - ) - turn = None - for chunk in response: - if chunk.event.payload.event_type != "turn_complete": - yield chunk - else: - turn = chunk.event.payload.turn - - message = turn.output_message + def _has_tool_call(self, chunk): + if chunk.event.payload.event_type != "turn_complete": + return False + message = chunk.event.payload.turn.output_message if len(message.tool_calls) == 0: - yield chunk - return - + return False if message.stop_reason == "out_of_tokens": - yield chunk - return + return False + return True + async def _async_run_tool(self, chunk): + message = chunk.event.payload.turn.output_message tool_call = message.tool_calls[0] if tool_call.tool_name not in self.custom_tools: - m = ToolResponseMessage( + return ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", role="ipython", ) - yield m - else: - yield chunk + tool = self.custom_tools[tool_call.tool_name] + result_messages = await tool.async_run([message]) + next_message = result_messages[0] + return next_message - def _has_tool_call(self, chunk): - if chunk.event.payload.event_type != "turn_complete": - return False + def _run_tool(self, chunk): message = chunk.event.payload.turn.output_message - if len(message.tool_calls) == 0: - return False - if message.stop_reason == "out_of_tokens": - return False - return True + tool_call = message.tool_calls[0] + if tool_call.tool_name not in self.custom_tools: + return ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", + role="ipython", + ) + tool = self.custom_tools[tool_call.tool_name] + result_messages = tool.run([message]) + next_message = result_messages[0] + return next_message + def create_turn( self, @@ -107,8 +96,8 @@ def create_turn( if not self._has_tool_call(chunk): yield chunk else: - console.print(chunk) - + yield self._run_tool(chunk) + async def async_create_turn( self, messages: List[Union[UserMessage, ToolResponseMessage]], @@ -127,4 +116,4 @@ async def async_create_turn( if not self._has_tool_call(chunk): yield chunk else: - console.print(chunk) \ No newline at end of file + yield await self._async_run_tool(chunk) \ No newline at end of file From 68f6f59eb4870a94b6d77ade95919198edace3fc Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 13:44:42 -0800 Subject: [PATCH 3/9] sync --- src/llama_stack_client/lib/agents/agent.py | 23 +++---------------- .../lib/agents/event_logger.py | 18 +++++++++++++-- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 9c79b1c..e3b8275 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -58,26 +58,10 @@ async def _async_run_tool(self, chunk): role="ipython", ) tool = self.custom_tools[tool_call.tool_name] - result_messages = await tool.async_run([message]) + result_messages = await tool.run([message]) next_message = result_messages[0] return next_message - def _run_tool(self, chunk): - message = chunk.event.payload.turn.output_message - tool_call = message.tool_calls[0] - if tool_call.tool_name not in self.custom_tools: - return ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", - role="ipython", - ) - tool = self.custom_tools[tool_call.tool_name] - result_messages = tool.run([message]) - next_message = result_messages[0] - return next_message - - def create_turn( self, messages: List[Union[UserMessage, ToolResponseMessage]], @@ -95,8 +79,6 @@ def create_turn( for chunk in response: if not self._has_tool_call(chunk): yield chunk - else: - yield self._run_tool(chunk) async def async_create_turn( self, @@ -116,4 +98,5 @@ async def async_create_turn( if not self._has_tool_call(chunk): yield chunk else: - yield await self._async_run_tool(chunk) \ No newline at end of file + next_message = await self._async_run_tool(chunk) + yield next_message \ No newline at end of file diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 0541559..601d45a 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -48,7 +48,7 @@ def print(self, flush=True): class EventLogger: - def get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): + def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None): if not hasattr(chunk, "event"): # Need to check for custom tool first # since it does not produce event but instead @@ -139,12 +139,26 @@ def get_log_event(self, chunk, previous_event_type=None, previous_step_type=None color="cyan", ) + async def async_log(self, event_generator): + previous_event_type = None + previous_step_type = None + + async for chunk in event_generator: + for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): + yield log_event + + event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None + step_type = chunk.event.payload.step_type if event_type not in {"turn_start", "turn_complete"} else None + + previous_event_type = event_type + previous_step_type = step_type + def log(self, event_generator): previous_event_type = None previous_step_type = None for chunk in event_generator: - for log_event in self.get_log_event(chunk, previous_event_type, previous_step_type): + for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): yield log_event event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None From 30f7aaa6ccafb455fcfc7f4f13d0328292165c96 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 13:51:59 -0800 Subject: [PATCH 4/9] simplify --- .../lib/agents/event_logger.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 601d45a..dfeb4f2 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -146,12 +146,10 @@ async def async_log(self, event_generator): async for chunk in event_generator: for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): yield log_event - - event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None - step_type = chunk.event.payload.step_type if event_type not in {"turn_start", "turn_complete"} else None - - previous_event_type = event_type - previous_step_type = step_type + + if hasattr(chunk, "event"): + previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None + previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None def log(self, event_generator): previous_event_type = None @@ -161,8 +159,6 @@ def log(self, event_generator): for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): yield log_event - event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None - step_type = chunk.event.payload.step_type if event_type not in {"turn_start", "turn_complete"} else None - - previous_event_type = event_type - previous_step_type = step_type + if hasattr(chunk, "event"): + previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None + previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None \ No newline at end of file From 00406b214732d6e6c78b6a1432aee413f2d65494 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 13:52:33 -0800 Subject: [PATCH 5/9] simplify --- src/llama_stack_client/lib/agents/agent.py | 2 +- src/llama_stack_client/lib/agents/event_logger.py | 6 +++--- tests/test_client.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index e3b8275..3c226ca 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -99,4 +99,4 @@ async def async_create_turn( yield chunk else: next_message = await self._async_run_tool(chunk) - yield next_message \ No newline at end of file + yield next_message diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index dfeb4f2..8021335 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -146,11 +146,11 @@ async def async_log(self, event_generator): async for chunk in event_generator: for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): yield log_event - + if hasattr(chunk, "event"): previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None - + def log(self, event_generator): previous_event_type = None previous_step_type = None @@ -161,4 +161,4 @@ def log(self, event_generator): if hasattr(chunk, "event"): previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None - previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None \ No newline at end of file + previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None diff --git a/tests/test_client.py b/tests/test_client.py index 45d816e..c1f8496 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1539,7 +1539,7 @@ def test_get_platform(self) -> None: import threading from llama_stack_client._utils import asyncify - from llama_stack_client._base_client import get_platform + from llama_stack_client._base_client import get_platform async def test_main() -> None: result = await asyncify(get_platform)() From 9b170b3c7215e0968a0e40e289882098564e4901 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 13:54:42 -0800 Subject: [PATCH 6/9] remove unused --- src/llama_stack_client/lib/agents/agent.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 3c226ca..e778c9f 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -9,9 +9,6 @@ from llama_stack_client.types import Attachment, ToolResponseMessage, UserMessage from llama_stack_client.types.agent_create_params import AgentConfig from .custom_tool import CustomTool -from rich.console import Console - -console = Console() class Agent: def __init__(self, client: LlamaStackClient, agent_config: AgentConfig, custom_tools: Tuple[CustomTool] = ()): From f61b3b83960d16fbf4051274ce57df2b9e5ba463 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 14:09:59 -0800 Subject: [PATCH 7/9] improve error msg --- src/llama_stack_client/lib/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index e778c9f..fb738d4 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -51,7 +51,7 @@ async def _async_run_tool(self, chunk): return ToolResponseMessage( call_id=tool_call.call_id, tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else", + content=f"Unknown tool `{tool_call.tool_name}` was called.", role="ipython", ) tool = self.custom_tools[tool_call.tool_name] From a680de16fb6cfe5f07731d6c7ba10025d779c2ae Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 14:10:31 -0800 Subject: [PATCH 8/9] check --- src/llama_stack_client/lib/agents/agent.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index fb738d4..34ed6ac 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -38,11 +38,9 @@ def _has_tool_call(self, chunk): if chunk.event.payload.event_type != "turn_complete": return False message = chunk.event.payload.turn.output_message - if len(message.tool_calls) == 0: - return False if message.stop_reason == "out_of_tokens": return False - return True + return len(message.tool_calls) > 0 async def _async_run_tool(self, chunk): message = chunk.event.payload.turn.output_message From 72087c413e70d2cee48622c358589c9514d0ae84 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 21 Nov 2024 14:32:38 -0800 Subject: [PATCH 9/9] all sync --- src/llama_stack_client/lib/agents/agent.py | 24 +++---------------- .../lib/agents/custom_tool.py | 2 +- .../lib/agents/event_logger.py | 22 ++++++----------- 3 files changed, 11 insertions(+), 37 deletions(-) diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 34ed6ac..1d89c8b 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -42,7 +42,7 @@ def _has_tool_call(self, chunk): return False return len(message.tool_calls) > 0 - async def _async_run_tool(self, chunk): + def _run_tool(self, chunk): message = chunk.event.payload.turn.output_message tool_call = message.tool_calls[0] if tool_call.tool_name not in self.custom_tools: @@ -53,7 +53,7 @@ async def _async_run_tool(self, chunk): role="ipython", ) tool = self.custom_tools[tool_call.tool_name] - result_messages = await tool.run([message]) + result_messages = tool.run([message]) next_message = result_messages[0] return next_message @@ -62,24 +62,6 @@ def create_turn( messages: List[Union[UserMessage, ToolResponseMessage]], attachments: Optional[List[Attachment]] = None, session_id: Optional[str] = None, - ): - response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - attachments=attachments, - stream=True, - ) - for chunk in response: - if not self._has_tool_call(chunk): - yield chunk - - async def async_create_turn( - self, - messages: List[Union[UserMessage, ToolResponseMessage]], - attachments: Optional[List[Attachment]] = None, - session_id: Optional[str] = None, ): response = self.client.agents.turn.create( agent_id=self.agent_id, @@ -93,5 +75,5 @@ async def async_create_turn( if not self._has_tool_call(chunk): yield chunk else: - next_message = await self._async_run_tool(chunk) + next_message = self._run_tool(chunk) yield next_message diff --git a/src/llama_stack_client/lib/agents/custom_tool.py b/src/llama_stack_client/lib/agents/custom_tool.py index 75907a5..e1205ca 100644 --- a/src/llama_stack_client/lib/agents/custom_tool.py +++ b/src/llama_stack_client/lib/agents/custom_tool.py @@ -61,7 +61,7 @@ def get_tool_definition(self) -> FunctionCallToolDefinition: ) @abstractmethod - async def run( + def run( self, messages: List[Union[UserMessage, ToolResponseMessage]] ) -> List[Union[UserMessage, ToolResponseMessage]]: raise NotImplementedError diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 8021335..e356463 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -139,17 +139,12 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non color="cyan", ) - async def async_log(self, event_generator): - previous_event_type = None - previous_step_type = None - - async for chunk in event_generator: - for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): - yield log_event - - if hasattr(chunk, "event"): - previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None - previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None + def _get_event_type_step_type(self, chunk): + if hasattr(chunk, "event"): + previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None + previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None + return previous_event_type, previous_step_type + return None, None def log(self, event_generator): previous_event_type = None @@ -158,7 +153,4 @@ def log(self, event_generator): for chunk in event_generator: for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): yield log_event - - if hasattr(chunk, "event"): - previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None - previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None + previous_event_type, previous_step_type = self._get_event_type_step_type(chunk) \ No newline at end of file