Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
yanxi0830 committed Nov 21, 2024
1 parent a10e8b1 commit 68f6f59
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
23 changes: 3 additions & 20 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,10 @@ async def _async_run_tool(self, chunk):
role="ipython",
)
tool = self.custom_tools[tool_call.tool_name]
result_messages = await tool.async_run([message])
result_messages = await tool.run([message])
next_message = result_messages[0]
return next_message

def _run_tool(self, chunk):
message = chunk.event.payload.turn.output_message
tool_call = message.tool_calls[0]
if tool_call.tool_name not in self.custom_tools:
return ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
role="ipython",
)
tool = self.custom_tools[tool_call.tool_name]
result_messages = tool.run([message])
next_message = result_messages[0]
return next_message


def create_turn(
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
Expand All @@ -95,8 +79,6 @@ def create_turn(
for chunk in response:
if not self._has_tool_call(chunk):
yield chunk
else:
yield self._run_tool(chunk)

async def async_create_turn(
self,
Expand All @@ -116,4 +98,5 @@ async def async_create_turn(
if not self._has_tool_call(chunk):
yield chunk
else:
yield await self._async_run_tool(chunk)
next_message = await self._async_run_tool(chunk)
yield next_message
18 changes: 16 additions & 2 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def print(self, flush=True):


class EventLogger:
def get_log_event(self, chunk, previous_event_type=None, previous_step_type=None):
def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None):
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
Expand Down Expand Up @@ -139,12 +139,26 @@ def get_log_event(self, chunk, previous_event_type=None, previous_step_type=None
color="cyan",
)

async def async_log(self, event_generator):
previous_event_type = None
previous_step_type = None

async for chunk in event_generator:
for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type):
yield log_event

event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
step_type = chunk.event.payload.step_type if event_type not in {"turn_start", "turn_complete"} else None

previous_event_type = event_type
previous_step_type = step_type

def log(self, event_generator):
previous_event_type = None
previous_step_type = None

for chunk in event_generator:
for log_event in self.get_log_event(chunk, previous_event_type, previous_step_type):
for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type):
yield log_event

event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
Expand Down

0 comments on commit 68f6f59

Please sign in to comment.