Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sync for Agent SDK #45

Merged
merged 9 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 27 additions & 36 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from llama_stack_client.types.agent_create_params import AgentConfig
from .custom_tool import CustomTool


class Agent:
def __init__(self, client: LlamaStackClient, agent_config: AgentConfig, custom_tools: Tuple[CustomTool] = ()):
self.client = client
Expand All @@ -35,7 +34,30 @@ def create_session(self, session_name: str) -> int:
self.sessions.append(self.session_id)
return self.session_id

async def create_turn(
def _has_tool_call(self, chunk):
if chunk.event.payload.event_type != "turn_complete":
return False
message = chunk.event.payload.turn.output_message
if message.stop_reason == "out_of_tokens":
yanxi0830 marked this conversation as resolved.
Show resolved Hide resolved
return False
return len(message.tool_calls) > 0

def _run_tool(self, chunk):
message = chunk.event.payload.turn.output_message
tool_call = message.tool_calls[0]
if tool_call.tool_name not in self.custom_tools:
return ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called.",
role="ipython",
)
tool = self.custom_tools[tool_call.tool_name]
result_messages = tool.run([message])
next_message = result_messages[0]
return next_message

def create_turn(
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
attachments: Optional[List[Attachment]] = None,
Expand All @@ -49,40 +71,9 @@ async def create_turn(
attachments=attachments,
stream=True,
)
turn = None
for chunk in response:
if chunk.event.payload.event_type != "turn_complete":
if not self._has_tool_call(chunk):
yield chunk
else:
yanxi0830 marked this conversation as resolved.
Show resolved Hide resolved
turn = chunk.event.payload.turn

message = turn.output_message
if len(message.tool_calls) == 0:
yield chunk
return

if message.stop_reason == "out_of_tokens":
yield chunk
return

tool_call = message.tool_calls[0]
if tool_call.tool_name not in self.custom_tools:
m = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called. Try again with something else",
role="ipython",
)
next_message = m
else:
tool = self.custom_tools[tool_call.tool_name]
result_messages = await self.execute_custom_tool(tool, message)
next_message = result_messages[0]

yield next_message

async def execute_custom_tool(
self, tool: CustomTool, message: Union[UserMessage, ToolResponseMessage]
) -> List[Union[UserMessage, ToolResponseMessage]]:
result_messages = await tool.run([message])
return result_messages
next_message = self._run_tool(chunk)
yield next_message
yanxi0830 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion src/llama_stack_client/lib/agents/custom_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_tool_definition(self) -> FunctionCallToolDefinition:
)

@abstractmethod
async def run(
def run(
self, messages: List[Union[UserMessage, ToolResponseMessage]]
) -> List[Union[UserMessage, ToolResponseMessage]]:
raise NotImplementedError
175 changes: 92 additions & 83 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,100 +48,109 @@ def print(self, flush=True):


class EventLogger:
async def log(self, event_generator):
previous_event_type = None
previous_step_type = None

async for chunk in event_generator:
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield LogEvent(role="CustomTool", content=chunk.content, color="green")
continue

event = chunk.event
event_type = event.payload.event_type

if event_type in {"turn_start", "turn_complete"}:
# Currently not logging any turn realted info
yield LogEvent(role=None, content="", end="", color="grey")
continue

step_type = event.payload.step_type
# handle safety
if step_type == "shield_call" and event_type == "step_complete":
violation = event.payload.step_details.violation
if not violation:
yield LogEvent(role=step_type, content="No Violation", color="magenta")
else:
yield LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
)
def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=None):
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield LogEvent(role="CustomTool", content=chunk.content, color="green")
return

event = chunk.event
event_type = event.payload.event_type

if event_type in {"turn_start", "turn_complete"}:
# Currently not logging any turn realted info
yield LogEvent(role=None, content="", end="", color="grey")
return

step_type = event.payload.step_type
# handle safety
if step_type == "shield_call" and event_type == "step_complete":
violation = event.payload.step_details.violation
if not violation:
yield LogEvent(role=step_type, content="No Violation", color="magenta")
else:
yield LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
)

# handle inference
if step_type == "inference":
if event_type == "step_start":
# handle inference
if step_type == "inference":
if event_type == "step_start":
yield LogEvent(role=step_type, content="", end="", color="yellow")
elif event_type == "step_progress":
# HACK: if previous was not step/event was not inference's step_progress
# this is the first time we are getting model inference response
# aka equivalent to step_start for inference. Hence,
# start with "Model>".
if previous_event_type != "step_progress" and previous_step_type != "inference":
yield LogEvent(role=step_type, content="", end="", color="yellow")
elif event_type == "step_progress":
# HACK: if previous was not step/event was not inference's step_progress
# this is the first time we are getting model inference response
# aka equivalent to step_start for inference. Hence,
# start with "Model>".
if previous_event_type != "step_progress" and previous_step_type != "inference":
yield LogEvent(role=step_type, content="", end="", color="yellow")

if event.payload.tool_call_delta:
if isinstance(event.payload.tool_call_delta.content, str):
yield LogEvent(
role=None,
content=event.payload.tool_call_delta.content,
end="",
color="cyan",
)
else:

if event.payload.tool_call_delta:
if isinstance(event.payload.tool_call_delta.content, str):
yield LogEvent(
role=None,
content=event.payload.text_delta_model_response,
content=event.payload.tool_call_delta.content,
end="",
color="yellow",
color="cyan",
)
else:
# step complete
yield LogEvent(role=None, content="")

# handle tool_execution
if step_type == "tool_execution" and event_type == "step_complete":
# Only print tool calls and responses at the step_complete event
details = event.payload.step_details
for t in details.tool_calls:
yield LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
role=None,
content=event.payload.text_delta_model_response,
end="",
color="yellow",
)
else:
# step complete
yield LogEvent(role=None, content="")

# handle tool_execution
if step_type == "tool_execution" and event_type == "step_complete":
# Only print tool calls and responses at the step_complete event
details = event.payload.step_details
for t in details.tool_calls:
yield LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
)

for r in details.tool_responses:
yield LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
)

# memory retrieval
if step_type == "memory_retrieval" and event_type == "step_complete":
details = event.payload.step_details
content = interleaved_text_media_as_str(details.inserted_context)
content = content[:200] + "..." if len(content) > 200 else content

for r in details.tool_responses:
yield LogEvent(
role=step_type,
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
color="cyan",
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
)

previous_event_type = event_type
previous_step_type = step_type
# memory retrieval
if step_type == "memory_retrieval" and event_type == "step_complete":
details = event.payload.step_details
content = interleaved_text_media_as_str(details.inserted_context)
content = content[:200] + "..." if len(content) > 200 else content

yield LogEvent(
role=step_type,
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
color="cyan",
)

def _get_event_type_step_type(self, chunk):
if hasattr(chunk, "event"):
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
previous_step_type = chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None
return previous_event_type, previous_step_type
return None, None

def log(self, event_generator):
previous_event_type = None
previous_step_type = None

for chunk in event_generator:
for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type):
yield log_event
previous_event_type, previous_step_type = self._get_event_type_step_type(chunk)
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ def test_get_platform(self) -> None:
import threading

from llama_stack_client._utils import asyncify
from llama_stack_client._base_client import get_platform
from llama_stack_client._base_client import get_platform

async def test_main() -> None:
result = await asyncify(get_platform)()
Expand Down