From f90e9c2003e93b78c1ecda21acf9faa4fdb52fb4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 20 Dec 2024 14:46:32 -0800 Subject: [PATCH 01/53] agents to use tools api --- llama_stack/apis/agents/agents.py | 93 +------ llama_stack/apis/tools/tools.py | 8 +- llama_stack/distribution/datatypes.py | 1 + llama_stack/distribution/resolver.py | 5 +- .../distribution/routers/routing_tables.py | 2 + llama_stack/distribution/stack.py | 5 +- .../providers/tests/agents/conftest.py | 5 + .../inline/agents/meta_reference/__init__.py | 2 + .../agents/meta_reference/agent_instance.py | 245 ++++------------- .../inline/agents/meta_reference/agents.py | 9 +- .../agents/meta_reference/persistence.py | 14 - .../inline/tool_runtime/memory/__init__.py | 20 ++ .../inline/tool_runtime/memory/config.py | 93 +++++++ .../memory}/context_retriever.py | 11 +- .../inline/tool_runtime/memory/memory.py | 253 ++++++++++++++++++ llama_stack/providers/registry/agents.py | 2 + .../providers/registry/tool_runtime.py | 8 + .../providers/tests/agents/conftest.py | 10 +- .../providers/tests/agents/fixtures.py | 63 ++++- .../providers/tests/agents/test_agents.py | 14 +- llama_stack/providers/tests/resolver.py | 4 +- 21 files changed, 538 insertions(+), 329 deletions(-) create mode 100644 llama_stack/llama_stack/providers/tests/agents/conftest.py create mode 100644 llama_stack/providers/inline/tool_runtime/memory/__init__.py create mode 100644 llama_stack/providers/inline/tool_runtime/memory/config.py rename llama_stack/providers/inline/{agents/meta_reference/rag => tool_runtime/memory}/context_retriever.py (98%) create mode 100644 llama_stack/providers/inline/tool_runtime/memory/memory.py diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5748b4e41b..65be923488 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -14,18 +14,16 @@ Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) from llama_models.llama3.api.datatypes import ToolParamDefinition - from llama_models.schema_utils import json_schema_type, webmethod - from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig from llama_stack.apis.inference import ( CompletionMessage, @@ -40,7 +38,6 @@ ) from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation - from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -110,85 +107,6 @@ class FunctionCallToolDefinition(ToolDefinitionCommon): remote_execution: Optional[RestAPIExecutionConfig] = None -class _MemoryBankConfigCommon(BaseModel): - bank_id: str - - -class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["vector"] = "vector" - - -class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyvalue"] = "keyvalue" - keys: List[str] # what keys to focus on - - -class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["keyword"] = "keyword" - - -class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon): - type: Literal["graph"] = "graph" - entities: List[str] # what entities to focus on - - -MemoryBankConfig = Annotated[ - Union[ - AgentVectorMemoryBankConfig, - AgentKeyValueMemoryBankConfig, - AgentKeywordMemoryBankConfig, - AgentGraphMemoryBankConfig, - ], - Field(discriminator="type"), -] - - -class MemoryQueryGenerator(Enum): - default = "default" - llm = "llm" - custom = "custom" - - -class DefaultMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.default.value] = ( - MemoryQueryGenerator.default.value - ) - sep: str = " " - - -class LLMMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value - model: str - template: str - - -class CustomMemoryQueryGeneratorConfig(BaseModel): - type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value - - -MemoryQueryGeneratorConfig = Annotated[ - Union[ - DefaultMemoryQueryGeneratorConfig, - LLMMemoryQueryGeneratorConfig, - CustomMemoryQueryGeneratorConfig, - ], - Field(discriminator="type"), -] - - -@json_schema_type -class MemoryToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.memory.value] = AgentTool.memory.value - memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) - # This config defines how a query is generated using the messages - # for memory bank retrieval. - query_generator_config: MemoryQueryGeneratorConfig = Field( - default=DefaultMemoryQueryGeneratorConfig() - ) - max_tokens_in_context: int = 4096 - max_chunks: int = 10 - - AgentToolDefinition = Annotated[ Union[ SearchToolDefinition, @@ -196,7 +114,6 @@ class MemoryToolDefinition(ToolDefinitionCommon): PhotogenToolDefinition, CodeInterpreterToolDefinition, FunctionCallToolDefinition, - MemoryToolDefinition, ], Field(discriminator="type"), ] @@ -295,7 +212,11 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list) + tools: Optional[List[AgentToolDefinition]] = Field( + default_factory=list, deprecated=True + ) + available_tools: Optional[List[str]] = Field(default_factory=list) + preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 23110543bd..60b2bdab9f 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -68,10 +68,16 @@ class UserDefinedToolGroupDef(BaseModel): Annotated[ Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type") ], - name="ToolGroup", + name="ToolGroupDef", ) +class ToolGroupInput(BaseModel): + tool_group_id: str + tool_group: ToolGroupDef + provider_id: Optional[str] = None + + class ToolGroup(Resource): type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index dec62bfaee..ba7ba62bd6 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -161,6 +161,7 @@ class StackRunConfig(BaseModel): datasets: List[DatasetInput] = Field(default_factory=list) scoring_fns: List[ScoringFnInput] = Field(default_factory=list) eval_tasks: List[EvalTaskInput] = Field(default_factory=list) + tool_groups: List[ToolGroupInput] = Field(default_factory=list) class BuildConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 0a6eed3458..3ea93301ff 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,9 +5,7 @@ # the root directory of this source tree. import importlib import inspect - import logging - from typing import Any, Dict, List, Set from llama_stack.apis.agents import Agents @@ -28,7 +26,6 @@ from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.distribution.client import get_client_impl - from llama_stack.distribution.datatypes import ( AutoRoutedProviderSpec, Provider, @@ -38,7 +35,7 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type - +from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import ( Api, DatasetsProtocolPrivate, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ab1becfdd9..8d622a5c27 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -523,6 +523,8 @@ async def register_tool_group( ) provider_id = list(self.impls_by_provider_id.keys())[0] + # parse tool group to the type if dict + tool_group = parse_obj_as(ToolGroupDef, tool_group) if isinstance(tool_group, MCPToolGroupDef): tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_group diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 7fc2c76502..9d12303c99 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -12,7 +12,7 @@ import pkg_resources import yaml - +from llama_models.llama3.api.datatypes import * # noqa: F403 from termcolor import colored from llama_stack.apis.agents import Agents @@ -33,14 +33,12 @@ from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry - from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.datatypes import Api - log = logging.getLogger(__name__) LLAMA_STACK_API_VERSION = "alpha" @@ -81,6 +79,7 @@ class LlamaStack( "list_scoring_functions", ), ("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"), + ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"), ] diff --git a/llama_stack/llama_stack/providers/tests/agents/conftest.py b/llama_stack/llama_stack/providers/tests/agents/conftest.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/llama_stack/llama_stack/providers/tests/agents/conftest.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 156de9a17b..50f61fb426 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -22,6 +22,8 @@ async def get_provider_impl( deps[Api.memory], deps[Api.safety], deps[Api.memory_banks], + deps[Api.tool_runtime], + deps[Api.tool_groups], ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 09738d7b7d..00d8bbd363 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -4,25 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import copy import logging import os -import re import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional, Tuple +from typing import AsyncGenerator, Dict, List from urllib.parse import urlparse import httpx - from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, - AgentTool, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -36,8 +32,6 @@ CodeInterpreterToolDefinition, FunctionCallToolDefinition, InferenceStep, - MemoryRetrievalStep, - MemoryToolDefinition, PhotogenToolDefinition, SearchToolDefinition, ShieldCallStep, @@ -46,11 +40,9 @@ Turn, WolframAlphaToolDefinition, ) - from llama_stack.apis.common.content_types import ( - InterleavedContent, - TextContentItem, URL, + TextContentItem, ) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -62,30 +54,26 @@ SystemMessage, ToolCallDelta, ToolCallParseStatus, - ToolChoice, ToolDefinition, ToolResponse, ToolResponseMessage, UserMessage, ) -from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse -from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety - from llama_stack.providers.utils.kvstore import KVStore -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence -from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin from .tools.base import BaseTool from .tools.builtin import ( CodeInterpreterTool, - interpret_content_as_attachment, PhotogenTool, SearchTool, WolframAlphaTool, + interpret_content_as_attachment, ) from .tools.safety import SafeTool @@ -108,6 +96,8 @@ def __init__( memory_api: Memory, memory_banks_api: MemoryBanks, safety_api: Safety, + tool_runtime_api: ToolRuntime, + tool_groups_api: ToolGroups, persistence_store: KVStore, ): self.agent_id = agent_id @@ -118,6 +108,8 @@ def __init__( self.memory_banks_api = memory_banks_api self.safety_api = safety_api self.storage = AgentPersistence(agent_id, persistence_store) + self.tool_runtime_api = tool_runtime_api + self.tool_groups_api = tool_groups_api builtin_tools = [] for tool_defn in agent_config.tools: @@ -392,62 +384,50 @@ async def _run( sampling_params: SamplingParams, stream: bool = False, ) -> AsyncGenerator: - enabled_tools = set(t.type for t in self.agent_config.tools) - need_rag_context = await self._should_retrieve_context( - input_messages, attachments - ) - if need_rag_context: - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.memory_retrieval.value, - step_id=step_id, + if self.agent_config.preprocessing_tools: + with tracing.span("preprocessing_tools") as span: + for tool_name in self.agent_config.preprocessing_tools: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.tool_execution.value, + step_id=str(uuid.uuid4()), + ) + ) ) - ) - ) - - # TODO: find older context from the session and either replace it - # or append with a sliding window. this is really a very simplistic implementation - with tracing.span("retrieve_rag_context") as span: - rag_context, bank_ids = await self._retrieve_context( - session_id, input_messages, attachments - ) - span.set_attribute( - "input", [m.model_dump_json() for m in input_messages] - ) - span.set_attribute("output", rag_context) - span.set_attribute("bank_ids", bank_ids) - - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.memory_retrieval.value, - step_id=step_id, - step_details=MemoryRetrievalStep( - turn_id=turn_id, - step_id=step_id, - memory_bank_ids=bank_ids, - inserted_context=rag_context or "", - ), + args = dict( + session_id=session_id, + input_messages=input_messages, + attachments=attachments, ) - ) - ) - - if rag_context: - last_message = input_messages[-1] - last_message.context = rag_context - - elif attachments and AgentTool.code_interpreter.value in enabled_tools: - urls = [a.content for a in attachments if isinstance(a.content, URL)] - # TODO: we need to migrate URL away from str type - pattern = re.compile("^(https?://|file://|data:)") - urls += [ - URL(uri=a.content) for a in attachments if pattern.match(a.content) - ] - msg = await attachment_message(self.tempdir, urls) - input_messages.append(msg) + result = await self.tool_runtime_api.invoke_tool( + tool_name=tool_name, + args=args, + ) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=str(uuid.uuid4()), + tool_call_delta=ToolCallDelta( + parse_status=ToolCallParseStatus.success, + content=ToolCall( + call_id="", tool_name=tool_name, arguments={} + ), + ), + ) + ) + ) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", result.content) + span.set_attribute("error_code", result.error_code) + span.set_attribute("error_message", result.error_message) + span.set_attribute("tool_name", tool_name) + if result.error_code != 0 and result.content: + last_message = input_messages[-1] + last_message.context = result.content output_attachments = [] @@ -659,129 +639,6 @@ async def _run( n_iter += 1 - async def _ensure_memory_bank(self, session_id: str) -> str: - session_info = await self.storage.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - if session_info.memory_bank_id is None: - bank_id = f"memory_bank_{session_id}" - await self.memory_banks_api.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), - ) - await self.storage.add_memory_bank_to_session(session_id, bank_id) - else: - bank_id = session_info.memory_bank_id - - return bank_id - - async def _should_retrieve_context( - self, messages: List[Message], attachments: List[Attachment] - ) -> bool: - enabled_tools = set(t.type for t in self.agent_config.tools) - if attachments: - if ( - AgentTool.code_interpreter.value in enabled_tools - and self.agent_config.tool_choice == ToolChoice.required - ): - return False - else: - return True - - return AgentTool.memory.value in enabled_tools - - def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]: - for t in self.agent_config.tools: - if t.type == AgentTool.memory.value: - return t - - return None - - async def _retrieve_context( - self, session_id: str, messages: List[Message], attachments: List[Attachment] - ) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids) - bank_ids = [] - - memory = self._memory_tool_definition() - assert memory is not None, "Memory tool not configured" - bank_ids.extend(c.bank_id for c in memory.memory_bank_configs) - - if attachments: - bank_id = await self._ensure_memory_bank(session_id) - bank_ids.append(bank_id) - - documents = [ - MemoryBankDocument( - document_id=str(uuid.uuid4()), - content=a.content, - mime_type=a.mime_type, - metadata={}, - ) - for a in attachments - ] - with tracing.span("insert_documents"): - await self.memory_api.insert_documents(bank_id, documents) - else: - session_info = await self.storage.get_session_info(session_id) - if session_info.memory_bank_id: - bank_ids.append(session_info.memory_bank_id) - - if not bank_ids: - # this can happen if the per-session memory bank is not yet populated - # (i.e., no prior turns uploaded an Attachment) - return None, [] - - query = await generate_rag_query( - memory.query_generator_config, messages, inference_api=self.inference_api - ) - tasks = [ - self.memory_api.query_documents( - bank_id=bank_id, - query=query, - params={ - "max_chunks": 5, - }, - ) - for bank_id in bank_ids - ] - results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) - chunks = [c for r in results for c in r.chunks] - scores = [s for r in results for s in r.scores] - - if not chunks: - return None, bank_ids - - # sort by score - chunks, scores = zip( - *sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) - ) - - tokens = 0 - picked = [] - for c in chunks[: memory.max_chunks]: - tokens += c.token_count - if tokens > memory.max_tokens_in_context: - log.error( - f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", - ) - break - picked.append(f"id:{c.document_id}; content:{c.content}") - - return ( - concat_interleaved_content( - [ - "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - *picked, - "\n=== END-RETRIEVED-CONTEXT ===\n", - ] - ), - bank_ids, - ) - def _get_tools(self) -> List[ToolDefinition]: ret = [] for t in self.agent_config.tools: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 93bfab5f46..89b38a7fc6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -24,12 +24,11 @@ Session, Turn, ) - from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety - +from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from .agent_instance import ChatAgent @@ -47,12 +46,16 @@ def __init__( memory_api: Memory, safety_api: Safety, memory_banks_api: MemoryBanks, + tool_runtime_api: ToolRuntime, + tool_groups_api: ToolGroups, ): self.config = config self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api self.memory_banks_api = memory_banks_api + self.tool_runtime_api = tool_runtime_api + self.tool_groups_api = tool_groups_api self.in_memory_store = InmemoryKVStoreImpl() self.tempdir = tempfile.mkdtemp() @@ -112,6 +115,8 @@ async def get_agent(self, agent_id: str) -> ChatAgent: safety_api=self.safety_api, memory_api=self.memory_api, memory_banks_api=self.memory_banks_api, + tool_runtime_api=self.tool_runtime_api, + tool_groups_api=self.tool_groups_api, persistence_store=( self.persistence_store if agent_config.enable_session_persistence diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index a4b1af616c..144f65863f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -8,13 +8,11 @@ import logging import uuid from datetime import datetime - from typing import List, Optional from pydantic import BaseModel from llama_stack.apis.agents import Turn - from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -23,7 +21,6 @@ class AgentSessionInfo(BaseModel): session_id: str session_name: str - memory_bank_id: Optional[str] = None started_at: datetime @@ -54,17 +51,6 @@ async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: return AgentSessionInfo(**json.loads(value)) - async def add_memory_bank_to_session(self, session_id: str, bank_id: str): - session_info = await self.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") - - session_info.memory_bank_id = bank_id - await self.kvstore.set( - key=f"session:{self.agent_id}:{session_id}", - value=session_info.model_dump_json(), - ) - async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py new file mode 100644 index 0000000000..36377f1471 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from llama_stack.providers.datatypes import Api + +from .config import MemoryToolConfig +from .memory import MemoryToolRuntimeImpl + + +async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]): + impl = MemoryToolRuntimeImpl( + config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] + ) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py new file mode 100644 index 0000000000..cb24883dc0 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/memory/config.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Annotated, List, Literal, Union + +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig + +from pydantic import BaseModel, Field + + +class _MemoryBankConfigCommon(BaseModel): + bank_id: str + + +class VectorMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["vector"] = "vector" + + +class KeyValueMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["keyvalue"] = "keyvalue" + keys: List[str] # what keys to focus on + + +class KeywordMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["keyword"] = "keyword" + + +class GraphMemoryBankConfig(_MemoryBankConfigCommon): + type: Literal["graph"] = "graph" + entities: List[str] # what entities to focus on + + +MemoryBankConfig = Annotated[ + Union[ + VectorMemoryBankConfig, + KeyValueMemoryBankConfig, + KeywordMemoryBankConfig, + GraphMemoryBankConfig, + ], + Field(discriminator="type"), +] + + +class MemoryQueryGenerator(Enum): + default = "default" + llm = "llm" + custom = "custom" + + +class DefaultMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.default.value] = ( + MemoryQueryGenerator.default.value + ) + sep: str = " " + + +class LLMMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value + model: str + template: str + + +class CustomMemoryQueryGeneratorConfig(BaseModel): + type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value + + +MemoryQueryGeneratorConfig = Annotated[ + Union[ + DefaultMemoryQueryGeneratorConfig, + LLMMemoryQueryGeneratorConfig, + CustomMemoryQueryGeneratorConfig, + ], + Field(discriminator="type"), +] + + +class MemoryToolConfig(BaseModel): + memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + # This config defines how a query is generated using the messages + # for memory bank retrieval. + query_generator_config: MemoryQueryGeneratorConfig = Field( + default=DefaultMemoryQueryGeneratorConfig() + ) + max_tokens_in_context: int = 4096 + max_chunks: int = 10 + kvstore_config: KVStoreConfig = SqliteKVStoreConfig( + db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix() + ) diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py similarity index 98% rename from llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py rename to llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 74eb91c53a..da97cb3a3a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -8,16 +8,17 @@ from jinja2 import Template -from llama_stack.apis.agents import ( +from llama_stack.apis.inference import Message, UserMessage +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) + +from .config import ( DefaultMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig, MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) -from llama_stack.apis.inference import Message, UserMessage -from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, -) async def generate_rag_query( diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py new file mode 100644 index 0000000000..3a08bf1f98 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +import logging +import os +import re +import secrets +import string +import tempfile +import uuid +from typing import Any, Dict, List, Optional +from urllib.parse import urlparse + +import httpx + +from llama_stack.apis.agents import Attachment +from llama_stack.apis.common.content_types import TextContentItem, URL +from llama_stack.apis.inference import Inference, InterleavedContent, Message +from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse +from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.tools import ( + ToolDef, + ToolGroupDef, + ToolInvocationResult, + ToolRuntime, +) +from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content +from pydantic import BaseModel + +from .config import MemoryToolConfig +from .context_retriever import generate_rag_query + +log = logging.getLogger(__name__) + + +class MemorySessionInfo(BaseModel): + session_id: str + session_name: str + memory_bank_id: Optional[str] = None + + +def make_random_string(length: int = 8): + return "".join( + secrets.choice(string.ascii_letters + string.digits) for _ in range(length) + ) + + +class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__( + self, + config: MemoryToolConfig, + memory_api: Memory, + memory_banks_api: MemoryBanks, + inference_api: Inference, + ): + self.config = config + self.memory_api = memory_api + self.memory_banks_api = memory_banks_api + self.tempdir = tempfile.mkdtemp() + self.inference_api = inference_api + + async def initialize(self): + self.kvstore = await kvstore_impl(self.config.kvstore_config) + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: + return [] + + async def create_session(self, session_id: str) -> MemorySessionInfo: + session_info = MemorySessionInfo( + session_id=session_id, + session_name=f"session_{session_id}", + ) + await self.kvstore.set( + key=f"memory::session:{session_id}", + value=session_info.model_dump_json(), + ) + return session_info + + async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]: + value = await self.kvstore.get( + key=f"memory::session:{session_id}", + ) + if not value: + session_info = await self.create_session(session_id) + return session_info + + return MemorySessionInfo(**json.loads(value)) + + async def add_memory_bank_to_session(self, session_id: str, bank_id: str): + session_info = await self.get_session_info(session_id) + + session_info.memory_bank_id = bank_id + await self.kvstore.set( + key=f"memory::session:{session_id}", + value=session_info.model_dump_json(), + ) + + async def _ensure_memory_bank(self, session_id: str) -> str: + session_info = await self.get_session_info(session_id) + + if session_info.memory_bank_id is None: + bank_id = f"memory_bank_{session_id}" + await self.memory_banks_api.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + ), + ) + await self.add_memory_bank_to_session(session_id, bank_id) + else: + bank_id = session_info.memory_bank_id + + return bank_id + + async def attachment_message( + self, tempdir: str, urls: List[URL] + ) -> List[TextContentItem]: + content = [] + + for url in urls: + uri = url.uri + if uri.startswith("file://"): + filepath = uri[len("file://") :] + elif uri.startswith("http"): + path = urlparse(uri).path + basename = os.path.basename(path) + filepath = f"{tempdir}/{make_random_string() + basename}" + log.info(f"Downloading {url} -> {filepath}") + + async with httpx.AsyncClient() as client: + r = await client.get(uri) + resp = r.text + with open(filepath, "w") as fp: + fp.write(resp) + else: + raise ValueError(f"Unsupported URL {url}") + + content.append( + TextContentItem( + text=f'# There is a file accessible to you at "{filepath}"\n' + ) + ) + + return content + + async def _retrieve_context( + self, session_id: str, messages: List[Message] + ) -> Optional[List[InterleavedContent]]: + bank_ids = [] + + bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs) + + session_info = await self.get_session_info(session_id) + if session_info.memory_bank_id: + bank_ids.append(session_info.memory_bank_id) + + if not bank_ids: + # this can happen if the per-session memory bank is not yet populated + # (i.e., no prior turns uploaded an Attachment) + return None + + query = await generate_rag_query( + self.config.query_generator_config, + messages, + inference_api=self.inference_api, + ) + tasks = [ + self.memory_api.query_documents( + bank_id=bank_id, + query=query, + params={ + "max_chunks": 5, + }, + ) + for bank_id in bank_ids + ] + results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks) + chunks = [c for r in results for c in r.chunks] + scores = [s for r in results for s in r.scores] + + if not chunks: + return None + + # sort by score + chunks, scores = zip( + *sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True) + ) + + tokens = 0 + picked = [] + for c in chunks[: self.config.max_chunks]: + tokens += c.token_count + if tokens > self.config.max_tokens_in_context: + log.error( + f"Using {len(picked)} chunks; reached max tokens in context: {tokens}", + ) + break + picked.append(f"id:{c.document_id}; content:{c.content}") + + return [ + "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", + *picked, + "\n=== END-RETRIEVED-CONTEXT ===\n", + ] + + async def _process_attachments( + self, session_id: str, attachments: List[Attachment] + ): + bank_id = await self._ensure_memory_bank(session_id) + + documents = [ + MemoryBankDocument( + document_id=str(uuid.uuid4()), + content=a.content, + mime_type=a.mime_type, + metadata={}, + ) + for a in attachments + if isinstance(a.content, str) + ] + await self.memory_api.insert_documents(bank_id, documents) + + urls = [a.content for a in attachments if isinstance(a.content, URL)] + # TODO: we need to migrate URL away from str type + pattern = re.compile("^(https?://|file://|data:)") + urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)] + return await self.attachment_message(self.tempdir, urls) + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + if args["session_id"] is None: + raise ValueError("session_id is required") + + context = await self._retrieve_context( + args["session_id"], args["input_messages"] + ) + if context is None: + context = [] + attachments = args["attachments"] + if attachments and len(attachments) > 0: + context += await self._process_attachments(args["session_id"], attachments) + return ToolInvocationResult( + content=concat_interleaved_content(context), error_code=0 + ) diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 6595b1955e..3e38b1adc8 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]: Api.safety, Api.memory, Api.memory_banks, + Api.tool_runtime, + Api.tool_groups, ], ), remote_provider_spec( diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 042aef9d9e..d0493810c8 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -25,6 +25,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", ), + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::memory-runtime", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.memory", + config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", + api_dependencies=[Api.memory, Api.memory_banks, Api.inference], + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index dbf79e7130..d80013fae3 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -7,12 +7,10 @@ import pytest from ..conftest import get_provider_fixture_overrides - from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield -from .fixtures import AGENTS_FIXTURES - +from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( @@ -21,6 +19,7 @@ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="meta_reference", marks=pytest.mark.meta_reference, @@ -31,6 +30,7 @@ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="ollama", marks=pytest.mark.ollama, @@ -42,6 +42,7 @@ # make this work with Weaviate which is what the together distro supports "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="together", marks=pytest.mark.together, @@ -52,6 +53,7 @@ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", + "tool_runtime": "memory", }, id="fireworks", marks=pytest.mark.fireworks, @@ -62,6 +64,7 @@ "safety": "remote", "memory": "remote", "agents": "remote", + "tool_runtime": "memory", }, id="remote", marks=pytest.mark.remote, @@ -117,6 +120,7 @@ def pytest_generate_tests(metafunc): "safety": SAFETY_FIXTURES, "memory": MEMORY_FIXTURES, "agents": AGENTS_FIXTURES, + "tool_runtime": TOOL_RUNTIME_FIXTURES, } combinations = ( get_provider_fixture_overrides(metafunc.config, available_fixtures) diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 9f8e7a12bb..dd9882aa6a 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -10,14 +10,19 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput, ModelType +from llama_stack.apis.tools import ( + ToolDef, + ToolGroupInput, + ToolParameter, + UserDefinedToolGroupDef, +) from llama_stack.distribution.datatypes import Api, Provider - from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, ) - from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + from ..conftest import ProviderFixture, remote_stack_fixture @@ -55,7 +60,21 @@ def agents_meta_reference() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def tool_runtime_memory() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="memory-runtime", + provider_type="inline::memory-runtime", + config={}, + ) + ], + ) + + AGENTS_FIXTURES = ["meta_reference", "remote"] +TOOL_RUNTIME_FIXTURES = ["memory"] @pytest_asyncio.fixture(scope="session") @@ -64,7 +83,7 @@ async def agents_stack(request, inference_model, safety_shield): providers = {} provider_data = {} - for key in ["inference", "safety", "memory", "agents"]: + for key in ["inference", "safety", "memory", "agents", "tool_runtime"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if key == "inference": @@ -111,12 +130,48 @@ async def agents_stack(request, inference_model, safety_shield): metadata={"embedding_dimension": 384}, ) ) + tool_groups = [ + ToolGroupInput( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="memory", + description="memory", + parameters=[ + ToolParameter( + name="session_id", + description="session id", + parameter_type="string", + required=True, + ), + ToolParameter( + name="input_messages", + description="messages", + parameter_type="list", + required=True, + ), + ToolParameter( + name="attachments", + description="attachments", + parameter_type="list", + required=False, + ), + ], + metadata={}, + ) + ], + ), + provider_id="memory-runtime", + ) + ] test_stack = await construct_stack_for_test( - [Api.agents, Api.inference, Api.safety, Api.memory], + [Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], providers, provider_data, models=models, shields=[safety_shield] if safety_shield else [], + tool_groups=tool_groups, ) return test_stack diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index dc95fa6a65..4ff94e4fee 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -35,7 +35,6 @@ # # pytest -v -s llama_stack/providers/tests/agents/test_agents.py # -m "meta_reference" - from .fixtures import pick_inference_model from .utils import create_agent_session @@ -255,17 +254,8 @@ async def test_rag_agent_as_attachments( agent_config = AgentConfig( **{ **common_params, - "tools": [ - MemoryToolDefinition( - memory_bank_configs=[], - query_generator_config={ - "type": "default", - "sep": " ", - }, - max_tokens_in_context=4096, - max_chunks=10, - ), - ], + "tools": [], + "preprocessing_tools": ["memory"], "tool_choice": ToolChoice.auto, } ) diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 5a38aaecc9..6f37334083 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -16,7 +16,7 @@ from llama_stack.apis.models import ModelInput from llama_stack.apis.scoring_functions import ScoringFnInput from llama_stack.apis.shields import ShieldInput - +from llama_stack.apis.tools import ToolGroupInput from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Provider, StackRunConfig @@ -43,6 +43,7 @@ async def construct_stack_for_test( datasets: Optional[List[DatasetInput]] = None, scoring_fns: Optional[List[ScoringFnInput]] = None, eval_tasks: Optional[List[EvalTaskInput]] = None, + tool_groups: Optional[List[ToolGroupInput]] = None, ) -> TestStack: sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") run_config = dict( @@ -56,6 +57,7 @@ async def construct_stack_for_test( datasets=datasets or [], scoring_fns=scoring_fns or [], eval_tasks=eval_tasks or [], + tool_groups=tool_groups or [], ) run_config = parse_and_maybe_upgrade_config(run_config) try: From dcdf9da6ef6eb3ea430331cf9084ca37a11cbe7e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 20 Dec 2024 17:30:36 -0800 Subject: [PATCH 02/53] remove all usages of builtin tools in agents --- llama_stack/apis/agents/agents.py | 76 ---- llama_stack/apis/tools/tools.py | 2 + .../providers/tests/agents/conftest.py | 5 - .../agents/meta_reference/agent_instance.py | 117 +++--- .../agents/meta_reference/rag/__init__.py | 5 - .../agents/meta_reference/tests/__init__.py | 5 - .../meta_reference/tests/code_execution.py | 93 ---- .../agents/meta_reference/tools/__init__.py | 5 - .../agents/meta_reference/tools/base.py | 20 - .../agents/meta_reference/tools/builtin.py | 396 ------------------ .../tools/ipython_tool/__init__.py | 5 - .../tools/ipython_tool/code_env_prefix.py | 133 ------ .../tools/ipython_tool/code_execution.py | 256 ----------- .../ipython_tool/matplotlib_custom_backend.py | 90 ---- .../tools/ipython_tool/utils.py | 21 - .../providers/tests/agents/fixtures.py | 26 +- .../providers/tests/agents/test_agents.py | 34 +- 17 files changed, 89 insertions(+), 1200 deletions(-) delete mode 100644 llama_stack/llama_stack/providers/tests/agents/conftest.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/rag/__init__.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tests/__init__.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/__init__.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/base.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/builtin.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 65be923488..325ce94903 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -47,78 +47,6 @@ class Attachment(BaseModel): mime_type: str -class AgentTool(Enum): - brave_search = "brave_search" - wolfram_alpha = "wolfram_alpha" - photogen = "photogen" - code_interpreter = "code_interpreter" - - function_call = "function_call" - memory = "memory" - - -class ToolDefinitionCommon(BaseModel): - input_shields: Optional[List[str]] = Field(default_factory=list) - output_shields: Optional[List[str]] = Field(default_factory=list) - - -class SearchEngineType(Enum): - bing = "bing" - brave = "brave" - tavily = "tavily" - - -@json_schema_type -class SearchToolDefinition(ToolDefinitionCommon): - # NOTE: brave_search is just a placeholder since model always uses - # brave_search as tool call name - type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value - api_key: str - engine: SearchEngineType = SearchEngineType.brave - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class WolframAlphaToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value - api_key: str - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class PhotogenToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class CodeInterpreterToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value - enable_inline_code_execution: bool = True - remote_execution: Optional[RestAPIExecutionConfig] = None - - -@json_schema_type -class FunctionCallToolDefinition(ToolDefinitionCommon): - type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value - function_name: str - description: str - parameters: Dict[str, ToolParamDefinition] - remote_execution: Optional[RestAPIExecutionConfig] = None - - -AgentToolDefinition = Annotated[ - Union[ - SearchToolDefinition, - WolframAlphaToolDefinition, - PhotogenToolDefinition, - CodeInterpreterToolDefinition, - FunctionCallToolDefinition, - ], - Field(discriminator="type"), -] - - class StepCommon(BaseModel): turn_id: str step_id: str @@ -211,10 +139,6 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - - tools: Optional[List[AgentToolDefinition]] = Field( - default_factory=list, deprecated=True - ) available_tools: Optional[List[str]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 60b2bdab9f..15d59ca8fb 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -21,6 +21,8 @@ class ToolParameter(BaseModel): name: str parameter_type: str description: str + required: bool + default: Optional[Any] = None @json_schema_type diff --git a/llama_stack/llama_stack/providers/tests/agents/conftest.py b/llama_stack/llama_stack/providers/tests/agents/conftest.py deleted file mode 100644 index 756f351d88..0000000000 --- a/llama_stack/llama_stack/providers/tests/agents/conftest.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 00d8bbd363..8d52ac1b9e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -5,13 +5,15 @@ # the root directory of this source tree. import copy +import json import logging import os +import re import secrets import string import uuid from datetime import datetime -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, List from urllib.parse import urlparse import httpx @@ -29,16 +31,11 @@ AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, Attachment, - CodeInterpreterToolDefinition, - FunctionCallToolDefinition, InferenceStep, - PhotogenToolDefinition, - SearchToolDefinition, ShieldCallStep, StepType, ToolExecutionStep, Turn, - WolframAlphaToolDefinition, ) from llama_stack.apis.common.content_types import ( URL, @@ -67,15 +64,6 @@ from .persistence import AgentPersistence from .safety import SafetyException, ShieldRunnerMixin -from .tools.base import BaseTool -from .tools.builtin import ( - CodeInterpreterTool, - PhotogenTool, - SearchTool, - WolframAlphaTool, - interpret_content_as_attachment, -) -from .tools.safety import SafeTool log = logging.getLogger(__name__) @@ -86,6 +74,9 @@ def make_random_string(length: int = 8): ) +TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") + + class ChatAgent(ShieldRunnerMixin): def __init__( self, @@ -111,29 +102,6 @@ def __init__( self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api - builtin_tools = [] - for tool_defn in agent_config.tools: - if isinstance(tool_defn, WolframAlphaToolDefinition): - tool = WolframAlphaTool(tool_defn.api_key) - elif isinstance(tool_defn, SearchToolDefinition): - tool = SearchTool(tool_defn.engine, tool_defn.api_key) - elif isinstance(tool_defn, CodeInterpreterToolDefinition): - tool = CodeInterpreterTool() - elif isinstance(tool_defn, PhotogenToolDefinition): - tool = PhotogenTool(dump_dir=self.tempdir) - else: - continue - - builtin_tools.append( - SafeTool( - tool, - safety_api, - tool_defn.input_shields, - tool_defn.output_shields, - ) - ) - self.tools_dict = {t.get_name(): t for t in builtin_tools} - ShieldRunnerMixin.__init__( self, safety_api, @@ -453,7 +421,7 @@ async def _run( async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=self._get_tools(), + tools=await self._get_tools(), tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, @@ -595,7 +563,8 @@ async def _run( }, ) as span: result_messages = await execute_tool_call_maybe( - self.tools_dict, + self.tool_runtime_api, + session_id, [message], ) assert ( @@ -627,6 +596,20 @@ async def _run( # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially + def interpret_content_as_attachment( + content: str, + ) -> Optional[Attachment]: + match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) + if match: + snippet = match.group(1) + data = json.loads(snippet) + return Attachment( + url=URL(uri="file://" + data["filepath"]), + mime_type=data["mimetype"], + ) + + return None + if out_attachment := interpret_content_as_attachment( result_message.content ): @@ -639,25 +622,25 @@ async def _run( n_iter += 1 - def _get_tools(self) -> List[ToolDefinition]: + async def _get_tools(self) -> List[ToolDefinition]: ret = [] - for t in self.agent_config.tools: - if isinstance(t, SearchToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search)) - elif isinstance(t, WolframAlphaToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha)) - elif isinstance(t, PhotogenToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.photogen)) - elif isinstance(t, CodeInterpreterToolDefinition): - ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter)) - elif isinstance(t, FunctionCallToolDefinition): - ret.append( - ToolDefinition( - tool_name=t.function_name, - description=t.description, - parameters=t.parameters, - ) + for tool_name in self.agent_config.available_tools: + tool = await self.tool_groups_api.get_tool(tool_name) + params = {} + for param in tool.parameters: + params[param.name] = ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + ret.append( + ToolDefinition( + tool_name=tool.identifier, + description=tool.description, + parameters=params, ) + ) return ret @@ -696,7 +679,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa async def execute_tool_call_maybe( - tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage] + tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage] ) -> List[ToolResponseMessage]: # While Tools.run interface takes a list of messages, # All tools currently only run on a single message @@ -712,7 +695,17 @@ async def execute_tool_call_maybe( name = name.value - assert name in tools_dict, f"Tool {name} not found" - tool = tools_dict[name] - result_messages = await tool.run(messages) - return result_messages + result = await tool_runtime_api.invoke_tool( + tool_name=name, + args=dict( + session_id=session_id, + **tool_call.arguments, + ), + ) + return [ + ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=result.content, + ) + ] diff --git a/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py b/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py deleted file mode 100644 index 756f351d88..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py deleted file mode 100644 index 756f351d88..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py deleted file mode 100644 index 495cd2c92d..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import unittest - -from llama_models.llama3.api.datatypes import ( - Attachment, - BuiltinTool, - CompletionMessage, - StopReason, - ToolCall, -) - -from ..tools.builtin import CodeInterpreterTool - - -class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase): - async def test_matplotlib(self): - tool = CodeInterpreterTool() - code = """ -import matplotlib.pyplot as plt -import numpy as np - -x = np.array([1, 1]) -y = np.array([0, 10]) - -plt.plot(x, y) -plt.title('x = 1') -plt.xlabel('x') -plt.ylabel('y') -plt.grid(True) -plt.axvline(x=1, color='r') -plt.show() - """ - message = CompletionMessage( - role="assistant", - content="", - tool_calls=[ - ToolCall( - call_id="call_id", - tool_name=BuiltinTool.code_interpreter, - arguments={"code": code}, - ) - ], - stop_reason=StopReason.end_of_message, - ) - ret = await tool.run([message]) - - self.assertEqual(len(ret), 1) - - output = ret[0].content - self.assertIsInstance(output, Attachment) - self.assertEqual(output.mime_type, "image/png") - - async def test_path_unlink(self): - tool = CodeInterpreterTool() - code = """ -import os -from pathlib import Path -import tempfile - -dpath = Path(os.environ["MPLCONFIGDIR"]) -with open(dpath / "test", "w") as f: - f.write("hello") - -Path(dpath / "test").unlink() -print("_OK_") - """ - message = CompletionMessage( - role="assistant", - content="", - tool_calls=[ - ToolCall( - call_id="call_id", - tool_name=BuiltinTool.code_interpreter, - arguments={"code": code}, - ) - ], - stop_reason=StopReason.end_of_message, - ) - ret = await tool.run([message]) - - self.assertEqual(len(ret), 1) - - output = ret[0].content - self.assertTrue("_OK_" in output) - - -if __name__ == "__main__": - unittest.main() diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py deleted file mode 100644 index 756f351d88..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/base.py b/llama_stack/providers/inline/agents/meta_reference/tools/base.py deleted file mode 100644 index 15fba7e2e7..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/base.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from abc import ABC, abstractmethod -from typing import List - -from llama_stack.apis.inference import Message - - -class BaseTool(ABC): - @abstractmethod - def get_name(self) -> str: - raise NotImplementedError - - @abstractmethod - async def run(self, messages: List[Message]) -> List[Message]: - raise NotImplementedError diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py deleted file mode 100644 index 5045bf32df..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py +++ /dev/null @@ -1,396 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import logging -import re -import tempfile - -from abc import abstractmethod -from typing import List, Optional - -import requests - -from .ipython_tool.code_execution import ( - CodeExecutionContext, - CodeExecutionRequest, - CodeExecutor, - TOOLS_ATTACHMENT_KEY_REGEX, -) - -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.apis.agents import * # noqa: F403 - -from .base import BaseTool - - -log = logging.getLogger(__name__) - - -def interpret_content_as_attachment(content: str) -> Optional[Attachment]: - match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) - if match: - snippet = match.group(1) - data = json.loads(snippet) - return Attachment( - url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"] - ) - - return None - - -class SingleMessageBuiltinTool(BaseTool): - async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - assert len(messages) == 1, f"Expected single message, got {len(messages)}" - - message = messages[0] - assert len(message.tool_calls) == 1, "Expected a single tool call" - - tool_call = messages[0].tool_calls[0] - - query = tool_call.arguments["query"] - response: str = await self.run_impl(query) - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response, - ) - return [message] - - @abstractmethod - async def run_impl(self, query: str) -> str: - raise NotImplementedError() - - -class PhotogenTool(SingleMessageBuiltinTool): - def __init__(self, dump_dir: str) -> None: - self.dump_dir = dump_dir - - def get_name(self) -> str: - return BuiltinTool.photogen.value - - async def run_impl(self, query: str) -> str: - """ - Implement this to give the model an ability to generate images. - - Return: - info = { - "filepath": str(image_filepath), - "mimetype": "image/png", - } - """ - raise NotImplementedError() - - -class SearchTool(SingleMessageBuiltinTool): - def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None: - self.api_key = api_key - self.engine_type = engine - if engine == SearchEngineType.bing: - self.engine = BingSearch(api_key, **kwargs) - elif engine == SearchEngineType.brave: - self.engine = BraveSearch(api_key, **kwargs) - elif engine == SearchEngineType.tavily: - self.engine = TavilySearch(api_key, **kwargs) - else: - raise ValueError(f"Unknown search engine: {engine}") - - def get_name(self) -> str: - return BuiltinTool.brave_search.value - - async def run_impl(self, query: str) -> str: - return await self.engine.search(query) - - -class BingSearch: - def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None: - self.api_key = api_key - self.top_k = top_k - - async def search(self, query: str) -> str: - url = "https://api.bing.microsoft.com/v7.0/search" - headers = { - "Ocp-Apim-Subscription-Key": self.api_key, - } - params = { - "count": self.top_k, - "textDecorations": True, - "textFormat": "HTML", - "q": query, - } - - response = requests.get(url=url, params=params, headers=headers) - response.raise_for_status() - clean = self._clean_response(response.json()) - return json.dumps(clean) - - def _clean_response(self, search_response): - clean_response = [] - query = search_response["queryContext"]["originalQuery"] - if "webPages" in search_response: - pages = search_response["webPages"]["value"] - for p in pages: - selected_keys = {"name", "url", "snippet"} - clean_response.append( - {k: v for k, v in p.items() if k in selected_keys} - ) - if "news" in search_response: - clean_news = [] - news = search_response["news"]["value"] - for n in news: - selected_keys = {"name", "url", "description"} - clean_news.append({k: v for k, v in n.items() if k in selected_keys}) - - clean_response.append(clean_news) - - return {"query": query, "top_k": clean_response} - - -class BraveSearch: - def __init__(self, api_key: str) -> None: - self.api_key = api_key - - async def search(self, query: str) -> str: - url = "https://api.search.brave.com/res/v1/web/search" - headers = { - "X-Subscription-Token": self.api_key, - "Accept-Encoding": "gzip", - "Accept": "application/json", - } - payload = {"q": query} - response = requests.get(url=url, params=payload, headers=headers) - return json.dumps(self._clean_brave_response(response.json())) - - def _clean_brave_response(self, search_response, top_k=3): - query = None - clean_response = [] - if "query" in search_response: - if "original" in search_response["query"]: - query = search_response["query"]["original"] - if "mixed" in search_response: - mixed_results = search_response["mixed"] - for m in mixed_results["main"][:top_k]: - r_type = m["type"] - results = search_response[r_type]["results"] - if r_type == "web": - # For web data - add a single output from the search - idx = m["index"] - selected_keys = [ - "type", - "title", - "url", - "description", - "date", - "extra_snippets", - ] - cleaned = { - k: v for k, v in results[idx].items() if k in selected_keys - } - elif r_type == "faq": - # For faw data - take a list of all the questions & answers - selected_keys = ["type", "question", "answer", "title", "url"] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - elif r_type == "infobox": - idx = m["index"] - selected_keys = [ - "type", - "title", - "url", - "description", - "long_desc", - ] - cleaned = { - k: v for k, v in results[idx].items() if k in selected_keys - } - elif r_type == "videos": - selected_keys = [ - "type", - "url", - "title", - "description", - "date", - ] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - elif r_type == "locations": - # For faw data - take a list of all the questions & answers - selected_keys = [ - "type", - "title", - "url", - "description", - "coordinates", - "postal_address", - "contact", - "rating", - "distance", - "zoom_level", - ] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - elif r_type == "news": - # For faw data - take a list of all the questions & answers - selected_keys = [ - "type", - "title", - "url", - "description", - ] - cleaned = [] - for q in results: - cleaned.append( - {k: v for k, v in q.items() if k in selected_keys} - ) - else: - cleaned = [] - - clean_response.append(cleaned) - - return {"query": query, "top_k": clean_response} - - -class TavilySearch: - def __init__(self, api_key: str) -> None: - self.api_key = api_key - - async def search(self, query: str) -> str: - response = requests.post( - "https://api.tavily.com/search", - json={"api_key": self.api_key, "query": query}, - ) - return json.dumps(self._clean_tavily_response(response.json())) - - def _clean_tavily_response(self, search_response, top_k=3): - return {"query": search_response["query"], "top_k": search_response["results"]} - - -class WolframAlphaTool(SingleMessageBuiltinTool): - def __init__(self, api_key: str) -> None: - self.api_key = api_key - self.url = "https://api.wolframalpha.com/v2/query" - - def get_name(self) -> str: - return BuiltinTool.wolfram_alpha.value - - async def run_impl(self, query: str) -> str: - params = { - "input": query, - "appid": self.api_key, - "format": "plaintext", - "output": "json", - } - response = requests.get( - self.url, - params=params, - ) - - return json.dumps(self._clean_wolfram_alpha_response(response.json())) - - def _clean_wolfram_alpha_response(self, wa_response): - remove = { - "queryresult": [ - "datatypes", - "error", - "timedout", - "timedoutpods", - "numpods", - "timing", - "parsetiming", - "parsetimedout", - "recalculate", - "id", - "host", - "server", - "related", - "version", - { - "pods": [ - "scanner", - "id", - "error", - "expressiontypes", - "states", - "infos", - "position", - "numsubpods", - ] - }, - "assumptions", - ], - } - for main_key in remove: - for key_to_remove in remove[main_key]: - try: - if key_to_remove == "assumptions": - if "assumptions" in wa_response[main_key]: - del wa_response[main_key][key_to_remove] - if isinstance(key_to_remove, dict): - for sub_key in key_to_remove: - if sub_key == "pods": - for i in range(len(wa_response[main_key][sub_key])): - if ( - wa_response[main_key][sub_key][i]["title"] - == "Result" - ): - del wa_response[main_key][sub_key][i + 1 :] - break - sub_items = wa_response[main_key][sub_key] - for i in range(len(sub_items)): - for sub_key_to_remove in key_to_remove[sub_key]: - if sub_key_to_remove in sub_items[i]: - del sub_items[i][sub_key_to_remove] - elif key_to_remove in wa_response[main_key]: - del wa_response[main_key][key_to_remove] - except KeyError: - pass - return wa_response - - -class CodeInterpreterTool(BaseTool): - def __init__(self) -> None: - ctx = CodeExecutionContext( - matplotlib_dump_dir=tempfile.mkdtemp(), - ) - self.code_executor = CodeExecutor(ctx) - - def get_name(self) -> str: - return BuiltinTool.code_interpreter.value - - async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - message = messages[0] - assert len(message.tool_calls) == 1, "Expected a single tool call" - - tool_call = messages[0].tool_calls[0] - script = tool_call.arguments["code"] - - req = CodeExecutionRequest(scripts=[script]) - res = self.code_executor.execute(req) - - pieces = [res["process_status"]] - for out_type in ["stdout", "stderr"]: - res_out = res[out_type] - if res_out != "": - pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"]) - if out_type == "stderr": - log.error(f"ipython tool error: ↓\n{res_out}") - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content="\n".join(pieces), - ) - return [message] diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py deleted file mode 100644 index 756f351d88..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py deleted file mode 100644 index 10f64ec94f..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import errno - -# Disabling potentially dangerous functions -import os as _os -from functools import partial - -os_funcs_to_disable = [ - "kill", - "system", - "putenv", - "remove", - "removedirs", - "rmdir", - "fchdir", - "setuid", - "fork", - "forkpty", - "killpg", - "rename", - "renames", - "truncate", - "replace", - # "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly - "fchmod", - "fchown", - "chmod", - "chown", - "chroot", - "fchdir", - "lchflags", - "lchmod", - "lchown", - "chdir", -] - - -def call_not_allowed(*args, **kwargs): - raise OSError(errno.EPERM, "Call are not permitted in this environment") - - -for func_name in os_funcs_to_disable: - if hasattr(_os, func_name): - setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}")) - -import shutil as _shutil - -for func_name in ["rmtree", "move", "chown"]: - if hasattr(_shutil, func_name): - setattr( - _shutil, - func_name, - partial(call_not_allowed, _func_name=f"shutil.{func_name}"), - ) - -import subprocess as _subprocess - - -def popen_not_allowed(*args, **kwargs): - raise _subprocess.CalledProcessError( - -1, - args[0] if args else "unknown", - stderr="subprocess.Popen is not allowed in this environment", - ) - - -_subprocess.Popen = popen_not_allowed - - -import atexit as _atexit -import builtins as _builtins -import io as _io -import json as _json -import sys as _sys - -# NB! The following "unused" imports crucial, make sure not not to remove -# them with linters - they're used in code_execution.py -from contextlib import ( # noqa - contextmanager as _contextmanager, - redirect_stderr as _redirect_stderr, - redirect_stdout as _redirect_stdout, -) -from multiprocessing.connection import Connection as _Connection - -# Mangle imports to avoid polluting model execution namespace. - -_IO_SINK = _io.StringIO() -_NETWORK_TIMEOUT = 5 -_NETWORK_CONNECTIONS = None - - -def _open_connections(): - global _NETWORK_CONNECTIONS - if _NETWORK_CONNECTIONS is not None: - # Ensure connections only opened once. - return _NETWORK_CONNECTIONS - req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2] - req_con = _Connection(int(req_w_fd), readable=False) - resp_con = _Connection(int(resp_r_fd), writable=False) - _NETWORK_CONNECTIONS = (req_con, resp_con) - return _NETWORK_CONNECTIONS - - -_builtins._open_connections = _open_connections - - -@_atexit.register -def _close_connections(): - global _NETWORK_CONNECTIONS - if _NETWORK_CONNECTIONS is None: - return - for con in _NETWORK_CONNECTIONS: - con.close() - del _NETWORK_CONNECTIONS - - -def _network_call(request): - # NOTE: We communicate with the parent process in json, encoded - # in raw bytes. We do this because native send/recv methods use - # pickle which involves execution of arbitrary code. - _open_connections() - req_con, resp_con = _NETWORK_CONNECTIONS - - req_con.send_bytes(_json.dumps(request).encode("utf-8")) - if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None: - raise Exception(f"Network request timed out: {_json.dumps(request)}") - else: - return _json.loads(resp_con.recv_bytes().decode("utf-8")) diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py deleted file mode 100644 index fa2e367e58..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import base64 -import json -import multiprocessing -import os -import re -import subprocess -import sys -import tempfile -import textwrap -import time -from dataclasses import dataclass -from datetime import datetime -from io import BytesIO -from pathlib import Path -from typing import List - -from PIL import Image - -from .utils import get_code_env_prefix - -TOOLS_ATTACHMENT_KEY = "__tools_attachment__" -TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") - -DIRNAME = Path(__file__).parent - -CODE_EXEC_TIMEOUT = 20 -CODE_ENV_PREFIX = get_code_env_prefix() - -STDOUTERR_SINK_WRAPPER_TEMPLATE = """\ -with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK): -{code}\ -""" - -TRYEXCEPT_WRAPPER_TEMPLATE = """\ -try: -{code} -except: - pass\ -""" - - -def generate_bwrap_command(bind_dirs: List[str]) -> str: - """ - Generate the bwrap command string for binding all - directories in the current directory read-only. - """ - bwrap_args = "" - bwrap_args += "--ro-bind / / " - # Add the --dev flag to mount device files - bwrap_args += "--dev /dev " - for d in bind_dirs: - bwrap_args += f"--bind {d} {d} " - - # Add the --unshare-all flag to isolate the sandbox from the rest of the system - bwrap_args += "--unshare-all " - # Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies - bwrap_args += "--die-with-parent " - return bwrap_args - - -@dataclass -class CodeExecutionContext: - matplotlib_dump_dir: str - use_proxy: bool = False - - -@dataclass -class CodeExecutionRequest: - scripts: List[str] - only_last_cell_stdouterr: bool = True - only_last_cell_fail: bool = True - seed: int = 0 - strip_fpaths_in_stderr: bool = True - - -class CodeExecutor: - def __init__(self, context: CodeExecutionContext): - self.context = context - - def execute(self, req: CodeExecutionRequest) -> dict: - scripts = req.scripts - for i in range(len(scripts) - 1): - if req.only_last_cell_stdouterr: - scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format( - code=textwrap.indent(scripts[i], " " * 4) - ) - if req.only_last_cell_fail: - scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format( - code=textwrap.indent(scripts[i], " " * 4) - ) - - # Seeds prefix: - seed = req.seed - seeds_prefix = f"""\ -def _set_seeds(): - import random - random.seed({seed}) - import numpy as np - np.random.seed({seed}) -_set_seeds()\ -""" - - script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts) - with tempfile.TemporaryDirectory() as dpath: - bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath]) - cmd = [*bwrap_prefix.split(), sys.executable, "-c", script] - code_fpath = os.path.join(dpath, "code.py") - with open(code_fpath, "w") as f: - f.write(script) - - try: - python_path = os.environ.get("PYTHONPATH", "") - env = dict( - os.environ, - PYTHONHASHSEED=str(seed), - MPLCONFIGDIR=dpath, - MPLBACKEND="module://matplotlib_custom_backend", - PYTHONPATH=f"{DIRNAME}:{python_path}", - ) - stdout, stderr, returncode = do_subprocess( - cmd=cmd, - env=env, - ctx=self.context, - ) - - stderr = stderr.strip() - if req.strip_fpaths_in_stderr: - pattern = r'File "([^"]+)", line (\d+)' - stderr = re.sub(pattern, r"line \2", stderr) - - return { - "process_status": "completed", - "returncode": returncode, - "stdout": stdout.strip(), - "stderr": stderr, - } - - except subprocess.TimeoutExpired: - return { - "process_status": "timeout", - "stdout": "Timed out", - "stderr": "Timed out", - } - - except Exception as e: - return { - "process_status": "error", - "error_type": type(e).__name__, - "stderr": str(e), - "stdout": str(e), - } - - -def process_matplotlib_response(response, matplotlib_dump_dir: str): - image_data = response["image_data"] - # Convert the base64 string to a bytes object - images = [base64.b64decode(d["image_base64"]) for d in image_data] - # Create a list of PIL images from the bytes objects - images = [Image.open(BytesIO(img)) for img in images] - # Create a list of image paths - image_paths = [] - for i, img in enumerate(images): - # create new directory for each day to better organize data: - dump_dname = datetime.today().strftime("%Y-%m-%d") - dump_dpath = Path(matplotlib_dump_dir, dump_dname) - dump_dpath.mkdir(parents=True, exist_ok=True) - # save image into a file - dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png" - dump_fpath = dump_dpath / dump_fname - img.save(dump_fpath, "PNG") - image_paths.append(str(dump_fpath)) - - # this is kind of convoluted, we send back this response to the subprocess which - # prints it out - info = { - "filepath": str(image_paths[-1]), - "mimetype": "image/png", - } - return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}" - - -def execute_subprocess_request(request, ctx: CodeExecutionContext): - "Route requests from the subprocess (via network Pipes) to the internet/tools." - if request["type"] == "matplotlib": - return process_matplotlib_response(request, ctx.matplotlib_dump_dir) - else: - raise Exception(f'Unrecognised network request type: {request["type"]}') - - -def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext): - # Create Pipes to be used for any external tool/network requests. - req_r, req_w = multiprocessing.Pipe(duplex=False) - resp_r, resp_w = multiprocessing.Pipe(duplex=False) - - cmd += [str(req_w.fileno()), str(resp_r.fileno())] - proc = subprocess.Popen( - cmd, - pass_fds=(req_w.fileno(), resp_r.fileno()), - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - close_fds=True, - env=env, - ) - - # Close unnecessary fds. - req_w.close() - resp_r.close() - - pipe_close = False - done_read = False - start = time.monotonic() - while proc.poll() is None and not pipe_close: - if req_r.poll(0.1): - # NB: Python pipe semantics for poll and recv mean that - # poll() returns True is a pipe is closed. - # CF old school PEP from '09 - # https://bugs.python.org/issue5573 - try: - request = json.loads(req_r.recv_bytes().decode("utf-8")) - response = execute_subprocess_request(request, ctx) - - resp_w.send_bytes(json.dumps(response).encode("utf-8")) - except EOFError: - # The request pipe is closed - set a marker to exit - # after the next attempt at reading stdout/stderr. - pipe_close = True - - try: - # If lots has been printed, pipe might be full but - # proc cannot exit until all the stdout/stderr - # been written/read. - stdout, stderr = proc.communicate(timeout=0.3) - done_read = True - except subprocess.TimeoutExpired: - # The program has not terminated. Ignore it, there - # may be more network/tool requests. - continue - if time.monotonic() - start > CODE_EXEC_TIMEOUT: - proc.terminate() - raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT) - - if not done_read: - # Solve race condition where process terminates before - # we hit the while loop. - stdout, stderr = proc.communicate(timeout=0.3) - - resp_w.close() - req_r.close() - return stdout, stderr, proc.returncode diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py deleted file mode 100644 index 7fec08cf24..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -A custom Matplotlib backend that overrides the show method to return image bytes. -""" - -import base64 -import io -import json as _json -import logging - -import matplotlib -from matplotlib.backend_bases import FigureManagerBase - -# Import necessary components from Matplotlib -from matplotlib.backends.backend_agg import FigureCanvasAgg - -log = logging.getLogger(__name__) - - -class CustomFigureCanvas(FigureCanvasAgg): - def show(self): - # Save the figure to a BytesIO object - buf = io.BytesIO() - self.print_png(buf) - image_bytes = buf.getvalue() - buf.close() - return image_bytes - - -class CustomFigureManager(FigureManagerBase): - def __init__(self, canvas, num): - super().__init__(canvas, num) - - -# Mimic module initialization that integrates with the Matplotlib backend system -def _create_figure_manager(num, *args, **kwargs): - """ - Create a custom figure manager instance. - """ - FigureClass = kwargs.pop("FigureClass", None) # noqa: N806 - if FigureClass is None: - from matplotlib.figure import Figure - - FigureClass = Figure # noqa: N806 - fig = FigureClass(*args, **kwargs) - canvas = CustomFigureCanvas(fig) - manager = CustomFigureManager(canvas, num) - return manager - - -def show(): - """ - Handle all figures and potentially return their images as bytes. - - This function iterates over all figures registered with the custom backend, - renders them as images in bytes format, and could return a list of bytes objects, - one for each figure, or handle them as needed. - """ - image_data = [] - for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers(): - # Get the figure from the manager - fig = manager.canvas.figure - buf = io.BytesIO() # Create a buffer for the figure - fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format - buf.seek(0) # Go to the beginning of the buffer - image_bytes = buf.getvalue() # Retrieve bytes value - image_base64 = base64.b64encode(image_bytes).decode("utf-8") - image_data.append({"image_base64": image_base64}) - buf.close() - - req_con, resp_con = _open_connections() - - _json_dump = _json.dumps( - { - "type": "matplotlib", - "image_data": image_data, - } - ) - req_con.send_bytes(_json_dump.encode("utf-8")) - resp = _json.loads(resp_con.recv_bytes().decode("utf-8")) - log.info(resp) - - -FigureCanvas = CustomFigureCanvas -FigureManager = CustomFigureManager diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py deleted file mode 100644 index d6f539a39f..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os - -DIR = os.path.dirname(os.path.realpath(__file__)) -CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py") -CODE_ENV_PREFIX = None - - -def get_code_env_prefix() -> str: - global CODE_ENV_PREFIX - - if CODE_ENV_PREFIX is None: - with open(CODE_ENV_PREFIX_FILE, "r") as f: - CODE_ENV_PREFIX = f.read() - - return CODE_ENV_PREFIX diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index dd9882aa6a..f5158b57c2 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os import tempfile import pytest @@ -68,7 +69,14 @@ def tool_runtime_memory() -> ProviderFixture: provider_id="memory-runtime", provider_type="inline::memory-runtime", config={}, - ) + ), + Provider( + provider_id="brave-search", + provider_type="inline::brave-search", + config={ + "api_key": os.environ["BRAVE_SEARCH_API_KEY"], + }, + ), ], ) @@ -131,6 +139,20 @@ async def agents_stack(request, inference_model, safety_shield): ) ) tool_groups = [ + ToolGroupInput( + tool_group_id="brave_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="brave_search", + description="brave_search", + parameters=[], + metadata={}, + ), + ], + ), + provider_id="brave-search", + ), ToolGroupInput( tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( @@ -163,7 +185,7 @@ async def agents_stack(request, inference_model, safety_shield): ], ), provider_id="memory-runtime", - ) + ), ] test_stack = await construct_stack_for_test( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 4ff94e4fee..78ca2341fc 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -50,7 +50,8 @@ def common_params(inference_model): sampling_params=SamplingParams(temperature=0.7, top_p=0.95), input_shields=[], output_shields=[], - tools=[], + available_tools=[], + preprocessing_tools=[], max_infer_iters=5, ) @@ -91,7 +92,7 @@ async def create_agent_turn_with_search_tool( agents_stack: Dict[str, object], search_query_messages: List[object], common_params: Dict[str, str], - search_tool_definition: SearchToolDefinition, + tool_name: str, ) -> None: """ Create an agent turn with a search tool. @@ -107,7 +108,7 @@ async def create_agent_turn_with_search_tool( agent_config = AgentConfig( **{ **common_params, - "tools": [search_tool_definition], + "available_tools": [tool_name], } ) @@ -254,7 +255,6 @@ async def test_rag_agent_as_attachments( agent_config = AgentConfig( **{ **common_params, - "tools": [], "preprocessing_tools": ["memory"], "tool_choice": ToolChoice.auto, } @@ -295,29 +295,11 @@ async def test_create_agent_turn_with_brave_search( if "BRAVE_SEARCH_API_KEY" not in os.environ: pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") - search_tool_definition = SearchToolDefinition( - type=AgentTool.brave_search.value, - api_key=os.environ["BRAVE_SEARCH_API_KEY"], - engine=SearchEngineType.brave, - ) - await create_agent_turn_with_search_tool( - agents_stack, search_query_messages, common_params, search_tool_definition - ) - - @pytest.mark.asyncio - async def test_create_agent_turn_with_tavily_search( - self, agents_stack, search_query_messages, common_params - ): - if "TAVILY_SEARCH_API_KEY" not in os.environ: - pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") - - search_tool_definition = SearchToolDefinition( - type=AgentTool.brave_search.value, # place holder only - api_key=os.environ["TAVILY_SEARCH_API_KEY"], - engine=SearchEngineType.tavily, - ) await create_agent_turn_with_search_tool( - agents_stack, search_query_messages, common_params, search_tool_definition + agents_stack, + search_query_messages, + common_params, + "brave_search", ) From 9192a9bbb427db4eae3bd2ea840b7fa4818b5230 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 20 Dec 2024 21:01:41 -0800 Subject: [PATCH 03/53] add tavily --- .../tool_runtime/tavily_search/__init__.py | 20 ++++++ .../tool_runtime/tavily_search/config.py | 20 ++++++ .../tavily_search/tavily_search.py | 64 +++++++++++++++++++ .../providers/registry/tool_runtime.py | 8 +++ .../providers/tests/agents/fixtures.py | 37 ++++++++++- .../providers/tests/agents/test_agents.py | 16 ++++- 6 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py create mode 100644 llama_stack/providers/inline/tool_runtime/tavily_search/config.py create mode 100644 llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py b/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py new file mode 100644 index 0000000000..8061a250cb --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import TavilySearchToolConfig +from .tavily_search import TavilySearchToolRuntimeImpl + + +class TavilySearchToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_provider_impl(config: TavilySearchToolConfig, _deps): + impl = TavilySearchToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/config.py b/llama_stack/providers/inline/tool_runtime/tavily_search/config.py new file mode 100644 index 0000000000..f7a8f3f09b --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel, Field + + +class TavilySearchToolConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="The Tavily Search API Key", + ) + max_results: int = Field( + default=3, + description="The maximum number of results to return", + ) diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py new file mode 100644 index 0000000000..f80d10dfe5 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import TavilySearchToolConfig + + +class TavilySearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: TavilySearchToolConfig): + self.config = config + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "tavily_search": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: + raise NotImplementedError("Tavily search tool group not supported") + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + response = requests.post( + "https://api.tavily.com/search", + json={"api_key": api_key, "query": args["query"]}, + ) + print(f"================= Tavily response: {response.json()}") + + return ToolInvocationResult( + content=json.dumps(self._clean_tavily_response(response.json())) + ) + + def _clean_tavily_response(self, search_response, top_k=3): + return {"query": search_response["query"], "top_k": search_response["results"]} diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index d0493810c8..9058fb7189 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -33,6 +33,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", api_dependencies=[Api.memory, Api.memory_banks, Api.inference], ), + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::tavily-search", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.tavily_search", + config_class="llama_stack.providers.inline.tool_runtime.tavily_search.config.TavilySearchToolConfig", + provider_data_validator="llama_stack.providers.inline.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index f5158b57c2..c0690e4e31 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -77,6 +77,13 @@ def tool_runtime_memory() -> ProviderFixture: "api_key": os.environ["BRAVE_SEARCH_API_KEY"], }, ), + Provider( + provider_id="tavily-search", + provider_type="inline::tavily-search", + config={ + "api_key": os.environ["TAVILY_SEARCH_API_KEY"], + }, + ), ], ) @@ -146,13 +153,41 @@ async def agents_stack(request, inference_model, safety_shield): ToolDef( name="brave_search", description="brave_search", - parameters=[], + parameters=[ + ToolParameter( + name="query", + description="query", + parameter_type="string", + required=True, + ), + ], metadata={}, ), ], ), provider_id="brave-search", ), + ToolGroupInput( + tool_group_id="tavily_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="tavily_search", + description="tavily_search", + parameters=[ + ToolParameter( + name="query", + description="query", + parameter_type="string", + required=True, + ), + ], + metadata={}, + ), + ], + ), + provider_id="tavily-search", + ), ToolGroupInput( tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 78ca2341fc..cd4f754185 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -149,7 +149,7 @@ async def create_agent_turn_with_search_tool( tool_execution = tool_execution_events[0].event.payload.step_details assert isinstance(tool_execution, ToolExecutionStep) assert len(tool_execution.tool_calls) > 0 - assert tool_execution.tool_calls[0].tool_name == BuiltinTool.brave_search + assert tool_execution.tool_calls[0].tool_name == tool_name assert len(tool_execution.tool_responses) > 0 check_turn_complete_event(turn_response, session_id, search_query_messages) @@ -302,6 +302,20 @@ async def test_create_agent_turn_with_brave_search( "brave_search", ) + @pytest.mark.asyncio + async def test_create_agent_turn_with_tavily_search( + self, agents_stack, search_query_messages, common_params + ): + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + + await create_agent_turn_with_search_tool( + agents_stack, + search_query_messages, + common_params, + "tavily_search", + ) + def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response] From 2ad67529ef94a06f612303b185b2676afc023469 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 20 Dec 2024 22:02:00 -0800 Subject: [PATCH 04/53] fix agents to run custom tools --- .../inline/agents/meta_reference/agent_instance.py | 11 ++--------- .../tool_runtime/tavily_search/tavily_search.py | 1 - llama_stack/providers/tests/agents/test_agents.py | 5 ++++- .../providers/utils/inference/prompt_adapter.py | 3 --- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8d52ac1b9e..8075ea2bd2 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -531,11 +531,6 @@ async def _run( log.info(f"{str(message)}") tool_call = message.tool_calls[0] - name = tool_call.tool_name - if not isinstance(name, BuiltinTool) or name not in enabled_tools: - yield message - return - step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -691,10 +686,8 @@ async def execute_tool_call_maybe( tool_call = message.tool_calls[0] name = tool_call.tool_name - assert isinstance(name, BuiltinTool) - - name = name.value - + if isinstance(name, BuiltinTool): + name = name.value result = await tool_runtime_api.invoke_tool( tool_name=name, args=dict( diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py index f80d10dfe5..94a387f306 100644 --- a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py @@ -54,7 +54,6 @@ async def invoke_tool( "https://api.tavily.com/search", json={"api_key": api_key, "query": args["query"]}, ) - print(f"================= Tavily response: {response.json()}") return ToolInvocationResult( content=json.dumps(self._clean_tavily_response(response.json())) diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index cd4f754185..147f04b023 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -149,7 +149,10 @@ async def create_agent_turn_with_search_tool( tool_execution = tool_execution_events[0].event.payload.step_details assert isinstance(tool_execution, ToolExecutionStep) assert len(tool_execution.tool_calls) > 0 - assert tool_execution.tool_calls[0].tool_name == tool_name + actual_tool_name = tool_execution.tool_calls[0].tool_name + if isinstance(actual_tool_name, BuiltinTool): + actual_tool_name = actual_tool_name.value + assert actual_tool_name == tool_name assert len(tool_execution.tool_responses) > 0 check_turn_complete_event(turn_response, session_id, search_query_messages) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index ed0cabe1c9..d296105e03 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -14,7 +14,6 @@ import httpx from llama_models.datatypes import is_multimodal, ModelFamily - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ( RawContent, @@ -41,7 +40,6 @@ InterleavedContentItem, TextContentItem, ) - from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, @@ -52,7 +50,6 @@ ToolChoice, UserMessage, ) - from llama_stack.providers.utils.inference import supported_inference_models log = logging.getLogger(__name__) From 0155700ea662ee8074dde844f1dd476ae5223812 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 23 Dec 2024 11:04:24 -0800 Subject: [PATCH 05/53] working end to end client sdk tests --- docs/resources/llama-stack-spec.html | 1263 +++++++++++++----------- docs/resources/llama-stack-spec.yaml | 615 +++++------- llama_stack/apis/tools/tools.py | 2 + llama_stack/distribution/stack.py | 3 + tests/client-sdk/agents/test_agents.py | 167 +--- 5 files changed, 979 insertions(+), 1071 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index a9fb22b100..b1bef08820 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -462,6 +462,46 @@ } } }, + "/alpha/tool-runtime/discover": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/jsonl": { + "schema": { + "$ref": "#/components/schemas/ToolDef" + } + } + } + } + }, + "tags": [ + "ToolRuntime" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DiscoverToolsRequest" + } + } + }, + "required": true + } + } + }, "/alpha/inference/embeddings": { "post": { "responses": { @@ -1118,6 +1158,82 @@ } } }, + "/alpha/tools/get": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Tool" + } + } + } + } + }, + "tags": [ + "ToolGroups" + ], + "parameters": [ + { + "name": "tool_name", + "in": "query", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, + "/alpha/toolgroups/get": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ToolGroup" + } + } + } + } + }, + "tags": [ + "ToolGroups" + ], + "parameters": [ + { + "name": "tool_group_id", + "in": "query", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/alpha/post-training/job/artifacts": { "get": { "responses": { @@ -1301,6 +1417,47 @@ } } }, + "/alpha/tool-runtime/invoke": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ToolInvocationResult" + } + } + } + } + }, + "tags": [ + "ToolRuntime" + ], + "summary": "Run a tool with the given arguments", + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InvokeToolRequest" + } + } + }, + "required": true + } + } + }, "/alpha/eval/job/cancel": { "post": { "responses": { @@ -1695,6 +1852,76 @@ ] } }, + "/alpha/toolgroups/list": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/jsonl": { + "schema": { + "$ref": "#/components/schemas/ToolGroup" + } + } + } + } + }, + "tags": [ + "ToolGroups" + ], + "summary": "List tool groups with optional provider", + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, + "/alpha/tools/list": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/jsonl": { + "schema": { + "$ref": "#/components/schemas/Tool" + } + } + } + } + }, + "tags": [ + "ToolGroups" + ], + "summary": "List tools with optional tool group", + "parameters": [ + { + "name": "tool_group_id", + "in": "query", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/alpha/telemetry/log-event": { "post": { "responses": { @@ -2096,6 +2323,40 @@ } } }, + "/alpha/toolgroups/register": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "ToolGroups" + ], + "summary": "Register a tool group", + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RegisterToolGroupRequest" + } + } + }, + "required": true + } + } + }, "/alpha/eval/run-eval": { "post": { "responses": { @@ -3444,29 +3705,16 @@ "type": "string" } }, - "tools": { + "available_tools": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/SearchToolDefinition" - }, - { - "$ref": "#/components/schemas/WolframAlphaToolDefinition" - }, - { - "$ref": "#/components/schemas/PhotogenToolDefinition" - }, - { - "$ref": "#/components/schemas/CodeInterpreterToolDefinition" - }, - { - "$ref": "#/components/schemas/FunctionCallToolDefinition" - }, - { - "$ref": "#/components/schemas/MemoryToolDefinition" - } - ] + "type": "string" + } + }, + "preprocessing_tools": { + "type": "array", + "items": { + "type": "string" } }, "tool_choice": { @@ -3499,480 +3747,7 @@ "enable_session_persistence" ] }, - "CodeInterpreterToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "type": { - "type": "string", - "const": "code_interpreter", - "default": "code_interpreter" - }, - "enable_inline_code_execution": { - "type": "boolean", - "default": true - }, - "remote_execution": { - "$ref": "#/components/schemas/RestAPIExecutionConfig" - } - }, - "additionalProperties": false, - "required": [ - "type", - "enable_inline_code_execution" - ] - }, - "FunctionCallToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "type": { - "type": "string", - "const": "function_call", - "default": "function_call" - }, - "function_name": { - "type": "string" - }, - "description": { - "type": "string" - }, - "parameters": { - "type": "object", - "additionalProperties": { - "$ref": "#/components/schemas/ToolParamDefinition" - } - }, - "remote_execution": { - "$ref": "#/components/schemas/RestAPIExecutionConfig" - } - }, - "additionalProperties": false, - "required": [ - "type", - "function_name", - "description", - "parameters" - ] - }, - "MemoryToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "type": { - "type": "string", - "const": "memory", - "default": "memory" - }, - "memory_bank_configs": { - "type": "array", - "items": { - "oneOf": [ - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "vector", - "default": "vector" - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "keyvalue", - "default": "keyvalue" - }, - "keys": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type", - "keys" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "keyword", - "default": "keyword" - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "graph", - "default": "graph" - }, - "entities": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type", - "entities" - ] - } - ] - } - }, - "query_generator_config": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "default", - "default": "default" - }, - "sep": { - "type": "string", - "default": " " - } - }, - "additionalProperties": false, - "required": [ - "type", - "sep" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "llm", - "default": "llm" - }, - "model": { - "type": "string" - }, - "template": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "type", - "model", - "template" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "custom", - "default": "custom" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - } - ] - }, - "max_tokens_in_context": { - "type": "integer", - "default": 4096 - }, - "max_chunks": { - "type": "integer", - "default": 10 - } - }, - "additionalProperties": false, - "required": [ - "type", - "memory_bank_configs", - "query_generator_config", - "max_tokens_in_context", - "max_chunks" - ] - }, - "PhotogenToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "type": { - "type": "string", - "const": "photogen", - "default": "photogen" - }, - "remote_execution": { - "$ref": "#/components/schemas/RestAPIExecutionConfig" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - }, - "RestAPIExecutionConfig": { - "type": "object", - "properties": { - "url": { - "$ref": "#/components/schemas/URL" - }, - "method": { - "$ref": "#/components/schemas/RestAPIMethod" - }, - "params": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "headers": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "body": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "url", - "method" - ] - }, - "RestAPIMethod": { - "type": "string", - "enum": [ - "GET", - "POST", - "PUT", - "DELETE" - ] - }, - "SearchToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "type": { - "type": "string", - "const": "brave_search", - "default": "brave_search" - }, - "api_key": { - "type": "string" - }, - "engine": { - "type": "string", - "enum": [ - "bing", - "brave", - "tavily" - ], - "default": "brave" - }, - "remote_execution": { - "$ref": "#/components/schemas/RestAPIExecutionConfig" - } - }, - "additionalProperties": false, - "required": [ - "type", - "api_key", - "engine" - ] - }, - "WolframAlphaToolDefinition": { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "type": { - "type": "string", - "const": "wolfram_alpha", - "default": "wolfram_alpha" - }, - "api_key": { - "type": "string" - }, - "remote_execution": { - "$ref": "#/components/schemas/RestAPIExecutionConfig" - } - }, - "additionalProperties": false, - "required": [ - "type", - "api_key" - ] - }, - "CreateAgentRequest": { + "CreateAgentRequest": { "type": "object", "properties": { "agent_config": { @@ -4575,57 +4350,218 @@ "type": "string", "format": "date-time" }, - "completed_at": { - "type": "string", - "format": "date-time" + "completed_at": { + "type": "string", + "format": "date-time" + } + }, + "additionalProperties": false, + "required": [ + "turn_id", + "session_id", + "input_messages", + "steps", + "output_message", + "output_attachments", + "started_at" + ], + "title": "A single turn in an interaction with an Agentic System." + }, + "ViolationLevel": { + "type": "string", + "enum": [ + "info", + "warn", + "error" + ] + }, + "DeleteAgentsRequest": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "agent_id" + ] + }, + "DeleteAgentsSessionRequest": { + "type": "object", + "properties": { + "agent_id": { + "type": "string" + }, + "session_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "agent_id", + "session_id" + ] + }, + "MCPToolGroupDef": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "model_context_protocol", + "default": "model_context_protocol" + }, + "endpoint": { + "$ref": "#/components/schemas/URL" + } + }, + "additionalProperties": false, + "required": [ + "type", + "endpoint" + ], + "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." + }, + "ToolDef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolParameter" + } + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "tool_prompt_format": { + "$ref": "#/components/schemas/ToolPromptFormat", + "default": "json" + } + }, + "additionalProperties": false, + "required": [ + "name", + "description", + "parameters", + "metadata" + ] + }, + "ToolGroupDef": { + "oneOf": [ + { + "$ref": "#/components/schemas/MCPToolGroupDef" + }, + { + "$ref": "#/components/schemas/UserDefinedToolGroupDef" + } + ] + }, + "ToolParameter": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parameter_type": { + "type": "string" + }, + "description": { + "type": "string" + }, + "required": { + "type": "boolean" + }, + "default": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] } }, "additionalProperties": false, "required": [ - "turn_id", - "session_id", - "input_messages", - "steps", - "output_message", - "output_attachments", - "started_at" - ], - "title": "A single turn in an interaction with an Agentic System." - }, - "ViolationLevel": { - "type": "string", - "enum": [ - "info", - "warn", - "error" + "name", + "parameter_type", + "description", + "required" ] }, - "DeleteAgentsRequest": { + "UserDefinedToolGroupDef": { "type": "object", "properties": { - "agent_id": { - "type": "string" + "type": { + "type": "string", + "const": "user_defined", + "default": "user_defined" + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolDef" + } } }, "additionalProperties": false, "required": [ - "agent_id" + "type", + "tools" ] }, - "DeleteAgentsSessionRequest": { + "DiscoverToolsRequest": { "type": "object", "properties": { - "agent_id": { - "type": "string" - }, - "session_id": { - "type": "string" + "tool_group": { + "$ref": "#/components/schemas/ToolGroupDef" } }, "additionalProperties": false, "required": [ - "agent_id", - "session_id" + "tool_group" ] }, "EmbeddingsRequest": { @@ -5841,6 +5777,101 @@ "start_time" ] }, + "Tool": { + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "provider_resource_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "tool", + "default": "tool" + }, + "tool_group": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolParameter" + } + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "tool_prompt_format": { + "$ref": "#/components/schemas/ToolPromptFormat", + "default": "json" + } + }, + "additionalProperties": false, + "required": [ + "identifier", + "provider_resource_id", + "type", + "tool_group", + "description", + "parameters" + ] + }, + "ToolGroup": { + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "provider_resource_id": { + "type": "string" + }, + "provider_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "tool_group", + "default": "tool_group" + } + }, + "additionalProperties": false, + "required": [ + "identifier", + "provider_resource_id", + "provider_id", + "type" + ] + }, "Checkpoint": { "description": "Checkpoint created during training runs" }, @@ -6041,6 +6072,62 @@ "documents" ] }, + "InvokeToolRequest": { + "type": "object", + "properties": { + "tool_name": { + "type": "string" + }, + "args": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "tool_name", + "args" + ] + }, + "ToolInvocationResult": { + "type": "object", + "properties": { + "content": { + "$ref": "#/components/schemas/InterleavedContent" + }, + "error_message": { + "type": "string" + }, + "error_code": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "content" + ] + }, "JobCancelRequest": { "type": "object", "properties": { @@ -7187,6 +7274,25 @@ "shield_id" ] }, + "RegisterToolGroupRequest": { + "type": "object", + "properties": { + "tool_group_id": { + "type": "string" + }, + "tool_group": { + "$ref": "#/components/schemas/ToolGroupDef" + }, + "provider_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "tool_group_id", + "tool_group" + ] + }, "RunEvalRequest": { "type": "object", "properties": { @@ -7868,10 +7974,6 @@ "name": "Checkpoint", "description": "Checkpoint created during training runs\n\n" }, - { - "name": "CodeInterpreterToolDefinition", - "description": "" - }, { "name": "CompletionMessage", "description": "" @@ -7926,6 +8028,10 @@ "name": "DeleteAgentsSessionRequest", "description": "" }, + { + "name": "DiscoverToolsRequest", + "description": "" + }, { "name": "EfficiencyConfig", "description": "" @@ -7956,10 +8062,6 @@ "name": "EvaluateRowsRequest", "description": "" }, - { - "name": "FunctionCallToolDefinition", - "description": "" - }, { "name": "GetAgentsSessionRequest", "description": "" @@ -8006,6 +8108,10 @@ "name": "InterleavedContentItem", "description": "" }, + { + "name": "InvokeToolRequest", + "description": "" + }, { "name": "Job", "description": "" @@ -8050,6 +8156,10 @@ "name": "LoraFinetuningConfig", "description": "" }, + { + "name": "MCPToolGroupDef", + "description": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.\n\n" + }, { "name": "Memory" }, @@ -8064,10 +8174,6 @@ "name": "MemoryRetrievalStep", "description": "" }, - { - "name": "MemoryToolDefinition", - "description": "" - }, { "name": "Message", "description": "" @@ -8107,10 +8213,6 @@ "name": "ParamType", "description": "" }, - { - "name": "PhotogenToolDefinition", - "description": "" - }, { "name": "PostTraining (Coming Soon)" }, @@ -8191,16 +8293,12 @@ "description": "" }, { - "name": "ResponseFormat", - "description": "" - }, - { - "name": "RestAPIExecutionConfig", - "description": "" + "name": "RegisterToolGroupRequest", + "description": "" }, { - "name": "RestAPIMethod", - "description": "" + "name": "ResponseFormat", + "description": "" }, { "name": "RouteInfo", @@ -8267,10 +8365,6 @@ "name": "ScoringResult", "description": "" }, - { - "name": "SearchToolDefinition", - "description": "" - }, { "name": "Session", "description": "A single session of an interaction with an Agentic System.\n\n" @@ -8344,6 +8438,10 @@ "name": "TokenLogProbs", "description": "" }, + { + "name": "Tool", + "description": "" + }, { "name": "ToolCall", "description": "" @@ -8360,6 +8458,10 @@ "name": "ToolChoice", "description": "" }, + { + "name": "ToolDef", + "description": "" + }, { "name": "ToolDefinition", "description": "" @@ -8368,10 +8470,29 @@ "name": "ToolExecutionStep", "description": "" }, + { + "name": "ToolGroup", + "description": "" + }, + { + "name": "ToolGroupDef", + "description": "" + }, + { + "name": "ToolGroups" + }, + { + "name": "ToolInvocationResult", + "description": "" + }, { "name": "ToolParamDefinition", "description": "" }, + { + "name": "ToolParameter", + "description": "" + }, { "name": "ToolPromptFormat", "description": "This Enum refers to the prompt format for calling custom / zero shot tools\n\n`json` --\n Refers to the json format for calling tools.\n The json format takes the form like\n {\n \"type\": \"function\",\n \"function\" : {\n \"name\": \"function_name\",\n \"description\": \"function_description\",\n \"parameters\": {...}\n }\n }\n\n`function_tag` --\n This is an example of how you could define\n your own user defined format for making tool calls.\n The function_tag format looks like this,\n (parameters)\n\nThe detailed prompts for each of these formats are added to llama cli\n\n" @@ -8384,6 +8505,9 @@ "name": "ToolResponseMessage", "description": "" }, + { + "name": "ToolRuntime" + }, { "name": "Trace", "description": "" @@ -8412,10 +8536,18 @@ "name": "UnregisterModelRequest", "description": "" }, + { + "name": "UnregisterToolGroupRequest", + "description": "" + }, { "name": "UnstructuredLogEvent", "description": "" }, + { + "name": "UserDefinedToolGroupDef", + "description": "" + }, { "name": "UserMessage", "description": "" @@ -8462,7 +8594,9 @@ "ScoringFunctions", "Shields", "SyntheticDataGeneration (Coming Soon)", - "Telemetry" + "Telemetry", + "ToolGroups", + "ToolRuntime" ] }, { @@ -8498,7 +8632,6 @@ "ChatCompletionResponseEventType", "ChatCompletionResponseStreamChunk", "Checkpoint", - "CodeInterpreterToolDefinition", "CompletionMessage", "CompletionRequest", "CompletionResponse", @@ -8511,13 +8644,13 @@ "Dataset", "DeleteAgentsRequest", "DeleteAgentsSessionRequest", + "DiscoverToolsRequest", "EfficiencyConfig", "EmbeddingsRequest", "EmbeddingsResponse", "EvalTask", "EvaluateResponse", "EvaluateRowsRequest", - "FunctionCallToolDefinition", "GetAgentsSessionRequest", "GetSpanTreeRequest", "GraphMemoryBank", @@ -8528,6 +8661,7 @@ "InsertDocumentsRequest", "InterleavedContent", "InterleavedContentItem", + "InvokeToolRequest", "Job", "JobCancelRequest", "JobStatus", @@ -8539,9 +8673,9 @@ "LogEventRequest", "LogSeverity", "LoraFinetuningConfig", + "MCPToolGroupDef", "MemoryBankDocument", "MemoryRetrievalStep", - "MemoryToolDefinition", "Message", "MetricEvent", "Model", @@ -8551,7 +8685,6 @@ "OptimizerType", "PaginatedRowsResult", "ParamType", - "PhotogenToolDefinition", "PostTrainingJob", "PostTrainingJobArtifactsResponse", "PostTrainingJobStatusResponse", @@ -8571,9 +8704,8 @@ "RegisterModelRequest", "RegisterScoringFunctionRequest", "RegisterShieldRequest", + "RegisterToolGroupRequest", "ResponseFormat", - "RestAPIExecutionConfig", - "RestAPIMethod", "RouteInfo", "RunEvalRequest", "RunShieldRequest", @@ -8588,7 +8720,6 @@ "ScoreResponse", "ScoringFn", "ScoringResult", - "SearchToolDefinition", "Session", "Shield", "ShieldCallStep", @@ -8605,13 +8736,19 @@ "SystemMessage", "TextContentItem", "TokenLogProbs", + "Tool", "ToolCall", "ToolCallDelta", "ToolCallParseStatus", "ToolChoice", + "ToolDef", "ToolDefinition", "ToolExecutionStep", + "ToolGroup", + "ToolGroupDef", + "ToolInvocationResult", "ToolParamDefinition", + "ToolParameter", "ToolPromptFormat", "ToolResponse", "ToolResponseMessage", @@ -8622,7 +8759,9 @@ "UnregisterDatasetRequest", "UnregisterMemoryBankRequest", "UnregisterModelRequest", + "UnregisterToolGroupRequest", "UnstructuredLogEvent", + "UserDefinedToolGroupDef", "UserMessage", "VectorMemoryBank", "VectorMemoryBankParams", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 8eca40cb74..5da647b542 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -17,6 +17,10 @@ components: AgentConfig: additionalProperties: false properties: + available_tools: + items: + type: string + type: array enable_session_persistence: type: boolean input_shields: @@ -34,6 +38,10 @@ components: items: type: string type: array + preprocessing_tools: + items: + type: string + type: array sampling_params: $ref: '#/components/schemas/SamplingParams' tool_choice: @@ -42,16 +50,6 @@ components: tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json - tools: - items: - oneOf: - - $ref: '#/components/schemas/SearchToolDefinition' - - $ref: '#/components/schemas/WolframAlphaToolDefinition' - - $ref: '#/components/schemas/PhotogenToolDefinition' - - $ref: '#/components/schemas/CodeInterpreterToolDefinition' - - $ref: '#/components/schemas/FunctionCallToolDefinition' - - $ref: '#/components/schemas/MemoryToolDefinition' - type: array required: - max_infer_iters - model @@ -490,30 +488,6 @@ components: type: object Checkpoint: description: Checkpoint created during training runs - CodeInterpreterToolDefinition: - additionalProperties: false - properties: - enable_inline_code_execution: - default: true - type: boolean - input_shields: - items: - type: string - type: array - output_shields: - items: - type: string - type: array - remote_execution: - $ref: '#/components/schemas/RestAPIExecutionConfig' - type: - const: code_interpreter - default: code_interpreter - type: string - required: - - type - - enable_inline_code_execution - type: object CompletionMessage: additionalProperties: false properties: @@ -729,6 +703,14 @@ components: - agent_id - session_id type: object + DiscoverToolsRequest: + additionalProperties: false + properties: + tool_group: + $ref: '#/components/schemas/ToolGroupDef' + required: + - tool_group + type: object EfficiencyConfig: additionalProperties: false properties: @@ -862,37 +844,6 @@ components: - scoring_functions - task_config type: object - FunctionCallToolDefinition: - additionalProperties: false - properties: - description: - type: string - function_name: - type: string - input_shields: - items: - type: string - type: array - output_shields: - items: - type: string - type: array - parameters: - additionalProperties: - $ref: '#/components/schemas/ToolParamDefinition' - type: object - remote_execution: - $ref: '#/components/schemas/RestAPIExecutionConfig' - type: - const: function_call - default: function_call - type: string - required: - - type - - function_name - - description - - parameters - type: object GetAgentsSessionRequest: additionalProperties: false properties: @@ -1017,6 +968,25 @@ components: oneOf: - $ref: '#/components/schemas/ImageContentItem' - $ref: '#/components/schemas/TextContentItem' + InvokeToolRequest: + additionalProperties: false + properties: + args: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + tool_name: + type: string + required: + - tool_name + - args + type: object Job: additionalProperties: false properties: @@ -1190,6 +1160,21 @@ components: - rank - alpha type: object + MCPToolGroupDef: + additionalProperties: false + properties: + endpoint: + $ref: '#/components/schemas/URL' + type: + const: model_context_protocol + default: model_context_protocol + type: string + required: + - type + - endpoint + title: A tool group that is defined by in a model context protocol server. Refer + to https://modelcontextprotocol.io/docs/concepts/tools for more information. + type: object MemoryBankDocument: additionalProperties: false properties: @@ -1250,135 +1235,6 @@ components: - memory_bank_ids - inserted_context type: object - MemoryToolDefinition: - additionalProperties: false - properties: - input_shields: - items: - type: string - type: array - max_chunks: - default: 10 - type: integer - max_tokens_in_context: - default: 4096 - type: integer - memory_bank_configs: - items: - oneOf: - - additionalProperties: false - properties: - bank_id: - type: string - type: - const: vector - default: vector - type: string - required: - - bank_id - - type - type: object - - additionalProperties: false - properties: - bank_id: - type: string - keys: - items: - type: string - type: array - type: - const: keyvalue - default: keyvalue - type: string - required: - - bank_id - - type - - keys - type: object - - additionalProperties: false - properties: - bank_id: - type: string - type: - const: keyword - default: keyword - type: string - required: - - bank_id - - type - type: object - - additionalProperties: false - properties: - bank_id: - type: string - entities: - items: - type: string - type: array - type: - const: graph - default: graph - type: string - required: - - bank_id - - type - - entities - type: object - type: array - output_shields: - items: - type: string - type: array - query_generator_config: - oneOf: - - additionalProperties: false - properties: - sep: - default: ' ' - type: string - type: - const: default - default: default - type: string - required: - - type - - sep - type: object - - additionalProperties: false - properties: - model: - type: string - template: - type: string - type: - const: llm - default: llm - type: string - required: - - type - - model - - template - type: object - - additionalProperties: false - properties: - type: - const: custom - default: custom - type: string - required: - - type - type: object - type: - const: memory - default: memory - type: string - required: - - type - - memory_bank_configs - - query_generator_config - - max_tokens_in_context - - max_chunks - type: object Message: oneOf: - $ref: '#/components/schemas/UserMessage' @@ -1621,26 +1477,6 @@ components: required: - type type: object - PhotogenToolDefinition: - additionalProperties: false - properties: - input_shields: - items: - type: string - type: array - output_shields: - items: - type: string - type: array - remote_execution: - $ref: '#/components/schemas/RestAPIExecutionConfig' - type: - const: photogen - default: photogen - type: string - required: - - type - type: object PostTrainingJob: additionalProperties: false properties: @@ -2039,6 +1875,19 @@ components: required: - shield_id type: object + RegisterToolGroupRequest: + additionalProperties: false + properties: + provider_id: + type: string + tool_group: + $ref: '#/components/schemas/ToolGroupDef' + tool_group_id: + type: string + required: + - tool_group_id + - tool_group + type: object ResponseFormat: oneOf: - additionalProperties: false @@ -2081,54 +1930,6 @@ components: - type - bnf type: object - RestAPIExecutionConfig: - additionalProperties: false - properties: - body: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - headers: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - method: - $ref: '#/components/schemas/RestAPIMethod' - params: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - url: - $ref: '#/components/schemas/URL' - required: - - url - - method - type: object - RestAPIMethod: - enum: - - GET - - POST - - PUT - - DELETE - type: string RouteInfo: additionalProperties: false properties: @@ -2399,37 +2200,6 @@ components: - score_rows - aggregated_results type: object - SearchToolDefinition: - additionalProperties: false - properties: - api_key: - type: string - engine: - default: brave - enum: - - bing - - brave - - tavily - type: string - input_shields: - items: - type: string - type: array - output_shields: - items: - type: string - type: array - remote_execution: - $ref: '#/components/schemas/RestAPIExecutionConfig' - type: - const: brave_search - default: brave_search - type: string - required: - - type - - api_key - - engine - type: object Session: additionalProperties: false properties: @@ -2784,6 +2554,48 @@ components: required: - logprobs_by_token type: object + Tool: + additionalProperties: false + properties: + description: + type: string + identifier: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + parameters: + items: + $ref: '#/components/schemas/ToolParameter' + type: array + provider_id: + type: string + provider_resource_id: + type: string + tool_group: + type: string + tool_prompt_format: + $ref: '#/components/schemas/ToolPromptFormat' + default: json + type: + const: tool + default: tool + type: string + required: + - identifier + - provider_resource_id + - type + - tool_group + - description + - parameters + type: object ToolCall: additionalProperties: false properties: @@ -2848,6 +2660,36 @@ components: - auto - required type: string + ToolDef: + additionalProperties: false + properties: + description: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + parameters: + items: + $ref: '#/components/schemas/ToolParameter' + type: array + tool_prompt_format: + $ref: '#/components/schemas/ToolPromptFormat' + default: json + required: + - name + - description + - parameters + - metadata + type: object ToolDefinition: additionalProperties: false properties: @@ -2896,6 +2738,41 @@ components: - tool_calls - tool_responses type: object + ToolGroup: + additionalProperties: false + properties: + identifier: + type: string + provider_id: + type: string + provider_resource_id: + type: string + type: + const: tool_group + default: tool_group + type: string + required: + - identifier + - provider_resource_id + - provider_id + - type + type: object + ToolGroupDef: + oneOf: + - $ref: '#/components/schemas/MCPToolGroupDef' + - $ref: '#/components/schemas/UserDefinedToolGroupDef' + ToolInvocationResult: + additionalProperties: false + properties: + content: + $ref: '#/components/schemas/InterleavedContent' + error_code: + type: integer + error_message: + type: string + required: + - content + type: object ToolParamDefinition: additionalProperties: false properties: @@ -2917,6 +2794,31 @@ components: required: - param_type type: object + ToolParameter: + additionalProperties: false + properties: + default: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: + type: string + name: + type: string + parameter_type: + type: string + required: + type: boolean + required: + - name + - parameter_type + - description + - required + type: object ToolPromptFormat: description: "`json` --\n Refers to the json format for calling tools.\n\ \ The json format takes the form like\n {\n \"type\": \"function\"\ @@ -3091,6 +2993,14 @@ components: required: - model_id type: object + UnregisterToolGroupRequest: + additionalProperties: false + properties: + tool_group_id: + type: string + required: + - tool_group_id + type: object UnstructuredLogEvent: additionalProperties: false properties: @@ -3127,6 +3037,21 @@ components: - message - severity type: object + UserDefinedToolGroupDef: + additionalProperties: false + properties: + tools: + items: + $ref: '#/components/schemas/ToolDef' + type: array + type: + const: user_defined + default: user_defined + type: string + required: + - type + - tools + type: object UserMessage: additionalProperties: false properties: @@ -3209,29 +3134,6 @@ components: - warn - error type: string - WolframAlphaToolDefinition: - additionalProperties: false - properties: - api_key: - type: string - input_shields: - items: - type: string - type: array - output_shields: - items: - type: string - type: array - remote_execution: - $ref: '#/components/schemas/RestAPIExecutionConfig' - type: - const: wolfram_alpha - default: wolfram_alpha - type: string - required: - - type - - api_key - type: object info: description: "This is the specification of the Llama Stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ @@ -4869,9 +4771,6 @@ tags: ' name: Checkpoint -- description: - name: CodeInterpreterToolDefinition - description: name: CompletionMessage @@ -4913,6 +4812,9 @@ tags: - description: name: DeleteAgentsSessionRequest +- description: + name: DiscoverToolsRequest - description: name: EfficiencyConfig @@ -4932,9 +4834,6 @@ tags: - description: name: EvaluateRowsRequest -- description: - name: FunctionCallToolDefinition - description: name: GetAgentsSessionRequest @@ -4965,6 +4864,9 @@ tags: - description: name: InterleavedContentItem +- description: + name: InvokeToolRequest - description: name: Job - description: name: LoraFinetuningConfig +- description: 'A tool group that is defined by in a model context protocol server. + Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information. + + + ' + name: MCPToolGroupDef - name: Memory - description: @@ -5003,9 +4911,6 @@ tags: - description: name: MemoryRetrievalStep -- description: - name: MemoryToolDefinition - description: name: Message - description: @@ -5027,9 +4932,6 @@ tags: name: PaginatedRowsResult - description: name: ParamType -- description: - name: PhotogenToolDefinition - name: PostTraining (Coming Soon) - description: @@ -5092,13 +4994,11 @@ tags: - description: name: RegisterShieldRequest +- description: + name: RegisterToolGroupRequest - description: name: ResponseFormat -- description: - name: RestAPIExecutionConfig -- description: - name: RestAPIMethod - description: name: RouteInfo - description: @@ -5137,9 +5037,6 @@ tags: - name: ScoringFunctions - description: name: ScoringResult -- description: - name: SearchToolDefinition - description: 'A single session of an interaction with an Agentic System. @@ -5191,6 +5088,8 @@ tags: name: TextContentItem - description: name: TokenLogProbs +- description: + name: Tool - description: name: ToolCall - description: @@ -5200,14 +5099,26 @@ tags: name: ToolCallParseStatus - description: name: ToolChoice +- description: + name: ToolDef - description: name: ToolDefinition - description: name: ToolExecutionStep +- description: + name: ToolGroup +- description: + name: ToolGroupDef +- name: ToolGroups +- description: + name: ToolInvocationResult - description: name: ToolParamDefinition +- description: + name: ToolParameter - description: "This Enum refers to the prompt format for calling custom / zero shot\ \ tools\n\n`json` --\n Refers to the json format for calling tools.\n The\ \ json format takes the form like\n {\n \"type\": \"function\",\n \ @@ -5224,6 +5135,7 @@ tags: - description: name: ToolResponseMessage +- name: ToolRuntime - description: name: Trace - description: @@ -5244,9 +5156,15 @@ tags: - description: name: UnregisterModelRequest +- description: + name: UnregisterToolGroupRequest - description: name: UnstructuredLogEvent +- description: + name: UserDefinedToolGroupDef - description: name: UserMessage - description: name: ViolationLevel -- description: - name: WolframAlphaToolDefinition x-tagGroups: - name: Operations tags: @@ -5283,6 +5198,8 @@ x-tagGroups: - Shields - SyntheticDataGeneration (Coming Soon) - Telemetry + - ToolGroups + - ToolRuntime - name: Types tags: - AgentCandidate @@ -5315,7 +5232,6 @@ x-tagGroups: - ChatCompletionResponseEventType - ChatCompletionResponseStreamChunk - Checkpoint - - CodeInterpreterToolDefinition - CompletionMessage - CompletionRequest - CompletionResponse @@ -5328,13 +5244,13 @@ x-tagGroups: - Dataset - DeleteAgentsRequest - DeleteAgentsSessionRequest + - DiscoverToolsRequest - EfficiencyConfig - EmbeddingsRequest - EmbeddingsResponse - EvalTask - EvaluateResponse - EvaluateRowsRequest - - FunctionCallToolDefinition - GetAgentsSessionRequest - GetSpanTreeRequest - GraphMemoryBank @@ -5345,6 +5261,7 @@ x-tagGroups: - InsertDocumentsRequest - InterleavedContent - InterleavedContentItem + - InvokeToolRequest - Job - JobCancelRequest - JobStatus @@ -5356,9 +5273,9 @@ x-tagGroups: - LogEventRequest - LogSeverity - LoraFinetuningConfig + - MCPToolGroupDef - MemoryBankDocument - MemoryRetrievalStep - - MemoryToolDefinition - Message - MetricEvent - Model @@ -5368,7 +5285,6 @@ x-tagGroups: - OptimizerType - PaginatedRowsResult - ParamType - - PhotogenToolDefinition - PostTrainingJob - PostTrainingJobArtifactsResponse - PostTrainingJobStatusResponse @@ -5388,9 +5304,8 @@ x-tagGroups: - RegisterModelRequest - RegisterScoringFunctionRequest - RegisterShieldRequest + - RegisterToolGroupRequest - ResponseFormat - - RestAPIExecutionConfig - - RestAPIMethod - RouteInfo - RunEvalRequest - RunShieldRequest @@ -5405,7 +5320,6 @@ x-tagGroups: - ScoreResponse - ScoringFn - ScoringResult - - SearchToolDefinition - Session - Shield - ShieldCallStep @@ -5422,13 +5336,19 @@ x-tagGroups: - SystemMessage - TextContentItem - TokenLogProbs + - Tool - ToolCall - ToolCallDelta - ToolCallParseStatus - ToolChoice + - ToolDef - ToolDefinition - ToolExecutionStep + - ToolGroup + - ToolGroupDef + - ToolInvocationResult - ToolParamDefinition + - ToolParameter - ToolPromptFormat - ToolResponse - ToolResponseMessage @@ -5439,10 +5359,11 @@ x-tagGroups: - UnregisterDatasetRequest - UnregisterMemoryBankRequest - UnregisterModelRequest + - UnregisterToolGroupRequest - UnstructuredLogEvent + - UserDefinedToolGroupDef - UserMessage - VectorMemoryBank - VectorMemoryBankParams - VersionInfo - ViolationLevel - - WolframAlphaToolDefinition diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 15d59ca8fb..e3c2ca52c6 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -74,12 +74,14 @@ class UserDefinedToolGroupDef(BaseModel): ) +@json_schema_type class ToolGroupInput(BaseModel): tool_group_id: str tool_group: ToolGroupDef provider_id: Optional[str] = None +@json_schema_type class ToolGroup(Resource): type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 9d12303c99..ae96744c6a 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -33,6 +33,7 @@ from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry +from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls @@ -63,6 +64,8 @@ class LlamaStack( Models, Shields, Inspect, + ToolGroups, + ToolRuntime, ): pass diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 85a197e36a..a7b08239be 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -4,78 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from typing import Dict, List from uuid import uuid4 import pytest -from llama_stack.providers.tests.env import get_env_or_fail - from llama_stack_client.lib.agents.agent import Agent - -from llama_stack_client.lib.agents.custom_tool import CustomTool from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types import CompletionMessage, ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.tool_param_definition_param import ( - ToolParamDefinitionParam, -) - - -class TestCustomTool(CustomTool): - """Tool to give boiling point of a liquid - Returns the correct value for water in Celcius and Fahrenheit - and returns -1 for other liquids - - """ - - def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: - assert len(messages) == 1, "Expected single message" - - message = messages[0] - - tool_call = message.tool_calls[0] - - try: - response = self.run_impl(**tool_call.arguments) - response_str = json.dumps(response, ensure_ascii=False) - except Exception as e: - response_str = f"Error when running tool: {e}" - - message = ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=response_str, - role="ipython", - ) - return [message] - - def get_name(self) -> str: - return "get_boiling_point" - - def get_description(self) -> str: - return "Get the boiling point of a imaginary liquids (eg. polyjuice)" - - def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: - return { - "liquid_name": ToolParamDefinitionParam( - param_type="string", description="The name of the liquid", required=True - ), - "celcius": ToolParamDefinitionParam( - param_type="boolean", - description="Whether to return the boiling point in Celcius", - required=False, - ), - } - - def run_impl(self, liquid_name: str, celcius: bool = True) -> int: - if liquid_name.lower() == "polyjuice": - if celcius: - return -100 - else: - return -212 - else: - return -1 @pytest.fixture(scope="session") @@ -151,12 +85,8 @@ def test_agent_simple(llama_stack_client, agent_config): def test_builtin_tool_brave_search(llama_stack_client, agent_config): agent_config = { **agent_config, - "tools": [ - { - "type": "brave_search", - "engine": "brave", - "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), - } + "available_tools": [ + "brave_search", ], } print(f"Agent Config: {agent_config}") @@ -167,7 +97,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "Search the web and tell me who the 44th president of the United States was. Please use tools", + "content": "Search the web and tell me who the current CEO of Meta is.", } ], session_id=session_id, @@ -178,92 +108,5 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): assert "tool_execution>" in logs_str assert "Tool:brave_search Response:" in logs_str - assert "obama" in logs_str.lower() - if len(agent_config["input_shields"]) > 0: - assert "No Violation" in logs_str - - -def test_builtin_tool_code_execution(llama_stack_client, agent_config): - agent_config = { - **agent_config, - "tools": [ - { - "type": "code_interpreter", - } - ], - } - agent = Agent(llama_stack_client, agent_config) - session_id = agent.create_session(f"test-session-{uuid4()}") - - response = agent.create_turn( - messages=[ - { - "role": "user", - "content": "Write code to answer the question: What is the 100th prime number?", - }, - ], - session_id=session_id, - ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - - if "Tool:code_interpreter Response" not in logs_str: - assert len(logs_str) > 0 - pytest.skip("code_interpreter not called by model") - - assert "Tool:code_interpreter Response" in logs_str - if "No such file or directory: 'bwrap'" in logs_str: - assert "prime" in logs_str - pytest.skip("`bwrap` is not available on this platform") - else: - assert "541" in logs_str - - -def test_custom_tool(llama_stack_client, agent_config): - agent_config = { - **agent_config, - "model": "meta-llama/Llama-3.2-3B-Instruct", - "tools": [ - { - "type": "brave_search", - "engine": "brave", - "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), - }, - { - "function_name": "get_boiling_point", - "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", - "parameters": { - "liquid_name": { - "param_type": "str", - "description": "The name of the liquid", - "required": True, - }, - "celcius": { - "param_type": "boolean", - "description": "Whether to return the boiling point in Celcius", - "required": False, - }, - }, - "type": "function_call", - }, - ], - "tool_prompt_format": "python_list", - } - - agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) - session_id = agent.create_session(f"test-session-{uuid4()}") - - response = agent.create_turn( - messages=[ - { - "role": "user", - "content": "What is the boiling point of polyjuice?", - }, - ], - session_id=session_id, - ) - - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - assert "-100" in logs_str - assert "CustomTool" in logs_str + assert "mark zuckerberg" in logs_str.lower() + assert "No Violation" in logs_str From 40f35f3a8d8053448ecdf2a2237ed69f4c0593e9 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 23 Dec 2024 11:55:55 -0800 Subject: [PATCH 06/53] add code interpreter --- .../tool_runtime/code_interpreter/__init__.py | 16 ++ .../code_interpreter/code_env_prefix.py | 133 +++++++++ .../code_interpreter/code_execution.py | 256 ++++++++++++++++++ .../code_interpreter/code_interpreter.py | 55 ++++ .../tool_runtime/code_interpreter/config.py | 11 + .../tool_runtime/code_interpreter/utils.py | 21 ++ .../providers/registry/tool_runtime.py | 7 + .../providers/tests/agents/fixtures.py | 19 ++ tests/client-sdk/agents/test_agents.py | 26 ++ tests/client-sdk/conftest.py | 2 +- 10 files changed, 545 insertions(+), 1 deletion(-) create mode 100644 llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py create mode 100644 llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py create mode 100644 llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py create mode 100644 llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py create mode 100644 llama_stack/providers/inline/tool_runtime/code_interpreter/config.py create mode 100644 llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py new file mode 100644 index 0000000000..663b9655ba --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .code_interpreter import CodeInterpreterToolRuntimeImpl +from .config import CodeInterpreterToolConfig + +__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"] + + +async def get_provider_impl(config: CodeInterpreterToolConfig, _deps): + impl = CodeInterpreterToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py new file mode 100644 index 0000000000..10f64ec94f --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import errno + +# Disabling potentially dangerous functions +import os as _os +from functools import partial + +os_funcs_to_disable = [ + "kill", + "system", + "putenv", + "remove", + "removedirs", + "rmdir", + "fchdir", + "setuid", + "fork", + "forkpty", + "killpg", + "rename", + "renames", + "truncate", + "replace", + # "unlink", # Commenting as this was blocking matpltlib from rendering plots correctly + "fchmod", + "fchown", + "chmod", + "chown", + "chroot", + "fchdir", + "lchflags", + "lchmod", + "lchown", + "chdir", +] + + +def call_not_allowed(*args, **kwargs): + raise OSError(errno.EPERM, "Call are not permitted in this environment") + + +for func_name in os_funcs_to_disable: + if hasattr(_os, func_name): + setattr(_os, func_name, partial(call_not_allowed, _func_name=f"os.{func_name}")) + +import shutil as _shutil + +for func_name in ["rmtree", "move", "chown"]: + if hasattr(_shutil, func_name): + setattr( + _shutil, + func_name, + partial(call_not_allowed, _func_name=f"shutil.{func_name}"), + ) + +import subprocess as _subprocess + + +def popen_not_allowed(*args, **kwargs): + raise _subprocess.CalledProcessError( + -1, + args[0] if args else "unknown", + stderr="subprocess.Popen is not allowed in this environment", + ) + + +_subprocess.Popen = popen_not_allowed + + +import atexit as _atexit +import builtins as _builtins +import io as _io +import json as _json +import sys as _sys + +# NB! The following "unused" imports crucial, make sure not not to remove +# them with linters - they're used in code_execution.py +from contextlib import ( # noqa + contextmanager as _contextmanager, + redirect_stderr as _redirect_stderr, + redirect_stdout as _redirect_stdout, +) +from multiprocessing.connection import Connection as _Connection + +# Mangle imports to avoid polluting model execution namespace. + +_IO_SINK = _io.StringIO() +_NETWORK_TIMEOUT = 5 +_NETWORK_CONNECTIONS = None + + +def _open_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is not None: + # Ensure connections only opened once. + return _NETWORK_CONNECTIONS + req_w_fd, resp_r_fd = _sys.argv[1], _sys.argv[2] + req_con = _Connection(int(req_w_fd), readable=False) + resp_con = _Connection(int(resp_r_fd), writable=False) + _NETWORK_CONNECTIONS = (req_con, resp_con) + return _NETWORK_CONNECTIONS + + +_builtins._open_connections = _open_connections + + +@_atexit.register +def _close_connections(): + global _NETWORK_CONNECTIONS + if _NETWORK_CONNECTIONS is None: + return + for con in _NETWORK_CONNECTIONS: + con.close() + del _NETWORK_CONNECTIONS + + +def _network_call(request): + # NOTE: We communicate with the parent process in json, encoded + # in raw bytes. We do this because native send/recv methods use + # pickle which involves execution of arbitrary code. + _open_connections() + req_con, resp_con = _NETWORK_CONNECTIONS + + req_con.send_bytes(_json.dumps(request).encode("utf-8")) + if resp_con.poll(timeout=_NETWORK_TIMEOUT) is None: + raise Exception(f"Network request timed out: {_json.dumps(request)}") + else: + return _json.loads(resp_con.recv_bytes().decode("utf-8")) diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py new file mode 100644 index 0000000000..fa2e367e58 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import json +import multiprocessing +import os +import re +import subprocess +import sys +import tempfile +import textwrap +import time +from dataclasses import dataclass +from datetime import datetime +from io import BytesIO +from pathlib import Path +from typing import List + +from PIL import Image + +from .utils import get_code_env_prefix + +TOOLS_ATTACHMENT_KEY = "__tools_attachment__" +TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") + +DIRNAME = Path(__file__).parent + +CODE_EXEC_TIMEOUT = 20 +CODE_ENV_PREFIX = get_code_env_prefix() + +STDOUTERR_SINK_WRAPPER_TEMPLATE = """\ +with _redirect_stdout(_IO_SINK), _redirect_stderr(_IO_SINK): +{code}\ +""" + +TRYEXCEPT_WRAPPER_TEMPLATE = """\ +try: +{code} +except: + pass\ +""" + + +def generate_bwrap_command(bind_dirs: List[str]) -> str: + """ + Generate the bwrap command string for binding all + directories in the current directory read-only. + """ + bwrap_args = "" + bwrap_args += "--ro-bind / / " + # Add the --dev flag to mount device files + bwrap_args += "--dev /dev " + for d in bind_dirs: + bwrap_args += f"--bind {d} {d} " + + # Add the --unshare-all flag to isolate the sandbox from the rest of the system + bwrap_args += "--unshare-all " + # Add the --die-with-parent flag to ensure the child process dies when bwrap's parent dies + bwrap_args += "--die-with-parent " + return bwrap_args + + +@dataclass +class CodeExecutionContext: + matplotlib_dump_dir: str + use_proxy: bool = False + + +@dataclass +class CodeExecutionRequest: + scripts: List[str] + only_last_cell_stdouterr: bool = True + only_last_cell_fail: bool = True + seed: int = 0 + strip_fpaths_in_stderr: bool = True + + +class CodeExecutor: + def __init__(self, context: CodeExecutionContext): + self.context = context + + def execute(self, req: CodeExecutionRequest) -> dict: + scripts = req.scripts + for i in range(len(scripts) - 1): + if req.only_last_cell_stdouterr: + scripts[i] = STDOUTERR_SINK_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + if req.only_last_cell_fail: + scripts[i] = TRYEXCEPT_WRAPPER_TEMPLATE.format( + code=textwrap.indent(scripts[i], " " * 4) + ) + + # Seeds prefix: + seed = req.seed + seeds_prefix = f"""\ +def _set_seeds(): + import random + random.seed({seed}) + import numpy as np + np.random.seed({seed}) +_set_seeds()\ +""" + + script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts) + with tempfile.TemporaryDirectory() as dpath: + bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath]) + cmd = [*bwrap_prefix.split(), sys.executable, "-c", script] + code_fpath = os.path.join(dpath, "code.py") + with open(code_fpath, "w") as f: + f.write(script) + + try: + python_path = os.environ.get("PYTHONPATH", "") + env = dict( + os.environ, + PYTHONHASHSEED=str(seed), + MPLCONFIGDIR=dpath, + MPLBACKEND="module://matplotlib_custom_backend", + PYTHONPATH=f"{DIRNAME}:{python_path}", + ) + stdout, stderr, returncode = do_subprocess( + cmd=cmd, + env=env, + ctx=self.context, + ) + + stderr = stderr.strip() + if req.strip_fpaths_in_stderr: + pattern = r'File "([^"]+)", line (\d+)' + stderr = re.sub(pattern, r"line \2", stderr) + + return { + "process_status": "completed", + "returncode": returncode, + "stdout": stdout.strip(), + "stderr": stderr, + } + + except subprocess.TimeoutExpired: + return { + "process_status": "timeout", + "stdout": "Timed out", + "stderr": "Timed out", + } + + except Exception as e: + return { + "process_status": "error", + "error_type": type(e).__name__, + "stderr": str(e), + "stdout": str(e), + } + + +def process_matplotlib_response(response, matplotlib_dump_dir: str): + image_data = response["image_data"] + # Convert the base64 string to a bytes object + images = [base64.b64decode(d["image_base64"]) for d in image_data] + # Create a list of PIL images from the bytes objects + images = [Image.open(BytesIO(img)) for img in images] + # Create a list of image paths + image_paths = [] + for i, img in enumerate(images): + # create new directory for each day to better organize data: + dump_dname = datetime.today().strftime("%Y-%m-%d") + dump_dpath = Path(matplotlib_dump_dir, dump_dname) + dump_dpath.mkdir(parents=True, exist_ok=True) + # save image into a file + dump_fname = f"matplotlib_{str(time.time()).replace('.', '_')}_{i}.png" + dump_fpath = dump_dpath / dump_fname + img.save(dump_fpath, "PNG") + image_paths.append(str(dump_fpath)) + + # this is kind of convoluted, we send back this response to the subprocess which + # prints it out + info = { + "filepath": str(image_paths[-1]), + "mimetype": "image/png", + } + return f"{TOOLS_ATTACHMENT_KEY}={json.dumps(info)}" + + +def execute_subprocess_request(request, ctx: CodeExecutionContext): + "Route requests from the subprocess (via network Pipes) to the internet/tools." + if request["type"] == "matplotlib": + return process_matplotlib_response(request, ctx.matplotlib_dump_dir) + else: + raise Exception(f'Unrecognised network request type: {request["type"]}') + + +def do_subprocess(*, cmd: list, env: dict, ctx: CodeExecutionContext): + # Create Pipes to be used for any external tool/network requests. + req_r, req_w = multiprocessing.Pipe(duplex=False) + resp_r, resp_w = multiprocessing.Pipe(duplex=False) + + cmd += [str(req_w.fileno()), str(resp_r.fileno())] + proc = subprocess.Popen( + cmd, + pass_fds=(req_w.fileno(), resp_r.fileno()), + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + env=env, + ) + + # Close unnecessary fds. + req_w.close() + resp_r.close() + + pipe_close = False + done_read = False + start = time.monotonic() + while proc.poll() is None and not pipe_close: + if req_r.poll(0.1): + # NB: Python pipe semantics for poll and recv mean that + # poll() returns True is a pipe is closed. + # CF old school PEP from '09 + # https://bugs.python.org/issue5573 + try: + request = json.loads(req_r.recv_bytes().decode("utf-8")) + response = execute_subprocess_request(request, ctx) + + resp_w.send_bytes(json.dumps(response).encode("utf-8")) + except EOFError: + # The request pipe is closed - set a marker to exit + # after the next attempt at reading stdout/stderr. + pipe_close = True + + try: + # If lots has been printed, pipe might be full but + # proc cannot exit until all the stdout/stderr + # been written/read. + stdout, stderr = proc.communicate(timeout=0.3) + done_read = True + except subprocess.TimeoutExpired: + # The program has not terminated. Ignore it, there + # may be more network/tool requests. + continue + if time.monotonic() - start > CODE_EXEC_TIMEOUT: + proc.terminate() + raise subprocess.TimeoutExpired(cmd, CODE_EXEC_TIMEOUT) + + if not done_read: + # Solve race condition where process terminates before + # we hit the while loop. + stdout, stderr = proc.communicate(timeout=0.3) + + resp_w.close() + req_r.close() + return stdout, stderr, proc.returncode diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py new file mode 100644 index 0000000000..2e062d6d7f --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import logging +import tempfile +from typing import Any, Dict, List + +from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor +from .config import CodeInterpreterToolConfig + +log = logging.getLogger(__name__) + + +class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: CodeInterpreterToolConfig): + self.config = config + ctx = CodeExecutionContext( + matplotlib_dump_dir=tempfile.mkdtemp(), + ) + self.code_executor = CodeExecutor(ctx) + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + if tool.identifier != "code_interpreter": + raise ValueError(f"Tool identifier {tool.identifier} is not supported") + + async def unregister_tool(self, tool_id: str) -> None: + return + + async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: + raise NotImplementedError("Code interpreter tool group not supported") + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + script = args["code"] + req = CodeExecutionRequest(scripts=[script]) + res = self.code_executor.execute(req) + pieces = [res["process_status"]] + for out_type in ["stdout", "stderr"]: + res_out = res[out_type] + if res_out != "": + pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"]) + if out_type == "stderr": + log.error(f"ipython tool error: ↓\n{res_out}") + return ToolInvocationResult(content="\n".join(pieces)) diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py new file mode 100644 index 0000000000..167a2c3184 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class CodeInterpreterToolConfig(BaseModel): + pass diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py new file mode 100644 index 0000000000..d6f539a39f --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +DIR = os.path.dirname(os.path.realpath(__file__)) +CODE_ENV_PREFIX_FILE = os.path.join(DIR, "code_env_prefix.py") +CODE_ENV_PREFIX = None + + +def get_code_env_prefix() -> str: + global CODE_ENV_PREFIX + + if CODE_ENV_PREFIX is None: + with open(CODE_ENV_PREFIX_FILE, "r") as f: + CODE_ENV_PREFIX = f.read() + + return CODE_ENV_PREFIX diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 9058fb7189..e4e61109f2 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -41,6 +41,13 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.tavily_search.config.TavilySearchToolConfig", provider_data_validator="llama_stack.providers.inline.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", ), + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::code-interpreter", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.code_interpreter", + config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig", + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index c0690e4e31..ca44325d77 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -84,6 +84,11 @@ def tool_runtime_memory() -> ProviderFixture: "api_key": os.environ["TAVILY_SEARCH_API_KEY"], }, ), + Provider( + provider_id="code-interpreter", + provider_type="inline::code-interpreter", + config={}, + ), ], ) @@ -221,6 +226,20 @@ async def agents_stack(request, inference_model, safety_shield): ), provider_id="memory-runtime", ), + ToolGroupInput( + tool_group_id="code_interpreter_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + ToolDef( + name="code_interpreter", + description="code_interpreter", + parameters=[], + metadata={}, + ) + ], + ), + provider_id="code-interpreter", + ), ] test_stack = await construct_stack_for_test( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a7b08239be..4e335d8d3b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -110,3 +110,29 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): assert "Tool:brave_search Response:" in logs_str assert "mark zuckerberg" in logs_str.lower() assert "No Violation" in logs_str + + +def test_builtin_tool_code_execution(llama_stack_client, agent_config): + agent_config = { + **agent_config, + "available_tools": [ + "code_interpreter", + ], + } + agent = Agent(llama_stack_client, agent_config) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "Write code to answer the question: What is the 100th prime number?", + }, + ], + session_id=session_id, + ) + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + + assert "541" in logs_str + assert "Tool:code_interpreter Response" in logs_str diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 2366008dd2..28808ae4c8 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -6,8 +6,8 @@ import os import pytest -from llama_stack import LlamaStackAsLibraryClient +from llama_stack import LlamaStackAsLibraryClient from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client import LlamaStackClient From 517bc9ebea49e405ac6e98c7c9cbfff8e47b61ed Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 23 Dec 2024 12:12:09 -0800 Subject: [PATCH 07/53] add back custom tool tests --- tests/client-sdk/agents/test_agents.py | 114 +++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 4e335d8d3b..7939259d18 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -4,12 +4,76 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json +from typing import Dict, List from uuid import uuid4 import pytest + +from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.custom_tool import CustomTool from llama_stack_client.lib.agents.event_logger import EventLogger +from llama_stack_client.types import CompletionMessage, ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.tool_param_definition_param import ( + ToolParamDefinitionParam, +) + + +class TestCustomTool(CustomTool): + """Tool to give boiling point of a liquid + Returns the correct value for water in Celcius and Fahrenheit + and returns -1 for other liquids + """ + + def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + assert len(messages) == 1, "Expected single message" + + message = messages[0] + + tool_call = message.tool_calls[0] + + try: + response = self.run_impl(**tool_call.arguments) + response_str = json.dumps(response, ensure_ascii=False) + except Exception as e: + response_str = f"Error when running tool: {e}" + + message = ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=response_str, + role="ipython", + ) + return [message] + + def get_name(self) -> str: + return "get_boiling_point" + + def get_description(self) -> str: + return "Get the boiling point of a imaginary liquids (eg. polyjuice)" + + def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: + return { + "liquid_name": ToolParamDefinitionParam( + param_type="string", description="The name of the liquid", required=True + ), + "celcius": ToolParamDefinitionParam( + param_type="boolean", + description="Whether to return the boiling point in Celcius", + required=False, + ), + } + + def run_impl(self, liquid_name: str, celcius: bool = True) -> int: + if liquid_name.lower() == "polyjuice": + if celcius: + return -100 + else: + return -212 + else: + return -1 @pytest.fixture(scope="session") @@ -136,3 +200,53 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): assert "541" in logs_str assert "Tool:code_interpreter Response" in logs_str + + +def test_custom_tool(llama_stack_client, agent_config): + agent_config = { + **agent_config, + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tools": [ + { + "type": "brave_search", + "engine": "brave", + "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), + }, + { + "function_name": "get_boiling_point", + "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", + "parameters": { + "liquid_name": { + "param_type": "str", + "description": "The name of the liquid", + "required": True, + }, + "celcius": { + "param_type": "boolean", + "description": "Whether to return the boiling point in Celcius", + "required": False, + }, + }, + "type": "function_call", + }, + ], + "tool_prompt_format": "python_list", + } + + agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) + session_id = agent.create_session(f"test-session-{uuid4()}") + + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": "What is the boiling point of polyjuice?", + }, + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert "-100" in logs_str + assert "CustomTool" in logs_str From 1a66ddc1b55f0a0a3f143de38225fa4fc154fbe7 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 23 Dec 2024 16:50:03 -0800 Subject: [PATCH 08/53] add support for built in tool type --- llama_stack/apis/tools/tools.py | 40 ++++++++++++++-- .../distribution/routers/routing_tables.py | 42 ++++++++++++----- .../agents/meta_reference/agent_instance.py | 3 ++ .../tool_runtime/brave_search/brave_search.py | 3 +- .../tavily_search/tavily_search.py | 3 +- .../model_context_protocol.py | 3 +- .../providers/tests/agents/fixtures.py | 46 ++++--------------- .../providers/tests/agents/test_agents.py | 18 +------- 8 files changed, 83 insertions(+), 75 deletions(-) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index e3c2ca52c6..65d5b84449 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from llama_models.llama3.api.datatypes import ToolPromptFormat +from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -25,13 +26,21 @@ class ToolParameter(BaseModel): default: Optional[Any] = None +@json_schema_type +class ToolHost(Enum): + distribution = "distribution" + client = "client" + model_context_protocol = "model_context_protocol" + + @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool.value] = ResourceType.tool.value tool_group: str + tool_host: ToolHost description: str parameters: List[ToolParameter] - provider_id: Optional[str] = None + built_in_type: Optional[BuiltinTool] = None metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json @@ -39,7 +48,8 @@ class Tool(Resource): @json_schema_type -class ToolDef(BaseModel): +class CustomToolDef(BaseModel): + type: Literal["custom"] = "custom" name: str description: str parameters: List[ToolParameter] @@ -49,6 +59,19 @@ class ToolDef(BaseModel): ) +@json_schema_type +class BuiltInToolDef(BaseModel): + type: Literal["built_in"] = "built_in" + built_in_type: BuiltinTool + metadata: Optional[Dict[str, Any]] = None + + +ToolDef = register_schema( + Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")], + name="ToolDef", +) + + @json_schema_type class MCPToolGroupDef(BaseModel): """ @@ -149,3 +172,14 @@ async def invoke_tool( ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... + + +# Three tool types: +# 1. Built-in tools +# 2. Client tools +# 3. Model-context-protocol tools + +# Suport registration of agents with tool groups +# TBD: Have a client utility to hide the pre processing tools. +# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked. +# diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 8d622a5c27..2aff0f3a23 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -516,6 +516,7 @@ async def register_tool_group( ) -> None: tools = [] tool_defs = [] + tool_host = ToolHost.distribution if provider_id is None: if len(self.impls_by_provider_id.keys()) > 1: raise ValueError( @@ -529,25 +530,42 @@ async def register_tool_group( tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_group ) - + tool_host = ToolHost.model_context_protocol elif isinstance(tool_group, UserDefinedToolGroupDef): tool_defs = tool_group.tools else: raise ValueError(f"Unknown tool group: {tool_group}") for tool_def in tool_defs: - tools.append( - Tool( - identifier=tool_def.name, - tool_group=tool_group_id, - description=tool_def.description, - parameters=tool_def.parameters, - provider_id=provider_id, - tool_prompt_format=tool_def.tool_prompt_format, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, + if isinstance(tool_def, CustomToolDef): + tools.append( + Tool( + identifier=tool_def.name, + tool_group=tool_group_id, + description=tool_def.description, + parameters=tool_def.parameters, + provider_id=provider_id, + tool_prompt_format=tool_def.tool_prompt_format, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + tool_host=tool_host, + ) + ) + elif isinstance(tool_def, BuiltInToolDef): + tools.append( + Tool( + identifier=tool_def.built_in_type.value, + tool_group=tool_group_id, + built_in_type=tool_def.built_in_type, + description="", + parameters=[], + provider_id=provider_id, + tool_prompt_format=ToolPromptFormat.json, + provider_resource_id=tool_def.built_in_type.value, + metadata=tool_def.metadata, + tool_host=tool_host, + ) ) - ) for tool in tools: existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 8075ea2bd2..cc4ef38a94 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -621,6 +621,9 @@ async def _get_tools(self) -> List[ToolDefinition]: ret = [] for tool_name in self.agent_config.available_tools: tool = await self.tool_groups_api.get_tool(tool_name) + if tool.built_in_type: + ret.append(ToolDefinition(tool_name=tool.built_in_type)) + continue params = {} for param in tool.parameters: params[param.name] = ToolParamDefinition( diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py index ca0141552e..cd0468d93b 100644 --- a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py @@ -25,8 +25,7 @@ async def initialize(self): pass async def register_tool(self, tool: Tool): - if tool.identifier != "brave_search": - raise ValueError(f"Tool identifier {tool.identifier} is not supported") + pass async def unregister_tool(self, tool_id: str) -> None: return diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py index 94a387f306..f4e9809293 100644 --- a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py @@ -26,8 +26,7 @@ async def initialize(self): pass async def register_tool(self, tool: Tool): - if tool.identifier != "tavily_search": - raise ValueError(f"Tool identifier {tool.identifier} is not supported") + pass async def unregister_tool(self, tool_id: str) -> None: return diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index b9bf3fe361..c77929f999 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -8,6 +8,7 @@ from urllib.parse import urlparse from llama_stack.apis.tools import ( + CustomToolDef, MCPToolGroupDef, ToolDef, ToolGroupDef, @@ -52,7 +53,7 @@ async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ) ) tools.append( - ToolDef( + CustomToolDef( name=tool.name, description=tool.description, parameters=parameters, diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index ca44325d77..97d0d47e69 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -9,10 +9,12 @@ import pytest import pytest_asyncio +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.tools import ( - ToolDef, + BuiltInToolDef, + CustomToolDef, ToolGroupInput, ToolParameter, UserDefinedToolGroupDef, @@ -151,42 +153,12 @@ async def agents_stack(request, inference_model, safety_shield): ) ) tool_groups = [ - ToolGroupInput( - tool_group_id="brave_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - ToolDef( - name="brave_search", - description="brave_search", - parameters=[ - ToolParameter( - name="query", - description="query", - parameter_type="string", - required=True, - ), - ], - metadata={}, - ), - ], - ), - provider_id="brave-search", - ), ToolGroupInput( tool_group_id="tavily_search_group", tool_group=UserDefinedToolGroupDef( tools=[ - ToolDef( - name="tavily_search", - description="tavily_search", - parameters=[ - ToolParameter( - name="query", - description="query", - parameter_type="string", - required=True, - ), - ], + BuiltInToolDef( + built_in_type=BuiltinTool.brave_search, metadata={}, ), ], @@ -197,7 +169,7 @@ async def agents_stack(request, inference_model, safety_shield): tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( tools=[ - ToolDef( + CustomToolDef( name="memory", description="memory", parameters=[ @@ -230,10 +202,8 @@ async def agents_stack(request, inference_model, safety_shield): tool_group_id="code_interpreter_group", tool_group=UserDefinedToolGroupDef( tools=[ - ToolDef( - name="code_interpreter", - description="code_interpreter", - parameters=[], + BuiltInToolDef( + built_in_type=BuiltinTool.code_interpreter, metadata={}, ) ], diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 147f04b023..a8c472da47 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -150,9 +150,7 @@ async def create_agent_turn_with_search_tool( assert isinstance(tool_execution, ToolExecutionStep) assert len(tool_execution.tool_calls) > 0 actual_tool_name = tool_execution.tool_calls[0].tool_name - if isinstance(actual_tool_name, BuiltinTool): - actual_tool_name = actual_tool_name.value - assert actual_tool_name == tool_name + assert actual_tool_name.value == tool_name assert len(tool_execution.tool_responses) > 0 check_turn_complete_event(turn_response, session_id, search_query_messages) @@ -305,20 +303,6 @@ async def test_create_agent_turn_with_brave_search( "brave_search", ) - @pytest.mark.asyncio - async def test_create_agent_turn_with_tavily_search( - self, agents_stack, search_query_messages, common_params - ): - if "TAVILY_SEARCH_API_KEY" not in os.environ: - pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") - - await create_agent_turn_with_search_tool( - agents_stack, - search_query_messages, - common_params, - "tavily_search", - ) - def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response] From 4dd2f4c363cf3d2b64bb3eb86c381631a629042e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 23 Dec 2024 18:27:55 -0800 Subject: [PATCH 09/53] working end to end client sdk tests with custom tools --- docs/resources/llama-stack-spec.html | 258 ++++++++++++------ docs/resources/llama-stack-spec.yaml | 115 ++++++-- llama_stack/apis/agents/agents.py | 3 +- .../agents/meta_reference/agent_instance.py | 23 ++ tests/client-sdk/agents/test_agents.py | 52 ++-- 5 files changed, 303 insertions(+), 148 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index b1bef08820..d1d2c266df 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3711,6 +3711,12 @@ "type": "string" } }, + "custom_tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CustomToolDef" + } + }, "preprocessing_tools": { "type": "array", "items": { @@ -3747,6 +3753,111 @@ "enable_session_persistence" ] }, + "CustomToolDef": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "custom", + "default": "custom" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "parameters": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolParameter" + } + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "tool_prompt_format": { + "$ref": "#/components/schemas/ToolPromptFormat", + "default": "json" + } + }, + "additionalProperties": false, + "required": [ + "type", + "name", + "description", + "parameters", + "metadata" + ] + }, + "ToolParameter": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parameter_type": { + "type": "string" + }, + "description": { + "type": "string" + }, + "required": { + "type": "boolean" + }, + "default": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "additionalProperties": false, + "required": [ + "name", + "parameter_type", + "description", + "required" + ] + }, "CreateAgentRequest": { "type": "object", "properties": { @@ -4403,39 +4514,16 @@ "session_id" ] }, - "MCPToolGroupDef": { + "BuiltInToolDef": { "type": "object", "properties": { "type": { "type": "string", - "const": "model_context_protocol", - "default": "model_context_protocol" + "const": "built_in", + "default": "built_in" }, - "endpoint": { - "$ref": "#/components/schemas/URL" - } - }, - "additionalProperties": false, - "required": [ - "type", - "endpoint" - ], - "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." - }, - "ToolDef": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "description": { - "type": "string" - }, - "parameters": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolParameter" - } + "built_in_type": { + "$ref": "#/components/schemas/BuiltinTool" }, "metadata": { "type": "object", @@ -4461,74 +4549,51 @@ } ] } - }, - "tool_prompt_format": { - "$ref": "#/components/schemas/ToolPromptFormat", - "default": "json" } }, "additionalProperties": false, "required": [ - "name", - "description", - "parameters", - "metadata" + "type", + "built_in_type" ] }, - "ToolGroupDef": { + "MCPToolGroupDef": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "model_context_protocol", + "default": "model_context_protocol" + }, + "endpoint": { + "$ref": "#/components/schemas/URL" + } + }, + "additionalProperties": false, + "required": [ + "type", + "endpoint" + ], + "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." + }, + "ToolDef": { "oneOf": [ { - "$ref": "#/components/schemas/MCPToolGroupDef" + "$ref": "#/components/schemas/CustomToolDef" }, { - "$ref": "#/components/schemas/UserDefinedToolGroupDef" + "$ref": "#/components/schemas/BuiltInToolDef" } ] }, - "ToolParameter": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "parameter_type": { - "type": "string" - }, - "description": { - "type": "string" - }, - "required": { - "type": "boolean" + "ToolGroupDef": { + "oneOf": [ + { + "$ref": "#/components/schemas/MCPToolGroupDef" }, - "default": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] + { + "$ref": "#/components/schemas/UserDefinedToolGroupDef" } - }, - "additionalProperties": false, - "required": [ - "name", - "parameter_type", - "description", - "required" ] }, "UserDefinedToolGroupDef": { @@ -5797,6 +5862,9 @@ "tool_group": { "type": "string" }, + "tool_host": { + "$ref": "#/components/schemas/ToolHost" + }, "description": { "type": "string" }, @@ -5806,6 +5874,9 @@ "$ref": "#/components/schemas/ToolParameter" } }, + "built_in_type": { + "$ref": "#/components/schemas/BuiltinTool" + }, "metadata": { "type": "object", "additionalProperties": { @@ -5840,12 +5911,22 @@ "required": [ "identifier", "provider_resource_id", + "provider_id", "type", "tool_group", + "tool_host", "description", "parameters" ] }, + "ToolHost": { + "type": "string", + "enum": [ + "distribution", + "client", + "model_context_protocol" + ] + }, "ToolGroup": { "type": "object", "properties": { @@ -7942,6 +8023,10 @@ "name": "BenchmarkEvalTaskConfig", "description": "" }, + { + "name": "BuiltInToolDef", + "description": "" + }, { "name": "BuiltinTool", "description": "" @@ -8002,6 +8087,10 @@ "name": "CreateAgentTurnRequest", "description": "" }, + { + "name": "CustomToolDef", + "description": "" + }, { "name": "DPOAlignmentConfig", "description": "" @@ -8481,6 +8570,10 @@ { "name": "ToolGroups" }, + { + "name": "ToolHost", + "description": "" + }, { "name": "ToolInvocationResult", "description": "" @@ -8624,6 +8717,7 @@ "BatchCompletionRequest", "BatchCompletionResponse", "BenchmarkEvalTaskConfig", + "BuiltInToolDef", "BuiltinTool", "CancelTrainingJobRequest", "ChatCompletionRequest", @@ -8639,6 +8733,7 @@ "CreateAgentRequest", "CreateAgentSessionRequest", "CreateAgentTurnRequest", + "CustomToolDef", "DPOAlignmentConfig", "DataConfig", "Dataset", @@ -8746,6 +8841,7 @@ "ToolExecutionStep", "ToolGroup", "ToolGroupDef", + "ToolHost", "ToolInvocationResult", "ToolParamDefinition", "ToolParameter", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 5da647b542..4f7a9c91c0 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -21,6 +21,10 @@ components: items: type: string type: array + custom_tools: + items: + $ref: '#/components/schemas/CustomToolDef' + type: array enable_session_persistence: type: boolean input_shields: @@ -389,6 +393,29 @@ components: - type - eval_candidate type: object + BuiltInToolDef: + additionalProperties: false + properties: + built_in_type: + $ref: '#/components/schemas/BuiltinTool' + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: built_in + default: built_in + type: string + required: + - type + - built_in_type + type: object BuiltinTool: enum: - brave_search @@ -607,6 +634,41 @@ components: - session_id - messages type: object + CustomToolDef: + additionalProperties: false + properties: + description: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + parameters: + items: + $ref: '#/components/schemas/ToolParameter' + type: array + tool_prompt_format: + $ref: '#/components/schemas/ToolPromptFormat' + default: json + type: + const: custom + default: custom + type: string + required: + - type + - name + - description + - parameters + - metadata + type: object DPOAlignmentConfig: additionalProperties: false properties: @@ -2557,6 +2619,8 @@ components: Tool: additionalProperties: false properties: + built_in_type: + $ref: '#/components/schemas/BuiltinTool' description: type: string identifier: @@ -2581,6 +2645,8 @@ components: type: string tool_group: type: string + tool_host: + $ref: '#/components/schemas/ToolHost' tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json @@ -2591,8 +2657,10 @@ components: required: - identifier - provider_resource_id + - provider_id - type - tool_group + - tool_host - description - parameters type: object @@ -2661,35 +2729,9 @@ components: - required type: string ToolDef: - additionalProperties: false - properties: - description: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - name: - type: string - parameters: - items: - $ref: '#/components/schemas/ToolParameter' - type: array - tool_prompt_format: - $ref: '#/components/schemas/ToolPromptFormat' - default: json - required: - - name - - description - - parameters - - metadata - type: object + oneOf: + - $ref: '#/components/schemas/CustomToolDef' + - $ref: '#/components/schemas/BuiltInToolDef' ToolDefinition: additionalProperties: false properties: @@ -2761,6 +2803,12 @@ components: oneOf: - $ref: '#/components/schemas/MCPToolGroupDef' - $ref: '#/components/schemas/UserDefinedToolGroupDef' + ToolHost: + enum: + - distribution + - client + - model_context_protocol + type: string ToolInvocationResult: additionalProperties: false properties: @@ -4738,6 +4786,8 @@ tags: - description: name: BenchmarkEvalTaskConfig +- description: + name: BuiltInToolDef - description: name: BuiltinTool - description: name: CreateAgentTurnRequest +- description: + name: CustomToolDef - description: name: DPOAlignmentConfig @@ -5111,6 +5163,8 @@ tags: - description: name: ToolGroupDef - name: ToolGroups +- description: + name: ToolHost - description: name: ToolInvocationResult @@ -5224,6 +5278,7 @@ x-tagGroups: - BatchCompletionRequest - BatchCompletionResponse - BenchmarkEvalTaskConfig + - BuiltInToolDef - BuiltinTool - CancelTrainingJobRequest - ChatCompletionRequest @@ -5239,6 +5294,7 @@ x-tagGroups: - CreateAgentRequest - CreateAgentSessionRequest - CreateAgentTurnRequest + - CustomToolDef - DPOAlignmentConfig - DataConfig - Dataset @@ -5346,6 +5402,7 @@ x-tagGroups: - ToolExecutionStep - ToolGroup - ToolGroupDef + - ToolHost - ToolInvocationResult - ToolParamDefinition - ToolParameter diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 325ce94903..3348211c97 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -18,13 +18,11 @@ runtime_checkable, ) -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated from llama_stack.apis.common.content_types import URL, InterleavedContent -from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig from llama_stack.apis.inference import ( CompletionMessage, SamplingParams, @@ -140,6 +138,7 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) available_tools: Optional[List[str]] = Field(default_factory=list) + custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cc4ef38a94..ba190f5673 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -400,6 +400,10 @@ async def _run( output_attachments = [] n_iter = 0 + # Build a map of custom tools to their definitions for faster lookup + custom_tools = {} + for tool in self.agent_config.custom_tools: + custom_tools[tool.name] = tool while True: msg = input_messages[-1] @@ -530,6 +534,9 @@ async def _run( else: log.info(f"{str(message)}") tool_call = message.tool_calls[0] + if tool_call.tool_name in custom_tools: + yield message + return step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -619,6 +626,22 @@ def interpret_content_as_attachment( async def _get_tools(self) -> List[ToolDefinition]: ret = [] + for tool in self.agent_config.custom_tools: + params = {} + for param in tool.parameters: + params[param.name] = ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + ret.append( + ToolDefinition( + tool_name=tool.name, + description=tool.description, + parameters=params, + ) + ) for tool_name in self.agent_config.available_tools: tool = await self.tool_groups_api.get_tool(tool_name) if tool.built_in_type: diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 7939259d18..ef3c087fad 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -9,16 +9,13 @@ from uuid import uuid4 import pytest - -from llama_stack.providers.tests.env import get_env_or_fail from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.custom_tool import CustomTool from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types import CompletionMessage, ToolResponseMessage +from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.tool_param_definition_param import ( - ToolParamDefinitionParam, -) +from llama_stack_client.types.custom_tool_def import Parameter +from llama_stack_client.types.shared.completion_message import CompletionMessage class TestCustomTool(CustomTool): @@ -54,13 +51,17 @@ def get_name(self) -> str: def get_description(self) -> str: return "Get the boiling point of a imaginary liquids (eg. polyjuice)" - def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: + def get_params_definition(self) -> Dict[str, Parameter]: return { - "liquid_name": ToolParamDefinitionParam( - param_type="string", description="The name of the liquid", required=True + "liquid_name": Parameter( + name="liquid_name", + parameter_type="string", + description="The name of the liquid", + required=True, ), - "celcius": ToolParamDefinitionParam( - param_type="boolean", + "celcius": Parameter( + name="celcius", + parameter_type="boolean", description="Whether to return the boiling point in Celcius", required=False, ), @@ -203,37 +204,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config): + custom_tool = TestCustomTool() agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tools": [ - { - "type": "brave_search", - "engine": "brave", - "api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"), - }, - { - "function_name": "get_boiling_point", - "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", - "parameters": { - "liquid_name": { - "param_type": "str", - "description": "The name of the liquid", - "required": True, - }, - "celcius": { - "param_type": "boolean", - "description": "Whether to return the boiling point in Celcius", - "required": False, - }, - }, - "type": "function_call", - }, - ], + "available_tools": ["brave_search"], + "custom_tools": [custom_tool.get_tool_definition()], "tool_prompt_format": "python_list", } - agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),)) + agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( From c76f5f418f15972ecf1319d81de500a72e209c21 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 24 Dec 2024 11:45:36 -0800 Subject: [PATCH 10/53] move brave and tavily to remote --- .../providers/registry/tool_runtime.py | 36 ++++++++++--------- .../tool_runtime/brave_search/__init__.py | 2 +- .../tool_runtime/brave_search/brave_search.py | 0 .../tool_runtime/brave_search/config.py | 0 .../tool_runtime/tavily_search/__init__.py | 2 +- .../tool_runtime/tavily_search/config.py | 0 .../tavily_search/tavily_search.py | 0 tests/client-sdk/agents/test_agents.py | 2 +- 8 files changed, 23 insertions(+), 19 deletions(-) rename llama_stack/providers/{inline => remote}/tool_runtime/brave_search/__init__.py (88%) rename llama_stack/providers/{inline => remote}/tool_runtime/brave_search/brave_search.py (100%) rename llama_stack/providers/{inline => remote}/tool_runtime/brave_search/config.py (100%) rename llama_stack/providers/{inline => remote}/tool_runtime/tavily_search/__init__.py (88%) rename llama_stack/providers/{inline => remote}/tool_runtime/tavily_search/config.py (100%) rename llama_stack/providers/{inline => remote}/tool_runtime/tavily_search/tavily_search.py (100%) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index e4e61109f2..b6b34edf0a 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -17,14 +17,6 @@ def available_providers() -> List[ProviderSpec]: return [ - InlineProviderSpec( - api=Api.tool_runtime, - provider_type="inline::brave-search", - pip_packages=[], - module="llama_stack.providers.inline.tool_runtime.brave_search", - config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig", - provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", - ), InlineProviderSpec( api=Api.tool_runtime, provider_type="inline::memory-runtime", @@ -33,14 +25,6 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", api_dependencies=[Api.memory, Api.memory_banks, Api.inference], ), - InlineProviderSpec( - api=Api.tool_runtime, - provider_type="inline::tavily-search", - pip_packages=[], - module="llama_stack.providers.inline.tool_runtime.tavily_search", - config_class="llama_stack.providers.inline.tool_runtime.tavily_search.config.TavilySearchToolConfig", - provider_data_validator="llama_stack.providers.inline.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", - ), InlineProviderSpec( api=Api.tool_runtime, provider_type="inline::code-interpreter", @@ -48,6 +32,26 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.tool_runtime.code_interpreter", config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig", ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="brave-search", + module="llama_stack.providers.remote.tool_runtime.brave_search", + config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", + ), + ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="tavily-search", + module="llama_stack.providers.remote.tool_runtime.tavily_search", + config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", + ), + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py b/llama_stack/providers/remote/tool_runtime/brave_search/__init__.py similarity index 88% rename from llama_stack/providers/inline/tool_runtime/brave_search/__init__.py rename to llama_stack/providers/remote/tool_runtime/brave_search/__init__.py index e9f0eeae81..0827e51d22 100644 --- a/llama_stack/providers/inline/tool_runtime/brave_search/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/__init__.py @@ -14,7 +14,7 @@ class BraveSearchToolProviderDataValidator(BaseModel): api_key: str -async def get_provider_impl(config: BraveSearchToolConfig, _deps): +async def get_adapter_impl(config: BraveSearchToolConfig, _deps): impl = BraveSearchToolRuntimeImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py similarity index 100% rename from llama_stack/providers/inline/tool_runtime/brave_search/brave_search.py rename to llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py diff --git a/llama_stack/providers/inline/tool_runtime/brave_search/config.py b/llama_stack/providers/remote/tool_runtime/brave_search/config.py similarity index 100% rename from llama_stack/providers/inline/tool_runtime/brave_search/config.py rename to llama_stack/providers/remote/tool_runtime/brave_search/config.py diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py b/llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py similarity index 88% rename from llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py rename to llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py index 8061a250cb..379e990817 100644 --- a/llama_stack/providers/inline/tool_runtime/tavily_search/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py @@ -14,7 +14,7 @@ class TavilySearchToolProviderDataValidator(BaseModel): api_key: str -async def get_provider_impl(config: TavilySearchToolConfig, _deps): +async def get_adapter_impl(config: TavilySearchToolConfig, _deps): impl = TavilySearchToolRuntimeImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/config.py b/llama_stack/providers/remote/tool_runtime/tavily_search/config.py similarity index 100% rename from llama_stack/providers/inline/tool_runtime/tavily_search/config.py rename to llama_stack/providers/remote/tool_runtime/tavily_search/config.py diff --git a/llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py similarity index 100% rename from llama_stack/providers/inline/tool_runtime/tavily_search/tavily_search.py rename to llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index ef3c087fad..1b21929490 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -191,7 +191,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): messages=[ { "role": "user", - "content": "Write code to answer the question: What is the 100th prime number?", + "content": "Write code and execute it to find the answer for: What is the 100th prime number?", }, ], session_id=session_id, From 97798c84420ebee05533168738b63108d42366bf Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 26 Dec 2024 09:13:34 -0800 Subject: [PATCH 11/53] add a RAG test to client SDK --- llama_stack/apis/agents/agents.py | 1 + .../agents/meta_reference/agent_instance.py | 47 ++++++++++--- tests/client-sdk/agents/test_agents.py | 66 +++++++++++++++++++ 3 files changed, 105 insertions(+), 9 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 3348211c97..14278b803f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -184,6 +184,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel): AgentTurnResponseEventType.step_complete.value ) step_type: StepType + step_id: str step_details: Step diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index ba190f5673..1ecb95e683 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -313,6 +313,7 @@ async def run_multiple_shields_wrapper( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, + step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, @@ -333,6 +334,7 @@ async def run_multiple_shields_wrapper( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( step_type=StepType.shield_call.value, + step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, @@ -355,28 +357,26 @@ async def _run( if self.agent_config.preprocessing_tools: with tracing.span("preprocessing_tools") as span: for tool_name in self.agent_config.preprocessing_tools: + step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( step_type=StepType.tool_execution.value, - step_id=str(uuid.uuid4()), + step_id=step_id, ) ) ) args = dict( session_id=session_id, + turn_id=turn_id, input_messages=input_messages, attachments=attachments, ) - result = await self.tool_runtime_api.invoke_tool( - tool_name=tool_name, - args=args, - ) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( step_type=StepType.tool_execution.value, - step_id=str(uuid.uuid4()), + step_id=step_id, tool_call_delta=ToolCallDelta( parse_status=ToolCallParseStatus.success, content=ToolCall( @@ -386,6 +386,37 @@ async def _run( ) ) ) + result = await self.tool_runtime_api.invoke_tool( + tool_name=tool_name, + args=args, + ) + + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + step_details=ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[ + ToolCall( + call_id="", + tool_name=tool_name, + arguments={}, + ) + ], + tool_responses=[ + ToolResponse( + call_id="", + tool_name=tool_name, + content=result.content, + ) + ], + ), + ) + ) + ) span.set_attribute( "input", [m.model_dump_json() for m in input_messages] ) @@ -393,7 +424,7 @@ async def _run( span.set_attribute("error_code", result.error_code) span.set_attribute("error_message", result.error_message) span.set_attribute("tool_name", tool_name) - if result.error_code != 0 and result.content: + if result.error_code == 0: last_message = input_messages[-1] last_message.context = result.content @@ -405,8 +436,6 @@ async def _run( for tool in self.agent_config.custom_tools: custom_tools[tool.name] = tool while True: - msg = input_messages[-1] - step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 1b21929490..10aaa09b5a 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -15,6 +15,7 @@ from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.custom_tool_def import Parameter +from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage @@ -230,3 +231,68 @@ def test_custom_tool(llama_stack_client, agent_config): logs_str = "".join(logs) assert "-100" in logs_str assert "CustomTool" in logs_str + + +def test_rag_agent(llama_stack_client, agent_config): + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + documents = [ + Document( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + llama_stack_client.memory_banks.register( + memory_bank_id="test_bank", + params={ + "memory_bank_type": "vector", + "embedding_model": "all-MiniLM-L6-v2", + "chunk_size_in_tokens": 512, + "overlap_size_in_tokens": 64, + }, + provider_id="faiss", + ) + + # insert some documents + llama_stack_client.memory.insert( + bank_id="test_bank", + documents=documents, + ) + + agent_config = { + **agent_config, + "preprocessing_tools": ["memory-tool"], + } + agent = Agent(llama_stack_client, agent_config) + session_id = agent.create_session(f"test-session-{uuid4()}") + + user_prompts = [ + "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", + "Was anything related to 'Llama3' discussed, if so what?", + "Tell me how to use LoRA", + "What about Quantization?", + ] + + for prompt in user_prompts: + response = agent.create_turn( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + session_id=session_id, + ) + + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + assert "Tool:memory-tool" in logs_str From f408fd3acac17cea9c77884f5f369555b30d3944 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 26 Dec 2024 15:48:52 -0800 Subject: [PATCH 12/53] remove attachements, move memory bank to tool metadata --- llama_stack/apis/agents/agents.py | 3 - .../agents/meta_reference/agent_instance.py | 21 ++- .../inline/agents/meta_reference/agents.py | 2 - .../inline/tool_runtime/memory/__init__.py | 4 +- .../inline/tool_runtime/memory/config.py | 11 +- .../tool_runtime/memory/context_retriever.py | 15 +- .../inline/tool_runtime/memory/memory.py | 162 ++---------------- .../providers/registry/tool_runtime.py | 2 +- tests/client-sdk/agents/test_agents.py | 5 +- 9 files changed, 45 insertions(+), 180 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 14278b803f..88ae919062 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -39,7 +39,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -@json_schema_type class Attachment(BaseModel): content: InterleavedContent | URL mime_type: str @@ -258,7 +257,6 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): ToolResponseMessage, ] ] - attachments: Optional[List[Attachment]] = None stream: Optional[bool] = False @@ -295,7 +293,6 @@ async def create_agent_turn( ToolResponseMessage, ] ], - attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1ecb95e683..6cf031bf7a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -188,7 +188,6 @@ async def create_and_execute_turn( session_id=request.session_id, turn_id=turn_id, input_messages=messages, - attachments=request.attachments or [], sampling_params=self.agent_config.sampling_params, stream=request.stream, ): @@ -238,7 +237,6 @@ async def run( session_id: str, turn_id: str, input_messages: List[Message], - attachments: List[Attachment], sampling_params: SamplingParams, stream: bool = False, ) -> AsyncGenerator: @@ -257,7 +255,7 @@ async def run( yield res async for res in self._run( - session_id, turn_id, input_messages, attachments, sampling_params, stream + session_id, turn_id, input_messages, sampling_params, stream ): if isinstance(res, bool): return @@ -350,7 +348,6 @@ async def _run( session_id: str, turn_id: str, input_messages: List[Message], - attachments: List[Attachment], sampling_params: SamplingParams, stream: bool = False, ) -> AsyncGenerator: @@ -370,7 +367,6 @@ async def _run( session_id=session_id, turn_id=turn_id, input_messages=input_messages, - attachments=attachments, ) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -423,7 +419,10 @@ async def _run( span.set_attribute("output", result.content) span.set_attribute("error_code", result.error_code) span.set_attribute("error_message", result.error_message) - span.set_attribute("tool_name", tool_name) + if isinstance(tool_name, BuiltinTool): + span.set_attribute("tool_name", tool_name.value) + else: + span.set_attribute("tool_name", tool_name) if result.error_code == 0: last_message = input_messages[-1] last_message.context = result.content @@ -553,9 +552,9 @@ async def _run( # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) if len(output_attachments) > 0: if isinstance(message.content, list): - message.content += attachments + message.content += output_attachments else: - message.content = [message.content] + attachments + message.content = [message.content] + output_attachments yield message else: log.info(f"Partial message: {str(message)}") @@ -586,10 +585,13 @@ async def _run( ) ) + tool_name = tool_call.tool_name + if isinstance(tool_name, BuiltinTool): + tool_name = tool_name.value with tracing.span( "tool_execution", { - "tool_name": tool_call.tool_name, + "tool_name": tool_name, "input": message.model_dump_json(), }, ) as span: @@ -608,6 +610,7 @@ async def _run( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( step_type=StepType.tool_execution.value, + step_id=step_id, step_details=ToolExecutionStep( step_id=step_id, turn_id=turn_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 89b38a7fc6..5769c42e5f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -146,14 +146,12 @@ async def create_agent_turn( ToolResponseMessage, ] ], - attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, session_id=session_id, messages=messages, - attachments=attachments, stream=True, ) if stream: diff --git a/llama_stack/providers/inline/tool_runtime/memory/__init__.py b/llama_stack/providers/inline/tool_runtime/memory/__init__.py index 36377f1471..928afa4846 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/memory/__init__.py @@ -8,11 +8,11 @@ from llama_stack.providers.datatypes import Api -from .config import MemoryToolConfig +from .config import MemoryToolRuntimeConfig from .memory import MemoryToolRuntimeImpl -async def get_provider_impl(config: MemoryToolConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]): impl = MemoryToolRuntimeImpl( config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference] ) diff --git a/llama_stack/providers/inline/tool_runtime/memory/config.py b/llama_stack/providers/inline/tool_runtime/memory/config.py index cb24883dc0..6ff242c6ba 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/config.py +++ b/llama_stack/providers/inline/tool_runtime/memory/config.py @@ -7,9 +7,6 @@ from enum import Enum from typing import Annotated, List, Literal, Union -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR -from llama_stack.providers.utils.kvstore import KVStoreConfig, SqliteKVStoreConfig - from pydantic import BaseModel, Field @@ -81,13 +78,13 @@ class CustomMemoryQueryGeneratorConfig(BaseModel): class MemoryToolConfig(BaseModel): memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list) + + +class MemoryToolRuntimeConfig(BaseModel): # This config defines how a query is generated using the messages # for memory bank retrieval. query_generator_config: MemoryQueryGeneratorConfig = Field( default=DefaultMemoryQueryGeneratorConfig() ) max_tokens_in_context: int = 4096 - max_chunks: int = 10 - kvstore_config: KVStoreConfig = SqliteKVStoreConfig( - db_path=(RUNTIME_BASE_DIR / "memory.db").as_posix() - ) + max_chunks: int = 5 diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index da97cb3a3a..7ee751a173 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List from jinja2 import Template @@ -23,7 +22,7 @@ async def generate_rag_query( config: MemoryQueryGeneratorConfig, - messages: List[Message], + message: Message, **kwargs, ): """ @@ -31,9 +30,9 @@ async def generate_rag_query( retrieving relevant information from the memory bank. """ if config.type == MemoryQueryGenerator.default.value: - query = await default_rag_query_generator(config, messages, **kwargs) + query = await default_rag_query_generator(config, message, **kwargs) elif config.type == MemoryQueryGenerator.llm.value: - query = await llm_rag_query_generator(config, messages, **kwargs) + query = await llm_rag_query_generator(config, message, **kwargs) else: raise NotImplementedError(f"Unsupported memory query generator {config.type}") return query @@ -41,21 +40,21 @@ async def generate_rag_query( async def default_rag_query_generator( config: DefaultMemoryQueryGeneratorConfig, - messages: List[Message], + message: Message, **kwargs, ): - return config.sep.join(interleaved_content_as_str(m.content) for m in messages) + return interleaved_content_as_str(message.content) async def llm_rag_query_generator( config: LLMMemoryQueryGeneratorConfig, - messages: List[Message], + message: Message, **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = {"messages": [m.model_dump() for m in messages]} + m_dict = {"messages": [message.model_dump()]} template = Template(config.template) content = template.render(m_dict) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index 3a08bf1f98..d492309cd8 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -5,24 +5,14 @@ # the root directory of this source tree. import asyncio -import json import logging -import os -import re import secrets import string -import tempfile -import uuid from typing import Any, Dict, List, Optional -from urllib.parse import urlparse -import httpx - -from llama_stack.apis.agents import Attachment -from llama_stack.apis.common.content_types import TextContentItem, URL from llama_stack.apis.inference import Inference, InterleavedContent, Message -from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse -from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams +from llama_stack.apis.memory import Memory, QueryDocumentsResponse +from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.tools import ( ToolDef, ToolGroupDef, @@ -30,22 +20,14 @@ ToolRuntime, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate -from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content -from pydantic import BaseModel -from .config import MemoryToolConfig +from .config import MemoryToolConfig, MemoryToolRuntimeConfig from .context_retriever import generate_rag_query log = logging.getLogger(__name__) -class MemorySessionInfo(BaseModel): - session_id: str - session_name: str - memory_bank_id: Optional[str] = None - - def make_random_string(length: int = 8): return "".join( secrets.choice(string.ascii_letters + string.digits) for _ in range(length) @@ -55,7 +37,7 @@ def make_random_string(length: int = 8): class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): def __init__( self, - config: MemoryToolConfig, + config: MemoryToolRuntimeConfig, memory_api: Memory, memory_banks_api: MemoryBanks, inference_api: Inference, @@ -63,113 +45,26 @@ def __init__( self.config = config self.memory_api = memory_api self.memory_banks_api = memory_banks_api - self.tempdir = tempfile.mkdtemp() self.inference_api = inference_api async def initialize(self): - self.kvstore = await kvstore_impl(self.config.kvstore_config) + pass async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: return [] - async def create_session(self, session_id: str) -> MemorySessionInfo: - session_info = MemorySessionInfo( - session_id=session_id, - session_name=f"session_{session_id}", - ) - await self.kvstore.set( - key=f"memory::session:{session_id}", - value=session_info.model_dump_json(), - ) - return session_info - - async def get_session_info(self, session_id: str) -> Optional[MemorySessionInfo]: - value = await self.kvstore.get( - key=f"memory::session:{session_id}", - ) - if not value: - session_info = await self.create_session(session_id) - return session_info - - return MemorySessionInfo(**json.loads(value)) - - async def add_memory_bank_to_session(self, session_id: str, bank_id: str): - session_info = await self.get_session_info(session_id) - - session_info.memory_bank_id = bank_id - await self.kvstore.set( - key=f"memory::session:{session_id}", - value=session_info.model_dump_json(), - ) - - async def _ensure_memory_bank(self, session_id: str) -> str: - session_info = await self.get_session_info(session_id) - - if session_info.memory_bank_id is None: - bank_id = f"memory_bank_{session_id}" - await self.memory_banks_api.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), - ) - await self.add_memory_bank_to_session(session_id, bank_id) - else: - bank_id = session_info.memory_bank_id - - return bank_id - - async def attachment_message( - self, tempdir: str, urls: List[URL] - ) -> List[TextContentItem]: - content = [] - - for url in urls: - uri = url.uri - if uri.startswith("file://"): - filepath = uri[len("file://") :] - elif uri.startswith("http"): - path = urlparse(uri).path - basename = os.path.basename(path) - filepath = f"{tempdir}/{make_random_string() + basename}" - log.info(f"Downloading {url} -> {filepath}") - - async with httpx.AsyncClient() as client: - r = await client.get(uri) - resp = r.text - with open(filepath, "w") as fp: - fp.write(resp) - else: - raise ValueError(f"Unsupported URL {url}") - - content.append( - TextContentItem( - text=f'# There is a file accessible to you at "{filepath}"\n' - ) - ) - - return content - async def _retrieve_context( - self, session_id: str, messages: List[Message] + self, messages: List[Message], bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: - bank_ids = [] - - bank_ids.extend(c.bank_id for c in self.config.memory_bank_configs) - - session_info = await self.get_session_info(session_id) - if session_info.memory_bank_id: - bank_ids.append(session_info.memory_bank_id) - if not bank_ids: - # this can happen if the per-session memory bank is not yet populated - # (i.e., no prior turns uploaded an Attachment) + return None + if len(messages) == 0: return None + message = messages[-1] # only use the last message as input to the query query = await generate_rag_query( self.config.query_generator_config, - messages, + message, inference_api=self.inference_api, ) tasks = [ @@ -177,7 +72,7 @@ async def _retrieve_context( bank_id=bank_id, query=query, params={ - "max_chunks": 5, + "max_chunks": self.config.max_chunks, }, ) for bank_id in bank_ids @@ -211,43 +106,20 @@ async def _retrieve_context( "\n=== END-RETRIEVED-CONTEXT ===\n", ] - async def _process_attachments( - self, session_id: str, attachments: List[Attachment] - ): - bank_id = await self._ensure_memory_bank(session_id) - - documents = [ - MemoryBankDocument( - document_id=str(uuid.uuid4()), - content=a.content, - mime_type=a.mime_type, - metadata={}, - ) - for a in attachments - if isinstance(a.content, str) - ] - await self.memory_api.insert_documents(bank_id, documents) - - urls = [a.content for a in attachments if isinstance(a.content, URL)] - # TODO: we need to migrate URL away from str type - pattern = re.compile("^(https?://|file://|data:)") - urls += [URL(uri=a.content) for a in attachments if pattern.match(a.content)] - return await self.attachment_message(self.tempdir, urls) - async def invoke_tool( self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: - if args["session_id"] is None: - raise ValueError("session_id is required") + tool = await self.tool_store.get_tool(tool_name) + config = MemoryToolConfig() + if tool.metadata.get("config") is not None: + config = MemoryToolConfig(**tool.metadata["config"]) context = await self._retrieve_context( - args["session_id"], args["input_messages"] + args["input_messages"], + [bank_config.bank_id for bank_config in config.memory_bank_configs], ) if context is None: context = [] - attachments = args["attachments"] - if attachments and len(attachments) > 0: - context += await self._process_attachments(args["session_id"], attachments) return ToolInvocationResult( content=concat_interleaved_content(context), error_code=0 ) diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b6b34edf0a..d6e8925992 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -22,7 +22,7 @@ def available_providers() -> List[ProviderSpec]: provider_type="inline::memory-runtime", pip_packages=[], module="llama_stack.providers.inline.tool_runtime.memory", - config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolConfig", + config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", api_dependencies=[Api.memory, Api.memory_banks, Api.inference], ), InlineProviderSpec( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 10aaa09b5a..7f8b5b26b8 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -21,7 +21,7 @@ class TestCustomTool(CustomTool): """Tool to give boiling point of a liquid - Returns the correct value for water in Celcius and Fahrenheit + Returns the correct value for polyjuice in Celcius and Fahrenheit and returns -1 for other liquids """ @@ -50,7 +50,7 @@ def get_name(self) -> str: return "get_boiling_point" def get_description(self) -> str: - return "Get the boiling point of a imaginary liquids (eg. polyjuice)" + return "Get the boiling point of imaginary liquids (eg. polyjuice)" def get_params_definition(self) -> Dict[str, Parameter]: return { @@ -279,7 +279,6 @@ def test_rag_agent(llama_stack_client, agent_config): "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", "Was anything related to 'Llama3' discussed, if so what?", "Tell me how to use LoRA", - "What about Quantization?", ] for prompt in user_prompts: From 439f52b0670c77c4fd69e06e6d3b8b59d79bad45 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 26 Dec 2024 16:02:41 -0800 Subject: [PATCH 13/53] register toolgroup as part of test --- tests/client-sdk/agents/test_agents.py | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 7f8b5b26b8..36674631bd 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -17,6 +17,8 @@ from llama_stack_client.types.custom_tool_def import Parameter from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.tool_def_param import CustomToolDefParam +from llama_stack_client.types.tool_group_def_param import UserDefinedToolGroupDef class TestCustomTool(CustomTool): @@ -268,6 +270,36 @@ def test_rag_agent(llama_stack_client, agent_config): documents=documents, ) + # create the required memory tool + llama_stack_client.toolgroups.register( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + type="user_defined", + tools=[ + CustomToolDefParam( + type="custom", + name="memory-tool", + description="Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments", + parameters=[ + Parameter( + name="input_messages", + description="Input messages for which to retrieve memory", + required=True, + parameter_type="list", + ), + ], + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, + ) + ], + ), + provider_id="memory-runtime", + ) agent_config = { **agent_config, "preprocessing_tools": ["memory-tool"], From 18d99375007c25a9d4a0d191e15e022be2ddaeb0 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 26 Dec 2024 18:24:27 -0800 Subject: [PATCH 14/53] fix agent server tests --- .../distribution/routers/routing_tables.py | 7 ++- .../providers/tests/agents/conftest.py | 10 ++-- .../providers/tests/agents/fixtures.py | 50 ++++--------------- .../providers/tests/agents/test_agents.py | 37 ++++++++------ 4 files changed, 41 insertions(+), 63 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 2aff0f3a23..45708649bf 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional -from pydantic import parse_obj_as +from pydantic import TypeAdapter from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType @@ -39,7 +39,6 @@ RoutableObjectWithProvider, RoutedProtocol, ) - from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable @@ -361,7 +360,7 @@ async def register_memory_bank( memory_bank_data["embedding_dimension"] = model.metadata[ "embedding_dimension" ] - memory_bank = parse_obj_as(MemoryBank, memory_bank_data) + memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data) await self.register_object(memory_bank) return memory_bank @@ -525,7 +524,7 @@ async def register_tool_group( provider_id = list(self.impls_by_provider_id.keys())[0] # parse tool group to the type if dict - tool_group = parse_obj_as(ToolGroupDef, tool_group) + tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group) if isinstance(tool_group, MCPToolGroupDef): tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( tool_group diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index d80013fae3..f805fbbbb7 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -19,7 +19,7 @@ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="meta_reference", marks=pytest.mark.meta_reference, @@ -30,7 +30,7 @@ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="ollama", marks=pytest.mark.ollama, @@ -42,7 +42,7 @@ # make this work with Weaviate which is what the together distro supports "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="together", marks=pytest.mark.together, @@ -53,7 +53,7 @@ "safety": "llama_guard", "memory": "faiss", "agents": "meta_reference", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="fireworks", marks=pytest.mark.fireworks, @@ -64,7 +64,7 @@ "safety": "remote", "memory": "remote", "agents": "remote", - "tool_runtime": "memory", + "tool_runtime": "memory_and_search", }, id="remote", marks=pytest.mark.remote, diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 97d0d47e69..71e98102e5 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -64,7 +64,7 @@ def agents_meta_reference() -> ProviderFixture: @pytest.fixture(scope="session") -def tool_runtime_memory() -> ProviderFixture: +def tool_runtime_memory_and_search() -> ProviderFixture: return ProviderFixture( providers=[ Provider( @@ -72,31 +72,19 @@ def tool_runtime_memory() -> ProviderFixture: provider_type="inline::memory-runtime", config={}, ), - Provider( - provider_id="brave-search", - provider_type="inline::brave-search", - config={ - "api_key": os.environ["BRAVE_SEARCH_API_KEY"], - }, - ), Provider( provider_id="tavily-search", - provider_type="inline::tavily-search", + provider_type="remote::tavily-search", config={ "api_key": os.environ["TAVILY_SEARCH_API_KEY"], }, ), - Provider( - provider_id="code-interpreter", - provider_type="inline::code-interpreter", - config={}, - ), ], ) AGENTS_FIXTURES = ["meta_reference", "remote"] -TOOL_RUNTIME_FIXTURES = ["memory"] +TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") @@ -173,43 +161,25 @@ async def agents_stack(request, inference_model, safety_shield): name="memory", description="memory", parameters=[ - ToolParameter( - name="session_id", - description="session id", - parameter_type="string", - required=True, - ), ToolParameter( name="input_messages", description="messages", parameter_type="list", required=True, ), - ToolParameter( - name="attachments", - description="attachments", - parameter_type="list", - required=False, - ), ], - metadata={}, + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, ) ], ), provider_id="memory-runtime", ), - ToolGroupInput( - tool_group_id="code_interpreter_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - BuiltInToolDef( - built_in_type=BuiltinTool.code_interpreter, - metadata={}, - ) - ], - ), - provider_id="code-interpreter", - ), ] test_stack = await construct_stack_for_test( diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index a8c472da47..3534e0f843 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -8,19 +8,13 @@ from typing import Dict, List import pytest -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, - AgentTool, AgentTurnResponseEventType, AgentTurnResponseStepCompletePayload, AgentTurnResponseStreamChunk, AgentTurnResponseTurnCompletePayload, - Attachment, - MemoryToolDefinition, - SearchEngineType, - SearchToolDefinition, ShieldCallStep, StepType, ToolChoice, @@ -228,7 +222,7 @@ async def test_create_agent_turn( check_turn_complete_event(turn_response, session_id, sample_messages) @pytest.mark.asyncio - async def test_rag_agent_as_attachments( + async def test_rag_agent( self, agents_stack, attachment_message, @@ -236,6 +230,8 @@ async def test_rag_agent_as_attachments( common_params, ): agents_impl = agents_stack.impls[Api.agents] + memory_banks_impl = agents_stack.impls[Api.memory_banks] + memory_impl = agents_stack.impls[Api.memory] urls = [ "memory_optimizations.rst", "chat.rst", @@ -244,14 +240,28 @@ async def test_rag_agent_as_attachments( "qat_finetune.rst", "lora_finetune.rst", ] - - attachments = [ - Attachment( + documents = [ + MemoryBankDocument( + document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", + metadata={}, ) for i, url in enumerate(urls) ] + await memory_banks_impl.register_memory_bank( + memory_bank_id="test_bank", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + provider_id="faiss", + ) + memory_impl.insert_documents( + bank_id="test_bank", + documents=documents, + ) agent_config = AgentConfig( **{ @@ -266,7 +276,6 @@ async def test_rag_agent_as_attachments( agent_id=agent_id, session_id=session_id, messages=attachment_message, - attachments=attachments, stream=True, ) turn_response = [ @@ -290,11 +299,11 @@ async def test_rag_agent_as_attachments( assert len(turn_response) > 0 @pytest.mark.asyncio - async def test_create_agent_turn_with_brave_search( + async def test_create_agent_turn_with_tavily_search( self, agents_stack, search_query_messages, common_params ): - if "BRAVE_SEARCH_API_KEY" not in os.environ: - pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") await create_agent_turn_with_search_tool( agents_stack, From 50852cadf3ad733aabb926132e620452c0cadbbe Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 10:50:59 -0800 Subject: [PATCH 15/53] add tool tests --- .gitignore | 1 + .../providers/tests/memory/fixtures.py | 1 + llama_stack/providers/tests/tools/__init__.py | 5 + llama_stack/providers/tests/tools/conftest.py | 65 +++++++++ llama_stack/providers/tests/tools/fixtures.py | 138 ++++++++++++++++++ .../providers/tests/tools/test_tools.py | 99 +++++++++++++ 6 files changed, 309 insertions(+) create mode 100644 llama_stack/providers/tests/tools/__init__.py create mode 100644 llama_stack/providers/tests/tools/conftest.py create mode 100644 llama_stack/providers/tests/tools/fixtures.py create mode 100644 llama_stack/providers/tests/tools/test_tools.py diff --git a/.gitignore b/.gitignore index 421ff4db11..f3585a51f5 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ Package.resolved _build docs/src pyrightconfig.json +.aider* diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 9a98526abb..b9dbb84f78 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -19,6 +19,7 @@ from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail diff --git a/llama_stack/providers/tests/tools/__init__.py b/llama_stack/providers/tests/tools/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/llama_stack/providers/tests/tools/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/tools/conftest.py b/llama_stack/providers/tests/tools/conftest.py new file mode 100644 index 0000000000..11aad5ab66 --- /dev/null +++ b/llama_stack/providers/tests/tools/conftest.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides +from ..inference.fixtures import INFERENCE_FIXTURES +from ..memory.fixtures import MEMORY_FIXTURES +from ..safety.fixtures import SAFETY_FIXTURES +from .fixtures import TOOL_RUNTIME_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "together", + "safety": "llama_guard", + "memory": "faiss", + "tool_runtime": "memory_and_search", + }, + id="together", + marks=pytest.mark.together, + ), +] + + +def pytest_configure(config): + for mark in ["together"]: + config.addinivalue_line( + "markers", + f"{mark}: marks tests as {mark} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="meta-llama/Llama-3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + parser.addoption( + "--safety-shield", + action="store", + default="meta-llama/Llama-Guard-3-1B", + help="Specify the safety shield to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "tools_stack" in metafunc.fixturenames: + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "safety": SAFETY_FIXTURES, + "memory": MEMORY_FIXTURES, + "tool_runtime": TOOL_RUNTIME_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + print(combinations) + metafunc.parametrize("tools_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py new file mode 100644 index 0000000000..f7580ee2f3 --- /dev/null +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest +import pytest_asyncio + +from llama_stack.apis.models import ModelInput, ModelType +from llama_stack.apis.tools import ( + BuiltInToolDef, + CustomToolDef, + ToolGroupInput, + ToolParameter, + UserDefinedToolGroupDef, +) +from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.providers.tests.resolver import construct_stack_for_test + +from ..conftest import ProviderFixture + + +@pytest.fixture(scope="session") +def tool_runtime_memory_and_search() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="memory-runtime", + provider_type="inline::memory-runtime", + config={}, + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + config={ + "api_key": os.environ["TAVILY_SEARCH_API_KEY"], + }, + ), + ], + ) + + +TOOL_RUNTIME_FIXTURES = ["memory_and_search"] + + +@pytest_asyncio.fixture(scope="session") +async def tools_stack(request, inference_model, safety_shield): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["inference", "memory", "tools", "tool_runtime"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if key == "inference": + providers[key].append( + Provider( + provider_id="tools_memory_provider", + provider_type="inline::sentence-transformers", + config={}, + ) + ) + if fixture.provider_data: + provider_data.update(fixture.provider_data) + inference_models = ( + inference_model if isinstance(inference_model, list) else [inference_model] + ) + models = [ + ModelInput( + model_id=model, + model_type=ModelType.llm, + provider_id=providers["inference"][0].provider_id, + ) + for model in inference_models + ] + models.append( + ModelInput( + model_id="all-MiniLM-L6-v2", + model_type=ModelType.embedding, + provider_id="tools_memory_provider", + metadata={"embedding_dimension": 384}, + ) + ) + + tool_groups = [ + ToolGroupInput( + tool_group_id="tavily_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + BuiltInToolDef( + name="brave_search", + description="Search the web using Brave Search", + metadata={}, + ), + ], + ), + provider_id="tavily-search", + ), + ToolGroupInput( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + CustomToolDef( + name="memory", + description="Query the memory bank", + parameters=[ + ToolParameter( + name="query", + description="The query to search for in memory", + parameter_type="string", + required=True, + ), + ToolParameter( + name="memory_bank_id", + description="The ID of the memory bank to search", + parameter_type="string", + required=True, + ), + ], + metadata={}, + ) + ], + ), + provider_id="memory-runtime", + ), + ] + + test_stack = await construct_stack_for_test( + [Api.tools, Api.inference, Api.memory], + providers, + provider_data, + models=models, + tool_groups=tool_groups, + ) + return test_stack diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py new file mode 100644 index 0000000000..96a80414c1 --- /dev/null +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest + +from llama_stack.apis.memory import MemoryBankDocument +from llama_stack.apis.memory_banks import VectorMemoryBankParams +from llama_stack.apis.tools import ToolInvocationResult +from llama_stack.providers.datatypes import Api + + +@pytest.fixture +def sample_search_query(): + return "What are the latest developments in quantum computing?" + + +@pytest.fixture +def sample_documents(): + urls = [ + "memory_optimizations.rst", + "chat.rst", + "llama3.rst", + "datasets.rst", + "qat_finetune.rst", + "lora_finetune.rst", + ] + return [ + MemoryBankDocument( + document_id=f"num-{i}", + content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", + mime_type="text/plain", + metadata={}, + ) + for i, url in enumerate(urls) + ] + + +class TestTools: + @pytest.mark.asyncio + async def test_brave_search_tool(self, tools_stack, sample_search_query): + """Test the Brave search tool functionality.""" + if "TAVILY_SEARCH_API_KEY" not in os.environ: + pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + + tools_impl = tools_stack.impls[Api.tool_runtime] + + # Execute the tool + response = await tools_impl.invoke_tool( + tool_name="brave_search", tool_args={"query": sample_search_query} + ) + + # Verify the response + assert isinstance(response, ToolInvocationResult) + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str) + + @pytest.mark.asyncio + async def test_memory_tool(self, tools_stack, sample_documents): + """Test the memory tool functionality.""" + memory_banks_impl = tools_stack.impls[Api.memory_banks] + memory_impl = tools_stack.impls[Api.memory] + tools_impl = tools_stack.impls[Api.tools] + + # Register memory bank + await memory_banks_impl.register_memory_bank( + memory_bank_id="test_memory_bank", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + provider_id="faiss", + ) + + # Insert documents into memory + memory_impl.insert_documents( + bank_id="test_memory_bank", + documents=sample_documents, + ) + + # Execute the memory tool + response = await tools_impl.invoke_tool( + tool_name="memory", + tool_args={ + "query": "What are the main topics covered in the documentation?", + }, + ) + + # Verify the response + assert isinstance(response, ToolInvocationResult) + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str) From b7ae86ae03469dbc3e887f62092c54fd922a4841 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 11:37:28 -0800 Subject: [PATCH 16/53] passing tool tests --- llama_stack/providers/tests/tools/fixtures.py | 30 +++++++++---------- .../providers/tests/tools/test_tools.py | 22 +++++++++----- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index f7580ee2f3..845e0dba4a 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -17,6 +17,7 @@ ToolParameter, UserDefinedToolGroupDef, ) +from llama_stack.apis.tools.tools import BuiltinTool from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -47,12 +48,12 @@ def tool_runtime_memory_and_search() -> ProviderFixture: @pytest_asyncio.fixture(scope="session") -async def tools_stack(request, inference_model, safety_shield): +async def tools_stack(request, inference_model): fixture_dict = request.param providers = {} provider_data = {} - for key in ["inference", "memory", "tools", "tool_runtime"]: + for key in ["inference", "memory", "tool_runtime"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if key == "inference": @@ -91,8 +92,7 @@ async def tools_stack(request, inference_model, safety_shield): tool_group=UserDefinedToolGroupDef( tools=[ BuiltInToolDef( - name="brave_search", - description="Search the web using Brave Search", + built_in_type=BuiltinTool.brave_search, metadata={}, ), ], @@ -108,19 +108,19 @@ async def tools_stack(request, inference_model, safety_shield): description="Query the memory bank", parameters=[ ToolParameter( - name="query", - description="The query to search for in memory", - parameter_type="string", - required=True, - ), - ToolParameter( - name="memory_bank_id", - description="The ID of the memory bank to search", - parameter_type="string", + name="input_messages", + description="The input messages to search for in memory", + parameter_type="list", required=True, ), ], - metadata={}, + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, ) ], ), @@ -129,7 +129,7 @@ async def tools_stack(request, inference_model, safety_shield): ] test_stack = await construct_stack_for_test( - [Api.tools, Api.inference, Api.memory], + [Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime], providers, provider_data, models=models, diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index 96a80414c1..08c7afe1e7 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -8,11 +8,14 @@ import pytest +from llama_stack.apis.inference import UserMessage from llama_stack.apis.memory import MemoryBankDocument from llama_stack.apis.memory_banks import VectorMemoryBankParams from llama_stack.apis.tools import ToolInvocationResult from llama_stack.providers.datatypes import Api +from .fixtures import tool_runtime_memory_and_search # noqa: F401 + @pytest.fixture def sample_search_query(): @@ -51,7 +54,7 @@ async def test_brave_search_tool(self, tools_stack, sample_search_query): # Execute the tool response = await tools_impl.invoke_tool( - tool_name="brave_search", tool_args={"query": sample_search_query} + tool_name="brave_search", args={"query": sample_search_query} ) # Verify the response @@ -65,11 +68,11 @@ async def test_memory_tool(self, tools_stack, sample_documents): """Test the memory tool functionality.""" memory_banks_impl = tools_stack.impls[Api.memory_banks] memory_impl = tools_stack.impls[Api.memory] - tools_impl = tools_stack.impls[Api.tools] + tools_impl = tools_stack.impls[Api.tool_runtime] # Register memory bank await memory_banks_impl.register_memory_bank( - memory_bank_id="test_memory_bank", + memory_bank_id="test_bank", params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, @@ -79,16 +82,20 @@ async def test_memory_tool(self, tools_stack, sample_documents): ) # Insert documents into memory - memory_impl.insert_documents( - bank_id="test_memory_bank", + await memory_impl.insert_documents( + bank_id="test_bank", documents=sample_documents, ) # Execute the memory tool response = await tools_impl.invoke_tool( tool_name="memory", - tool_args={ - "query": "What are the main topics covered in the documentation?", + args={ + "input_messages": [ + UserMessage( + content="What are the main topics covered in the documentation?", + ) + ], }, ) @@ -96,4 +103,3 @@ async def test_memory_tool(self, tools_stack, sample_documents): assert isinstance(response, ToolInvocationResult) assert response.content is not None assert len(response.content) > 0 - assert isinstance(response.content, str) From 9a3d7fa33cdb2b33f1c336fe23f5e834fc95ec67 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 13:03:39 -0800 Subject: [PATCH 17/53] rebase fixes --- llama_stack/apis/agents/agents.py | 5 +++-- llama_stack/distribution/datatypes.py | 2 +- llama_stack/distribution/routers/routing_tables.py | 4 ++++ .../inline/agents/meta_reference/agent_instance.py | 10 ++++------ llama_stack/providers/tests/tools/fixtures.py | 2 +- llama_stack/providers/tests/tools/test_tools.py | 2 +- 6 files changed, 14 insertions(+), 11 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 88ae919062..75f1cb9c0a 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -14,15 +14,15 @@ Literal, Optional, Protocol, - Union, runtime_checkable, + Union, ) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated -from llama_stack.apis.common.content_types import URL, InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.inference import ( CompletionMessage, SamplingParams, @@ -36,6 +36,7 @@ ) from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation +from llama_stack.apis.tools import CustomToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ba7ba62bd6..d0ccd6cd12 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -20,7 +20,7 @@ from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput from llama_stack.apis.shields import Shield, ShieldInput -from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime +from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 45708649bf..f0d55eaf2e 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -27,11 +27,15 @@ ) from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.tools import ( + BuiltInToolDef, + CustomToolDef, MCPToolGroupDef, Tool, ToolGroup, ToolGroupDef, ToolGroups, + ToolHost, + ToolPromptFormat, UserDefinedToolGroupDef, ) from llama_stack.distribution.datatypes import ( diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 6cf031bf7a..219afe6211 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,11 +13,11 @@ import string import uuid from datetime import datetime -from typing import AsyncGenerator, List +from typing import AsyncGenerator, List, Optional from urllib.parse import urlparse import httpx -from llama_models.llama3.api.datatypes import BuiltinTool +from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from llama_stack.apis.agents import ( AgentConfig, @@ -37,10 +37,7 @@ ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import ( - URL, - TextContentItem, -) +from llama_stack.apis.common.content_types import TextContentItem, URL from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -59,6 +56,7 @@ from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.safety import Safety +from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.telemetry import tracing diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index 845e0dba4a..5493a49871 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -8,6 +8,7 @@ import pytest import pytest_asyncio +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.tools import ( @@ -17,7 +18,6 @@ ToolParameter, UserDefinedToolGroupDef, ) -from llama_stack.apis.tools.tools import BuiltinTool from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index 08c7afe1e7..7e4947cb76 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -14,7 +14,7 @@ from llama_stack.apis.tools import ToolInvocationResult from llama_stack.providers.datatypes import Api -from .fixtures import tool_runtime_memory_and_search # noqa: F401 +from .fixtures import tools_stack as _tools_stack # noqa: F401, F811 @pytest.fixture From 40439509cae6ca3781c7e3a26f4085d1f8df14ab Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 13:09:38 -0800 Subject: [PATCH 18/53] test fixes --- llama_stack/providers/tests/tools/conftest.py | 2 +- llama_stack/providers/tests/tools/test_tools.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/llama_stack/providers/tests/tools/conftest.py b/llama_stack/providers/tests/tools/conftest.py index 11aad5ab66..6de90dc487 100644 --- a/llama_stack/providers/tests/tools/conftest.py +++ b/llama_stack/providers/tests/tools/conftest.py @@ -10,7 +10,7 @@ from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES -from .fixtures import TOOL_RUNTIME_FIXTURES +from .fixtures import TOOL_RUNTIME_FIXTURES, tools_stack # noqa: F401 DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index 7e4947cb76..f33b4a61d8 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -14,8 +14,6 @@ from llama_stack.apis.tools import ToolInvocationResult from llama_stack.providers.datatypes import Api -from .fixtures import tools_stack as _tools_stack # noqa: F401, F811 - @pytest.fixture def sample_search_query(): From c2dd0cdc78e87d983797c9745885c7320116bdc4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 13:27:43 -0800 Subject: [PATCH 19/53] more test fixes --- .gitignore | 1 - .../providers/tests/agents/conftest.py | 3 +- .../providers/tests/agents/fixtures.py | 81 ++-------------- .../providers/tests/agents/test_agents.py | 2 + llama_stack/providers/tests/conftest.py | 1 + llama_stack/providers/tests/tools/conftest.py | 2 +- llama_stack/providers/tests/tools/fixtures.py | 93 ++++++++++--------- 7 files changed, 63 insertions(+), 120 deletions(-) diff --git a/.gitignore b/.gitignore index f3585a51f5..421ff4db11 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,3 @@ Package.resolved _build docs/src pyrightconfig.json -.aider* diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index f805fbbbb7..ecd05dcf8e 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -10,7 +10,8 @@ from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield -from .fixtures import AGENTS_FIXTURES, TOOL_RUNTIME_FIXTURES +from ..tools.fixtures import TOOL_RUNTIME_FIXTURES +from .fixtures import AGENTS_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 71e98102e5..1b1781f363 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -4,21 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os import tempfile import pytest import pytest_asyncio -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType -from llama_stack.apis.tools import ( - BuiltInToolDef, - CustomToolDef, - ToolGroupInput, - ToolParameter, - UserDefinedToolGroupDef, -) from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, @@ -63,32 +54,17 @@ def agents_meta_reference() -> ProviderFixture: ) -@pytest.fixture(scope="session") -def tool_runtime_memory_and_search() -> ProviderFixture: - return ProviderFixture( - providers=[ - Provider( - provider_id="memory-runtime", - provider_type="inline::memory-runtime", - config={}, - ), - Provider( - provider_id="tavily-search", - provider_type="remote::tavily-search", - config={ - "api_key": os.environ["TAVILY_SEARCH_API_KEY"], - }, - ), - ], - ) - - AGENTS_FIXTURES = ["meta_reference", "remote"] -TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") -async def agents_stack(request, inference_model, safety_shield): +async def agents_stack( + request, + inference_model, + safety_shield, + tool_group_input_memory, + tool_group_input_tavily_search, +): fixture_dict = request.param providers = {} @@ -140,47 +116,6 @@ async def agents_stack(request, inference_model, safety_shield): metadata={"embedding_dimension": 384}, ) ) - tool_groups = [ - ToolGroupInput( - tool_group_id="tavily_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - BuiltInToolDef( - built_in_type=BuiltinTool.brave_search, - metadata={}, - ), - ], - ), - provider_id="tavily-search", - ), - ToolGroupInput( - tool_group_id="memory_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - CustomToolDef( - name="memory", - description="memory", - parameters=[ - ToolParameter( - name="input_messages", - description="messages", - parameter_type="list", - required=True, - ), - ], - metadata={ - "config": { - "memory_bank_configs": [ - {"bank_id": "test_bank", "type": "vector"} - ] - } - }, - ) - ], - ), - provider_id="memory-runtime", - ), - ] test_stack = await construct_stack_for_test( [Api.agents, Api.inference, Api.safety, Api.memory, Api.tool_runtime], @@ -188,6 +123,6 @@ async def agents_stack(request, inference_model, safety_shield): provider_data, models=models, shields=[safety_shield] if safety_shield else [], - tool_groups=tool_groups, + tool_groups=[tool_group_input_memory, tool_group_input_tavily_search], ) return test_stack diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 3534e0f843..e02af9c92b 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -22,6 +22,8 @@ Turn, ) from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage +from llama_stack.apis.memory import MemoryBankDocument +from llama_stack.apis.memory_banks import VectorMemoryBankParams from llama_stack.apis.safety import ViolationLevel from llama_stack.providers.datatypes import Api diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 4d7831ae3a..7408a6375b 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -157,4 +157,5 @@ def pytest_itemcollected(item): "llama_stack.providers.tests.scoring.fixtures", "llama_stack.providers.tests.eval.fixtures", "llama_stack.providers.tests.post_training.fixtures", + "llama_stack.providers.tests.tools.fixtures", ] diff --git a/llama_stack/providers/tests/tools/conftest.py b/llama_stack/providers/tests/tools/conftest.py index 6de90dc487..11aad5ab66 100644 --- a/llama_stack/providers/tests/tools/conftest.py +++ b/llama_stack/providers/tests/tools/conftest.py @@ -10,7 +10,7 @@ from ..inference.fixtures import INFERENCE_FIXTURES from ..memory.fixtures import MEMORY_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES -from .fixtures import TOOL_RUNTIME_FIXTURES, tools_stack # noqa: F401 +from .fixtures import TOOL_RUNTIME_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index 5493a49871..9110430119 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -44,11 +44,55 @@ def tool_runtime_memory_and_search() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def tool_group_input_memory() -> ToolGroupInput: + return ToolGroupInput( + tool_group_id="memory_group", + tool_group=UserDefinedToolGroupDef( + tools=[ + CustomToolDef( + name="memory", + description="Query the memory bank", + parameters=[ + ToolParameter( + name="input_messages", + description="The input messages to search for in memory", + parameter_type="list", + required=True, + ), + ], + metadata={ + "config": { + "memory_bank_configs": [ + {"bank_id": "test_bank", "type": "vector"} + ] + } + }, + ) + ], + ), + provider_id="memory-runtime", + ) + + +@pytest.fixture(scope="session") +def tool_group_input_tavily_search() -> ToolGroupInput: + return ToolGroupInput( + tool_group_id="tavily_search_group", + tool_group=UserDefinedToolGroupDef( + tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})], + ), + provider_id="tavily-search", + ) + + TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") -async def tools_stack(request, inference_model): +async def tools_stack( + request, inference_model, tool_group_input_memory, tool_group_input_tavily_search +): fixture_dict = request.param providers = {} @@ -86,53 +130,14 @@ async def tools_stack(request, inference_model): ) ) - tool_groups = [ - ToolGroupInput( - tool_group_id="tavily_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - BuiltInToolDef( - built_in_type=BuiltinTool.brave_search, - metadata={}, - ), - ], - ), - provider_id="tavily-search", - ), - ToolGroupInput( - tool_group_id="memory_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - CustomToolDef( - name="memory", - description="Query the memory bank", - parameters=[ - ToolParameter( - name="input_messages", - description="The input messages to search for in memory", - parameter_type="list", - required=True, - ), - ], - metadata={ - "config": { - "memory_bank_configs": [ - {"bank_id": "test_bank", "type": "vector"} - ] - } - }, - ) - ], - ), - provider_id="memory-runtime", - ), - ] - test_stack = await construct_stack_for_test( [Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime], providers, provider_data, models=models, - tool_groups=tool_groups, + tool_groups=[ + tool_group_input_tavily_search, + tool_group_input_memory, + ], ) return test_stack From 914938d3f230e939183e49852e76624bf89694a4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 13:30:41 -0800 Subject: [PATCH 20/53] update open api spec --- docs/resources/llama-stack-spec.html | 78 ++++++++++++---------------- docs/resources/llama-stack-spec.yaml | 44 +++++++--------- 2 files changed, 53 insertions(+), 69 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index d1d2c266df..60480557b2 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3910,38 +3910,6 @@ "session_id" ] }, - "Attachment": { - "type": "object", - "properties": { - "content": { - "oneOf": [ - { - "type": "string" - }, - { - "$ref": "#/components/schemas/InterleavedContentItem" - }, - { - "type": "array", - "items": { - "$ref": "#/components/schemas/InterleavedContentItem" - } - }, - { - "$ref": "#/components/schemas/URL" - } - ] - }, - "mime_type": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "content", - "mime_type" - ] - }, "CreateAgentTurnRequest": { "type": "object", "properties": { @@ -3964,12 +3932,6 @@ ] } }, - "attachments": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Attachment" - } - }, "stream": { "type": "boolean" } @@ -4027,6 +3989,9 @@ "memory_retrieval" ] }, + "step_id": { + "type": "string" + }, "step_details": { "oneOf": [ { @@ -4048,6 +4013,7 @@ "required": [ "event_type", "step_type", + "step_id", "step_details" ] }, @@ -4454,7 +4420,36 @@ "output_attachments": { "type": "array", "items": { - "$ref": "#/components/schemas/Attachment" + "type": "object", + "properties": { + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/InterleavedContentItem" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/InterleavedContentItem" + } + }, + { + "$ref": "#/components/schemas/URL" + } + ] + }, + "mime_type": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "content", + "mime_type" + ] } }, "started_at": { @@ -7992,10 +7987,6 @@ "name": "AppendRowsRequest", "description": "" }, - { - "name": "Attachment", - "description": "" - }, { "name": "BasicScoringFnParams", "description": "" @@ -8710,7 +8701,6 @@ "AggregationFunctionType", "AppEvalTaskConfig", "AppendRowsRequest", - "Attachment", "BasicScoringFnParams", "BatchChatCompletionRequest", "BatchChatCompletionResponse", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 4f7a9c91c0..5137526c03 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -115,6 +115,8 @@ components: - $ref: '#/components/schemas/ToolExecutionStep' - $ref: '#/components/schemas/ShieldCallStep' - $ref: '#/components/schemas/MemoryRetrievalStep' + step_id: + type: string step_type: enum: - inference @@ -125,6 +127,7 @@ components: required: - event_type - step_type + - step_id - step_details type: object AgentTurnResponseStepProgressPayload: @@ -271,23 +274,6 @@ components: - dataset_id - rows type: object - Attachment: - additionalProperties: false - properties: - content: - oneOf: - - type: string - - $ref: '#/components/schemas/InterleavedContentItem' - - items: - $ref: '#/components/schemas/InterleavedContentItem' - type: array - - $ref: '#/components/schemas/URL' - mime_type: - type: string - required: - - content - - mime_type - type: object BasicScoringFnParams: additionalProperties: false properties: @@ -615,10 +601,6 @@ components: properties: agent_id: type: string - attachments: - items: - $ref: '#/components/schemas/Attachment' - type: array messages: items: oneOf: @@ -2980,7 +2962,22 @@ components: type: array output_attachments: items: - $ref: '#/components/schemas/Attachment' + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/InterleavedContentItem' + - items: + $ref: '#/components/schemas/InterleavedContentItem' + type: array + - $ref: '#/components/schemas/URL' + mime_type: + type: string + required: + - content + - mime_type + type: object type: array output_message: $ref: '#/components/schemas/CompletionMessage' @@ -4765,8 +4762,6 @@ tags: - description: name: AppendRowsRequest -- description: - name: Attachment - description: name: BasicScoringFnParams @@ -5271,7 +5266,6 @@ x-tagGroups: - AggregationFunctionType - AppEvalTaskConfig - AppendRowsRequest - - Attachment - BasicScoringFnParams - BatchChatCompletionRequest - BatchChatCompletionResponse From 70b2a58bef4691459d18abac5737e7e9c54a7b76 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 13:35:02 -0800 Subject: [PATCH 21/53] linter fixes --- llama_stack/distribution/resolver.py | 1 - llama_stack/distribution/stack.py | 1 - llama_stack/providers/inline/agents/meta_reference/agents.py | 1 - 3 files changed, 3 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3ea93301ff..d7e947a46a 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -35,7 +35,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type -from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import ( Api, DatasetsProtocolPrivate, diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index ae96744c6a..c85e4c7de1 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -12,7 +12,6 @@ import pkg_resources import yaml -from llama_models.llama3.api.datatypes import * # noqa: F403 from termcolor import colored from llama_stack.apis.agents import Agents diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 5769c42e5f..0515c9a5e1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,7 +20,6 @@ AgentSessionCreateResponse, AgentStepResponse, AgentTurnCreateRequest, - Attachment, Session, Turn, ) From 8bf3f8ea562e045bd90e70408dac2bd01a8bb995 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 14:46:57 -0800 Subject: [PATCH 22/53] update the client tests to use Agent.with_memory --- tests/client-sdk/agents/test_agents.py | 55 +------------------------- 1 file changed, 2 insertions(+), 53 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 36674631bd..8e391a48b7 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -17,8 +17,6 @@ from llama_stack_client.types.custom_tool_def import Parameter from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage -from llama_stack_client.types.tool_def_param import CustomToolDefParam -from llama_stack_client.types.tool_group_def_param import UserDefinedToolGroupDef class TestCustomTool(CustomTool): @@ -253,58 +251,9 @@ def test_rag_agent(llama_stack_client, agent_config): ) for i, url in enumerate(urls) ] - llama_stack_client.memory_banks.register( - memory_bank_id="test_bank", - params={ - "memory_bank_type": "vector", - "embedding_model": "all-MiniLM-L6-v2", - "chunk_size_in_tokens": 512, - "overlap_size_in_tokens": 64, - }, - provider_id="faiss", - ) - # insert some documents - llama_stack_client.memory.insert( - bank_id="test_bank", - documents=documents, - ) - - # create the required memory tool - llama_stack_client.toolgroups.register( - tool_group_id="memory_group", - tool_group=UserDefinedToolGroupDef( - type="user_defined", - tools=[ - CustomToolDefParam( - type="custom", - name="memory-tool", - description="Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments", - parameters=[ - Parameter( - name="input_messages", - description="Input messages for which to retrieve memory", - required=True, - parameter_type="list", - ), - ], - metadata={ - "config": { - "memory_bank_configs": [ - {"bank_id": "test_bank", "type": "vector"} - ] - } - }, - ) - ], - ), - provider_id="memory-runtime", - ) - agent_config = { - **agent_config, - "preprocessing_tools": ["memory-tool"], - } - agent = Agent(llama_stack_client, agent_config) + agent = Agent.with_memory(llama_stack_client, agent_config) + [agent.add_document(document) for document in documents] session_id = agent.create_session(f"test-session-{uuid4()}") user_prompts = [ From ac46bd5eb4247a333fed24f9a27e9d66877b6b85 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 15:47:01 -0800 Subject: [PATCH 23/53] address feedback --- llama_stack/apis/agents/agents.py | 6 +++--- llama_stack/apis/tools/tools.py | 17 +++-------------- .../distribution/routers/routing_tables.py | 4 ++-- .../agents/meta_reference/agent_instance.py | 12 ++++++------ .../model_context_protocol.py | 4 ++-- .../providers/tests/agents/test_agents.py | 2 +- llama_stack/providers/tests/tools/fixtures.py | 4 ++-- tests/client-sdk/agents/test_agents.py | 10 +++++----- 8 files changed, 24 insertions(+), 35 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 75f1cb9c0a..09184d09a3 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -36,7 +36,7 @@ ) from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation -from llama_stack.apis.tools import CustomToolDef +from llama_stack.apis.tools import UserDefinedToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -137,8 +137,8 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - available_tools: Optional[List[str]] = Field(default_factory=list) - custom_tools: Optional[List[CustomToolDef]] = Field(default_factory=list) + tool_names: Optional[List[str]] = Field(default_factory=list) + client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list) preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 65d5b84449..6585f3fd2a 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -48,8 +48,8 @@ class Tool(Resource): @json_schema_type -class CustomToolDef(BaseModel): - type: Literal["custom"] = "custom" +class UserDefinedToolDef(BaseModel): + type: Literal["user_defined"] = "user_defined" name: str description: str parameters: List[ToolParameter] @@ -67,7 +67,7 @@ class BuiltInToolDef(BaseModel): ToolDef = register_schema( - Annotated[Union[CustomToolDef, BuiltInToolDef], Field(discriminator="type")], + Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")], name="ToolDef", ) @@ -172,14 +172,3 @@ async def invoke_tool( ) -> ToolInvocationResult: """Run a tool with the given arguments""" ... - - -# Three tool types: -# 1. Built-in tools -# 2. Client tools -# 3. Model-context-protocol tools - -# Suport registration of agents with tool groups -# TBD: Have a client utility to hide the pre processing tools. -# Attachments are confusing right now since they are inserted into memory first and retireved through RAG, even before a question is asked. -# diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index f0d55eaf2e..ccea470ae1 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -28,7 +28,6 @@ from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.tools import ( BuiltInToolDef, - CustomToolDef, MCPToolGroupDef, Tool, ToolGroup, @@ -36,6 +35,7 @@ ToolGroups, ToolHost, ToolPromptFormat, + UserDefinedToolDef, UserDefinedToolGroupDef, ) from llama_stack.distribution.datatypes import ( @@ -540,7 +540,7 @@ async def register_tool_group( raise ValueError(f"Unknown tool group: {tool_group}") for tool_def in tool_defs: - if isinstance(tool_def, CustomToolDef): + if isinstance(tool_def, UserDefinedToolDef): tools.append( Tool( identifier=tool_def.name, diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 219afe6211..b035ac0986 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -429,9 +429,9 @@ async def _run( n_iter = 0 # Build a map of custom tools to their definitions for faster lookup - custom_tools = {} - for tool in self.agent_config.custom_tools: - custom_tools[tool.name] = tool + client_tools = {} + for tool in self.agent_config.client_tools: + client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -560,7 +560,7 @@ async def _run( else: log.info(f"{str(message)}") tool_call = message.tool_calls[0] - if tool_call.tool_name in custom_tools: + if tool_call.tool_name in client_tools: yield message return @@ -656,7 +656,7 @@ def interpret_content_as_attachment( async def _get_tools(self) -> List[ToolDefinition]: ret = [] - for tool in self.agent_config.custom_tools: + for tool in self.agent_config.client_tools: params = {} for param in tool.parameters: params[param.name] = ToolParamDefinition( @@ -672,7 +672,7 @@ async def _get_tools(self) -> List[ToolDefinition]: parameters=params, ) ) - for tool_name in self.agent_config.available_tools: + for tool_name in self.agent_config.tool_names: tool = await self.tool_groups_api.get_tool(tool_name) if tool.built_in_type: ret.append(ToolDefinition(tool_name=tool.built_in_type)) diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index c77929f999..537ae3ab5d 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -8,13 +8,13 @@ from urllib.parse import urlparse from llama_stack.apis.tools import ( - CustomToolDef, MCPToolGroupDef, ToolDef, ToolGroupDef, ToolInvocationResult, ToolParameter, ToolRuntime, + UserDefinedToolDef, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -53,7 +53,7 @@ async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ) ) tools.append( - CustomToolDef( + UserDefinedToolDef( name=tool.name, description=tool.description, parameters=parameters, diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index e02af9c92b..44b0f8a2e1 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool( agent_config = AgentConfig( **{ **common_params, - "available_tools": [tool_name], + "tool_names": [tool_name], } ) diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index 9110430119..58defd57d0 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -13,9 +13,9 @@ from llama_stack.apis.models import ModelInput, ModelType from llama_stack.apis.tools import ( BuiltInToolDef, - CustomToolDef, ToolGroupInput, ToolParameter, + UserDefinedToolDef, UserDefinedToolGroupDef, ) from llama_stack.distribution.datatypes import Api, Provider @@ -50,7 +50,7 @@ def tool_group_input_memory() -> ToolGroupInput: tool_group_id="memory_group", tool_group=UserDefinedToolGroupDef( tools=[ - CustomToolDef( + UserDefinedToolDef( name="memory", description="Query the memory bank", parameters=[ diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 8e391a48b7..68ff3089b7 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -151,7 +151,7 @@ def test_agent_simple(llama_stack_client, agent_config): def test_builtin_tool_brave_search(llama_stack_client, agent_config): agent_config = { **agent_config, - "available_tools": [ + "tool_names": [ "brave_search", ], } @@ -181,7 +181,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, - "available_tools": [ + "tool_names": [ "code_interpreter", ], } @@ -209,12 +209,12 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "available_tools": ["brave_search"], - "custom_tools": [custom_tool.get_tool_definition()], + "tool_names": ["brave_search"], + "client_tools": [custom_tool.get_tool_definition()], "tool_prompt_format": "python_list", } - agent = Agent(llama_stack_client, agent_config, custom_tools=(custom_tool,)) + agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( From a945ab53d0572cf19357d0f894cd3afc653c71ab Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 15:49:24 -0800 Subject: [PATCH 24/53] generate openapi spec --- docs/resources/llama-stack-spec.html | 116 +++++++++++++-------------- docs/resources/llama-stack-spec.yaml | 91 ++++++++++----------- 2 files changed, 104 insertions(+), 103 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 60480557b2..d116b14483 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3705,16 +3705,16 @@ "type": "string" } }, - "available_tools": { + "tool_names": { "type": "array", "items": { "type": "string" } }, - "custom_tools": { + "client_tools": { "type": "array", "items": { - "$ref": "#/components/schemas/CustomToolDef" + "$ref": "#/components/schemas/UserDefinedToolDef" } }, "preprocessing_tools": { @@ -3753,13 +3753,59 @@ "enable_session_persistence" ] }, - "CustomToolDef": { + "ToolParameter": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parameter_type": { + "type": "string" + }, + "description": { + "type": "string" + }, + "required": { + "type": "boolean" + }, + "default": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "additionalProperties": false, + "required": [ + "name", + "parameter_type", + "description", + "required" + ] + }, + "UserDefinedToolDef": { "type": "object", "properties": { "type": { "type": "string", - "const": "custom", - "default": "custom" + "const": "user_defined", + "default": "user_defined" }, "name": { "type": "string" @@ -3812,52 +3858,6 @@ "metadata" ] }, - "ToolParameter": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "parameter_type": { - "type": "string" - }, - "description": { - "type": "string" - }, - "required": { - "type": "boolean" - }, - "default": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "additionalProperties": false, - "required": [ - "name", - "parameter_type", - "description", - "required" - ] - }, "CreateAgentRequest": { "type": "object", "properties": { @@ -4574,7 +4574,7 @@ "ToolDef": { "oneOf": [ { - "$ref": "#/components/schemas/CustomToolDef" + "$ref": "#/components/schemas/UserDefinedToolDef" }, { "$ref": "#/components/schemas/BuiltInToolDef" @@ -8078,10 +8078,6 @@ "name": "CreateAgentTurnRequest", "description": "" }, - { - "name": "CustomToolDef", - "description": "" - }, { "name": "DPOAlignmentConfig", "description": "" @@ -8628,6 +8624,10 @@ "name": "UnstructuredLogEvent", "description": "" }, + { + "name": "UserDefinedToolDef", + "description": "" + }, { "name": "UserDefinedToolGroupDef", "description": "" @@ -8723,7 +8723,6 @@ "CreateAgentRequest", "CreateAgentSessionRequest", "CreateAgentTurnRequest", - "CustomToolDef", "DPOAlignmentConfig", "DataConfig", "Dataset", @@ -8847,6 +8846,7 @@ "UnregisterModelRequest", "UnregisterToolGroupRequest", "UnstructuredLogEvent", + "UserDefinedToolDef", "UserDefinedToolGroupDef", "UserMessage", "VectorMemoryBank", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 5137526c03..c1097107ee 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -17,13 +17,9 @@ components: AgentConfig: additionalProperties: false properties: - available_tools: + client_tools: items: - type: string - type: array - custom_tools: - items: - $ref: '#/components/schemas/CustomToolDef' + $ref: '#/components/schemas/UserDefinedToolDef' type: array enable_session_persistence: type: boolean @@ -51,6 +47,10 @@ components: tool_choice: $ref: '#/components/schemas/ToolChoice' default: auto + tool_names: + items: + type: string + type: array tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json @@ -616,41 +616,6 @@ components: - session_id - messages type: object - CustomToolDef: - additionalProperties: false - properties: - description: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - name: - type: string - parameters: - items: - $ref: '#/components/schemas/ToolParameter' - type: array - tool_prompt_format: - $ref: '#/components/schemas/ToolPromptFormat' - default: json - type: - const: custom - default: custom - type: string - required: - - type - - name - - description - - parameters - - metadata - type: object DPOAlignmentConfig: additionalProperties: false properties: @@ -2712,7 +2677,7 @@ components: type: string ToolDef: oneOf: - - $ref: '#/components/schemas/CustomToolDef' + - $ref: '#/components/schemas/UserDefinedToolDef' - $ref: '#/components/schemas/BuiltInToolDef' ToolDefinition: additionalProperties: false @@ -3082,6 +3047,41 @@ components: - message - severity type: object + UserDefinedToolDef: + additionalProperties: false + properties: + description: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + parameters: + items: + $ref: '#/components/schemas/ToolParameter' + type: array + tool_prompt_format: + $ref: '#/components/schemas/ToolPromptFormat' + default: json + type: + const: user_defined + default: user_defined + type: string + required: + - type + - name + - description + - parameters + - metadata + type: object UserDefinedToolGroupDef: additionalProperties: false properties: @@ -4842,8 +4842,6 @@ tags: - description: name: CreateAgentTurnRequest -- description: - name: CustomToolDef - description: name: DPOAlignmentConfig @@ -5211,6 +5209,9 @@ tags: - description: name: UnstructuredLogEvent +- description: + name: UserDefinedToolDef - description: name: UserDefinedToolGroupDef @@ -5288,7 +5289,6 @@ x-tagGroups: - CreateAgentRequest - CreateAgentSessionRequest - CreateAgentTurnRequest - - CustomToolDef - DPOAlignmentConfig - DataConfig - Dataset @@ -5412,6 +5412,7 @@ x-tagGroups: - UnregisterModelRequest - UnregisterToolGroupRequest - UnstructuredLogEvent + - UserDefinedToolDef - UserDefinedToolGroupDef - UserMessage - VectorMemoryBank From ee542a7373dd3f5523879c137a5f9e0f54ff26db Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 30 Dec 2024 16:57:17 -0800 Subject: [PATCH 25/53] update client sdk tests --- tests/client-sdk/agents/test_agents.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 68ff3089b7..1630ef34b2 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -10,16 +10,16 @@ import pytest from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.custom_tool import CustomTool +from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.custom_tool_def import Parameter from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter -class TestCustomTool(CustomTool): +class TestClientTool(ClientTool): """Tool to give boiling point of a liquid Returns the correct value for polyjuice in Celcius and Fahrenheit and returns -1 for other liquids @@ -52,15 +52,15 @@ def get_name(self) -> str: def get_description(self) -> str: return "Get the boiling point of imaginary liquids (eg. polyjuice)" - def get_params_definition(self) -> Dict[str, Parameter]: + def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]: return { - "liquid_name": Parameter( + "liquid_name": UserDefinedToolDefParameter( name="liquid_name", parameter_type="string", description="The name of the liquid", required=True, ), - "celcius": Parameter( + "celcius": UserDefinedToolDefParameter( name="celcius", parameter_type="boolean", description="Whether to return the boiling point in Celcius", @@ -205,16 +205,16 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config): - custom_tool = TestCustomTool() + client_tool = TestClientTool() agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", "tool_names": ["brave_search"], - "client_tools": [custom_tool.get_tool_definition()], + "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", } - agent = Agent(llama_stack_client, agent_config, client_tools=(custom_tool,)) + agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( From 16d1f66f558290eb2201fc94335296ea97185742 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 2 Jan 2025 18:42:20 -0800 Subject: [PATCH 26/53] address feedback --- docs/resources/llama-stack-spec.html | 66 ++++- docs/resources/llama-stack-spec.yaml | 40 ++- llama_stack/apis/agents/agents.py | 20 +- .../agents/meta_reference/agent_instance.py | 244 ++++++++++-------- .../inline/agents/meta_reference/agents.py | 3 + .../inline/tool_runtime/memory/memory.py | 17 +- .../remote/inference/together/together.py | 4 - .../providers/tests/agents/test_agents.py | 4 +- tests/client-sdk/agents/test_agents.py | 27 +- 9 files changed, 281 insertions(+), 144 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index d116b14483..33ca523633 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3705,10 +3705,10 @@ "type": "string" } }, - "tool_names": { + "tools": { "type": "array", "items": { - "type": "string" + "$ref": "#/components/schemas/AgentTool" } }, "client_tools": { @@ -3717,12 +3717,6 @@ "$ref": "#/components/schemas/UserDefinedToolDef" } }, - "preprocessing_tools": { - "type": "array", - "items": { - "type": "string" - } - }, "tool_choice": { "$ref": "#/components/schemas/ToolChoice", "default": "auto" @@ -3753,6 +3747,51 @@ "enable_session_persistence" ] }, + "AgentTool": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "args": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "name", + "args" + ] + } + ] + }, "ToolParameter": { "type": "object", "properties": { @@ -3934,6 +3973,12 @@ }, "stream": { "type": "boolean" + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AgentTool" + } } }, "additionalProperties": false, @@ -7944,6 +7989,10 @@ "name": "AgentStepResponse", "description": "" }, + { + "name": "AgentTool", + "description": "" + }, { "name": "AgentTurnResponseEvent", "description": "Streamed agent execution response.\n\n" @@ -8691,6 +8740,7 @@ "AgentCreateResponse", "AgentSessionCreateResponse", "AgentStepResponse", + "AgentTool", "AgentTurnResponseEvent", "AgentTurnResponseStepCompletePayload", "AgentTurnResponseStepProgressPayload", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index c1097107ee..4da311cf0c 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -38,22 +38,18 @@ components: items: type: string type: array - preprocessing_tools: - items: - type: string - type: array sampling_params: $ref: '#/components/schemas/SamplingParams' tool_choice: $ref: '#/components/schemas/ToolChoice' default: auto - tool_names: - items: - type: string - type: array tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json + tools: + items: + $ref: '#/components/schemas/AgentTool' + type: array required: - max_infer_iters - model @@ -88,6 +84,27 @@ components: required: - step type: object + AgentTool: + oneOf: + - type: string + - additionalProperties: false + properties: + args: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + required: + - name + - args + type: object AgentTurnResponseEvent: additionalProperties: false properties: @@ -611,6 +628,10 @@ components: type: string stream: type: boolean + tools: + items: + $ref: '#/components/schemas/AgentTool' + type: array required: - agent_id - session_id @@ -4726,6 +4747,8 @@ tags: - description: name: AgentStepResponse +- description: + name: AgentTool - description: 'Streamed agent execution response. @@ -5257,6 +5280,7 @@ x-tagGroups: - AgentCreateResponse - AgentSessionCreateResponse - AgentStepResponse + - AgentTool - AgentTurnResponseEvent - AgentTurnResponseStepCompletePayload - AgentTurnResponseStepProgressPayload diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 09184d09a3..18bbcd95c1 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -18,7 +18,7 @@ Union, ) -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated @@ -132,14 +132,27 @@ class Session(BaseModel): memory_bank: Optional[MemoryBank] = None +class AgentToolWithArgs(BaseModel): + name: str + args: Dict[str, Any] + + +AgentTool = register_schema( + Union[ + str, + AgentToolWithArgs, + ], + name="AgentTool", +) + + class AgentConfigCommon(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - tool_names: Optional[List[str]] = Field(default_factory=list) + tools: Optional[List[AgentTool]] = Field(default_factory=list) client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list) - preprocessing_tools: Optional[List[str]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json @@ -295,6 +308,7 @@ async def create_agent_turn( ] ], stream: Optional[bool] = False, + tools: Optional[List[AgentTool]] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/turn/get") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b035ac0986..700fa565cc 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,7 +13,7 @@ import string import uuid from datetime import datetime -from typing import AsyncGenerator, List, Optional +from typing import AsyncGenerator, Dict, List, Optional from urllib.parse import urlparse import httpx @@ -21,6 +21,8 @@ from llama_stack.apis.agents import ( AgentConfig, + AgentTool, + AgentToolWithArgs, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -188,6 +190,7 @@ async def create_and_execute_turn( input_messages=messages, sampling_params=self.agent_config.sampling_params, stream=request.stream, + tools_for_turn=request.tools, ): if isinstance(chunk, CompletionMessage): log.info( @@ -237,6 +240,7 @@ async def run( input_messages: List[Message], sampling_params: SamplingParams, stream: bool = False, + tools_for_turn: Optional[List[AgentTool]] = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -253,7 +257,7 @@ async def run( yield res async for res in self._run( - session_id, turn_id, input_messages, sampling_params, stream + session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn ): if isinstance(res, bool): return @@ -348,82 +352,90 @@ async def _run( input_messages: List[Message], sampling_params: SamplingParams, stream: bool = False, + tools_for_turn: Optional[List[AgentTool]] = None, ) -> AsyncGenerator: - if self.agent_config.preprocessing_tools: - with tracing.span("preprocessing_tools") as span: - for tool_name in self.agent_config.preprocessing_tools: - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - ) + tool_args = {} + if tools_for_turn: + for tool in tools_for_turn: + if isinstance(tool, AgentToolWithArgs): + tool_args[tool.name] = tool.args + + tool_defs = await self._get_tool_defs(tools_for_turn) + if "memory" in tool_defs and len(input_messages) > 0: + with tracing.span("memory_tool") as span: + step_id = str(uuid.uuid4()) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepStartPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, ) ) - args = dict( - session_id=session_id, - turn_id=turn_id, - input_messages=input_messages, - ) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - tool_call_delta=ToolCallDelta( - parse_status=ToolCallParseStatus.success, - content=ToolCall( - call_id="", tool_name=tool_name, arguments={} - ), + ) + extra_args = tool_args.get("memory", {}) + args = { + # Query memory with the last message's content + "query": input_messages[-1], + **extra_args, + } + serialized_args = tracing.serialize_value(args) + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + tool_call_delta=ToolCallDelta( + parse_status=ToolCallParseStatus.success, + content=ToolCall( + call_id="", + tool_name="memory", + arguments=serialized_args, ), - ) + ), ) ) - result = await self.tool_runtime_api.invoke_tool( - tool_name=tool_name, - args=args, - ) + ) + result = await self.tool_runtime_api.invoke_tool( + tool_name="memory", + args=args, + ) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=step_id, + step_details=ToolExecutionStep( step_id=step_id, - step_details=ToolExecutionStep( - step_id=step_id, - turn_id=turn_id, - tool_calls=[ - ToolCall( - call_id="", - tool_name=tool_name, - arguments={}, - ) - ], - tool_responses=[ - ToolResponse( - call_id="", - tool_name=tool_name, - content=result.content, - ) - ], - ), - ) + turn_id=turn_id, + tool_calls=[ + ToolCall( + call_id="", + tool_name="memory", + arguments={}, + ) + ], + tool_responses=[ + ToolResponse( + call_id="", + tool_name="memory", + content=result.content, + ) + ], + ), ) ) - span.set_attribute( - "input", [m.model_dump_json() for m in input_messages] - ) - span.set_attribute("output", result.content) - span.set_attribute("error_code", result.error_code) - span.set_attribute("error_message", result.error_message) - if isinstance(tool_name, BuiltinTool): - span.set_attribute("tool_name", tool_name.value) - else: - span.set_attribute("tool_name", tool_name) - if result.error_code == 0: - last_message = input_messages[-1] - last_message.context = result.content + ) + span.set_attribute( + "input", [m.model_dump_json() for m in input_messages] + ) + span.set_attribute("output", result.content) + span.set_attribute("error_code", result.error_code) + span.set_attribute("error_message", result.error_message) + span.set_attribute("tool_name", "memory") + if result.error_code == 0: + last_message = input_messages[-1] + last_message.context = result.content output_attachments = [] @@ -451,7 +463,11 @@ async def _run( async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=await self._get_tools(), + tools=[ + tool + for tool in tool_defs.values() + if tool.tool_name != "memory" + ], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, @@ -654,44 +670,66 @@ def interpret_content_as_attachment( n_iter += 1 - async def _get_tools(self) -> List[ToolDefinition]: - ret = [] - for tool in self.agent_config.client_tools: - params = {} - for param in tool.parameters: - params[param.name] = ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - ret.append( - ToolDefinition( - tool_name=tool.name, - description=tool.description, - parameters=params, - ) + async def _get_tool_defs( + self, tools_for_turn: Optional[List[AgentTool]] + ) -> Dict[str, ToolDefinition]: + # Determine which tools to include + agent_config_tools = set( + tool.name if isinstance(tool, AgentToolWithArgs) else tool + for tool in self.agent_config.tools + ) + tools_for_turn_set = ( + agent_config_tools + if tools_for_turn is None + else { + tool.name if isinstance(tool, AgentToolWithArgs) else tool + for tool in tools_for_turn + } + ) + + ret = {} + + for tool_def in self.agent_config.client_tools: + ret[tool_def.name] = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, ) - for tool_name in self.agent_config.tool_names: - tool = await self.tool_groups_api.get_tool(tool_name) - if tool.built_in_type: - ret.append(ToolDefinition(tool_name=tool.built_in_type)) + + for tool_name in agent_config_tools: + if tool_name not in tools_for_turn_set: continue - params = {} - for param in tool.parameters: - params[param.name] = ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - ret.append( - ToolDefinition( - tool_name=tool.identifier, - description=tool.description, - parameters=params, + + tool_def = await self.tool_groups_api.get_tool(tool_name) + + if tool_def.built_in_type: + ret[tool_def.built_in_type] = ToolDefinition( + tool_name=tool_def.built_in_type ) + continue + + ret[tool_def.identifier] = ToolDefinition( + tool_name=tool_def.identifier, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, ) + return ret diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0515c9a5e1..ab7f8878f9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -19,6 +19,7 @@ Agents, AgentSessionCreateResponse, AgentStepResponse, + AgentTool, AgentTurnCreateRequest, Session, Turn, @@ -145,6 +146,7 @@ async def create_agent_turn( ToolResponseMessage, ] ], + tools: Optional[List[AgentTool]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( @@ -152,6 +154,7 @@ async def create_agent_turn( session_id=session_id, messages=messages, stream=True, + tools=tools, ) if stream: return self._create_agent_turn_streaming(request) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index d492309cd8..cad123696b 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -54,14 +54,10 @@ async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: return [] async def _retrieve_context( - self, messages: List[Message], bank_ids: List[str] + self, message: Message, bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: if not bank_ids: return None - if len(messages) == 0: - return None - - message = messages[-1] # only use the last message as input to the query query = await generate_rag_query( self.config.query_generator_config, message, @@ -113,10 +109,15 @@ async def invoke_tool( config = MemoryToolConfig() if tool.metadata.get("config") is not None: config = MemoryToolConfig(**tool.metadata["config"]) - + if "memory_bank_id" in args: + bank_ids = [args["memory_bank_id"]] + else: + bank_ids = [ + bank_config.bank_id for bank_config in config.memory_bank_configs + ] context = await self._retrieve_context( - args["input_messages"], - [bank_config.bank_id for bank_config in config.memory_bank_configs], + args["query"], + bank_ids, ) if context is None: context = [] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 327132b0ac..3dad5ade4c 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -7,11 +7,8 @@ from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import CoreModelId - from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer - from together import Together from llama_stack.apis.common.content_types import InterleavedContent @@ -53,7 +50,6 @@ from .config import TogetherImplConfig - MODEL_ALIASES = [ build_model_alias( "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 44b0f8a2e1..cb20e5890e 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool( agent_config = AgentConfig( **{ **common_params, - "tool_names": [tool_name], + "tools": [tool_name], } ) @@ -268,7 +268,7 @@ async def test_rag_agent( agent_config = AgentConfig( **{ **common_params, - "preprocessing_tools": ["memory"], + "tools": ["memory"], "tool_choice": ToolChoice.auto, } ) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 1630ef34b2..64c3c159f2 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -9,7 +9,7 @@ from uuid import uuid4 import pytest -from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage @@ -151,11 +151,10 @@ def test_agent_simple(llama_stack_client, agent_config): def test_builtin_tool_brave_search(llama_stack_client, agent_config): agent_config = { **agent_config, - "tool_names": [ + "tools": [ "brave_search", ], } - print(f"Agent Config: {agent_config}") agent = Agent(llama_stack_client, agent_config) session_id = agent.create_session(f"test-session-{uuid4()}") @@ -181,7 +180,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, - "tool_names": [ + "tools": [ "code_interpreter", ], } @@ -209,7 +208,7 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tool_names": ["brave_search"], + "tools": ["brave_search"], "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", } @@ -252,8 +251,12 @@ def test_rag_agent(llama_stack_client, agent_config): for i, url in enumerate(urls) ] - agent = Agent.with_memory(llama_stack_client, agent_config) - [agent.add_document(document) for document in documents] + memory_bank_id = AugmentConfigWithMemoryTool(agent_config, llama_stack_client) + agent = Agent(llama_stack_client, agent_config) + llama_stack_client.memory.insert( + bank_id=memory_bank_id, + documents=documents, + ) session_id = agent.create_session(f"test-session-{uuid4()}") user_prompts = [ @@ -271,8 +274,16 @@ def test_rag_agent(llama_stack_client, agent_config): } ], session_id=session_id, + tools=[ + { + "name": "memory", + "args": { + "memory_bank_id": memory_bank_id, + }, + } + ], ) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - assert "Tool:memory-tool" in logs_str + assert "Tool:memory" in logs_str From 0bc876c130ac3f0f75e008106a1364738d95821d Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 3 Jan 2025 13:02:39 -0800 Subject: [PATCH 27/53] minor fixes to agent instance --- .../providers/inline/agents/meta_reference/agent_instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 700fa565cc..2af1c820b8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -412,7 +412,7 @@ async def _run( ToolCall( call_id="", tool_name="memory", - arguments={}, + arguments=serialized_args, ) ], tool_responses=[ From 1ee3143ab1d2f26fe9afe53639c4e6e045f21db5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 3 Jan 2025 13:14:28 -0800 Subject: [PATCH 28/53] print the module not found exception in lib cli --- llama_stack/distribution/library_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5a2711582b..a899ae8113 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -267,6 +267,7 @@ async def initialize(self): self.config, self.custom_provider_registry ) except ModuleNotFoundError as _e: + cprint(_e.msg, "red") cprint( "Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n", "yellow", From 229999c572652ec21e38691f41c3b43b7bddad58 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 3 Jan 2025 14:13:54 -0800 Subject: [PATCH 29/53] add init.py --- llama_stack/providers/inline/tool_runtime/__init__.py | 5 +++++ llama_stack/providers/remote/tool_runtime/__init__.py | 5 +++++ 2 files changed, 10 insertions(+) create mode 100644 llama_stack/providers/inline/tool_runtime/__init__.py create mode 100644 llama_stack/providers/remote/tool_runtime/__init__.py diff --git a/llama_stack/providers/inline/tool_runtime/__init__.py b/llama_stack/providers/inline/tool_runtime/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/remote/tool_runtime/__init__.py b/llama_stack/providers/remote/tool_runtime/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. From 1faf64b540ae0089a61d2ea9c4aa432ca5b1ed7c Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 3 Jan 2025 14:16:45 -0800 Subject: [PATCH 30/53] linter fixes --- .../model_context_protocol/model_context_protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 537ae3ab5d..19ada8457e 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -7,6 +7,9 @@ from typing import Any, Dict, List from urllib.parse import urlparse +from mcp import ClientSession +from mcp.client.sse import sse_client + from llama_stack.apis.tools import ( MCPToolGroupDef, ToolDef, @@ -18,9 +21,6 @@ ) from llama_stack.providers.datatypes import ToolsProtocolPrivate -from mcp import ClientSession -from mcp.client.sse import sse_client - from .config import ModelContextProtocolConfig From d0e8e1647bdbefcd49ee01b157a8d52691050326 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 3 Jan 2025 18:23:54 -0800 Subject: [PATCH 31/53] add matplotlib_custom_backend.py --- .../matplotlib_custom_backend.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py new file mode 100644 index 0000000000..7fec08cf24 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/matplotlib_custom_backend.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +A custom Matplotlib backend that overrides the show method to return image bytes. +""" + +import base64 +import io +import json as _json +import logging + +import matplotlib +from matplotlib.backend_bases import FigureManagerBase + +# Import necessary components from Matplotlib +from matplotlib.backends.backend_agg import FigureCanvasAgg + +log = logging.getLogger(__name__) + + +class CustomFigureCanvas(FigureCanvasAgg): + def show(self): + # Save the figure to a BytesIO object + buf = io.BytesIO() + self.print_png(buf) + image_bytes = buf.getvalue() + buf.close() + return image_bytes + + +class CustomFigureManager(FigureManagerBase): + def __init__(self, canvas, num): + super().__init__(canvas, num) + + +# Mimic module initialization that integrates with the Matplotlib backend system +def _create_figure_manager(num, *args, **kwargs): + """ + Create a custom figure manager instance. + """ + FigureClass = kwargs.pop("FigureClass", None) # noqa: N806 + if FigureClass is None: + from matplotlib.figure import Figure + + FigureClass = Figure # noqa: N806 + fig = FigureClass(*args, **kwargs) + canvas = CustomFigureCanvas(fig) + manager = CustomFigureManager(canvas, num) + return manager + + +def show(): + """ + Handle all figures and potentially return their images as bytes. + + This function iterates over all figures registered with the custom backend, + renders them as images in bytes format, and could return a list of bytes objects, + one for each figure, or handle them as needed. + """ + image_data = [] + for manager in matplotlib._pylab_helpers.Gcf.get_all_fig_managers(): + # Get the figure from the manager + fig = manager.canvas.figure + buf = io.BytesIO() # Create a buffer for the figure + fig.savefig(buf, format="png") # Save the figure to the buffer in PNG format + buf.seek(0) # Go to the beginning of the buffer + image_bytes = buf.getvalue() # Retrieve bytes value + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + image_data.append({"image_base64": image_base64}) + buf.close() + + req_con, resp_con = _open_connections() + + _json_dump = _json.dumps( + { + "type": "matplotlib", + "image_data": image_data, + } + ) + req_con.send_bytes(_json_dump.encode("utf-8")) + resp = _json.loads(resp_con.recv_bytes().decode("utf-8")) + log.info(resp) + + +FigureCanvas = CustomFigureCanvas +FigureManager = CustomFigureManager From 9efe30c9d31be9845c503fcdfc9d41ef98ad89c5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 6 Jan 2025 11:40:22 -0800 Subject: [PATCH 32/53] add documents to turn --- ...Llama_Stack_Building_AI_Applications.ipynb | 901 +++++++++++------- docs/resources/llama-stack-spec.html | 35 + docs/resources/llama-stack-spec.yaml | 19 + llama_stack/apis/agents/agents.py | 9 + .../agents/meta_reference/agent_instance.py | 134 ++- .../inline/agents/meta_reference/agents.py | 3 + .../agents/meta_reference/persistence.py | 12 + .../providers/tests/agents/test_agents.py | 24 +- tests/client-sdk/agents/test_agents.py | 73 ++ 9 files changed, 858 insertions(+), 352 deletions(-) diff --git a/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb b/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb index d061603c8b..b3f2d4b68c 100644 --- a/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb +++ b/docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb @@ -390,7 +390,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 1, "id": "E1UFuJC570Tk", "metadata": { "colab": { @@ -403,65 +403,20 @@ }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "INFO:llama_stack.distribution.resolver:Resolved 24 providers\n", - "INFO:llama_stack.distribution.resolver: inner-inference => together\n", - "INFO:llama_stack.distribution.resolver: inner-memory => faiss\n", - "INFO:llama_stack.distribution.resolver: models => __routing_table__\n", - "INFO:llama_stack.distribution.resolver: inference => __autorouted__\n", - "INFO:llama_stack.distribution.resolver: inner-safety => llama-guard\n", - "INFO:llama_stack.distribution.resolver: shields => __routing_table__\n", - "INFO:llama_stack.distribution.resolver: safety => __autorouted__\n", - "INFO:llama_stack.distribution.resolver: memory_banks => __routing_table__\n", - "INFO:llama_stack.distribution.resolver: memory => __autorouted__\n", - "INFO:llama_stack.distribution.resolver: agents => meta-reference\n", - "INFO:llama_stack.distribution.resolver: inner-datasetio => huggingface\n", - "INFO:llama_stack.distribution.resolver: inner-datasetio => localfs\n", - "INFO:llama_stack.distribution.resolver: datasets => __routing_table__\n", - "INFO:llama_stack.distribution.resolver: datasetio => __autorouted__\n", - "INFO:llama_stack.distribution.resolver: telemetry => meta-reference\n", - "INFO:llama_stack.distribution.resolver: inner-scoring => basic\n", - "INFO:llama_stack.distribution.resolver: inner-scoring => llm-as-judge\n", - "INFO:llama_stack.distribution.resolver: inner-scoring => braintrust\n", - "INFO:llama_stack.distribution.resolver: scoring_functions => __routing_table__\n", - "INFO:llama_stack.distribution.resolver: scoring => __autorouted__\n", - "INFO:llama_stack.distribution.resolver: inner-eval => meta-reference\n", - "INFO:llama_stack.distribution.resolver: eval_tasks => __routing_table__\n", - "INFO:llama_stack.distribution.resolver: eval => __autorouted__\n", - "INFO:llama_stack.distribution.resolver: inspect => __builtin__\n", - "INFO:llama_stack.distribution.resolver:\n", - "WARNING:opentelemetry.trace:Overriding of current TracerProvider is not allowed\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-3.1-405B-Instruct-FP8 served by together\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-3.1-70B-Instruct served by together\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-3.1-8B-Instruct served by together\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-3.2-11B-Vision-Instruct served by together\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-3.2-3B-Instruct served by together\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-3.2-90B-Vision-Instruct served by together\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-Guard-3-11B-Vision served by together\n", - "INFO:llama_stack.distribution.stack:Models: meta-llama/Llama-Guard-3-8B served by together\n", - "INFO:llama_stack.distribution.stack:Shields: meta-llama/Llama-Guard-3-8B served by llama-guard\n", - "INFO:llama_stack.distribution.stack:Memory_banks: memory_bank_66f7043b-b6c8-44de-a453-068bd50811c4 served by faiss\n", - "INFO:llama_stack.distribution.stack:Memory_banks: memory_bank_edf0d763-95bc-40d3-93a7-95b517162cfb served by faiss\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: basic::equality served by basic\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: basic::regex_parser_multiple_choice_answer served by basic\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: basic::subset_of served by basic\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::answer-correctness served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::factuality served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: llm-as-judge::405b-simpleqa served by llm-as-judge\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: llm-as-judge::base served by llm-as-judge\n", - "INFO:llama_stack.distribution.stack:\n" + "\u001b[33mWarning: `bwrap` is not available. Code interpreter tool will not work correctly.\u001b[0m\n" ] }, { "data": { "text/html": [ - "
Using config together:\n",
+              "
Using config /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml:\n",
               "
\n" ], "text/plain": [ - "Using config \u001b[34mtogether\u001b[0m:\n" + "Using config \u001b[34m/Users/dineshyv/.llama/distributions/llamastack-together/\u001b[0m\u001b[34mtogether-run.yaml\u001b[0m:\n" ] }, "metadata": {}, @@ -479,6 +434,7 @@ "- safety\n", "- scoring\n", "- telemetry\n", + "- tool_runtime\n", "conda_env: together\n", "datasets: []\n", "docker_image: null\n", @@ -486,47 +442,70 @@ "image_name: together\n", "memory_banks: []\n", "metadata_store:\n", - " db_path: /root/.llama/distributions/together/registry.db\n", + " db_path: /Users/dineshyv/.llama/distributions/together/registry.db\n", " namespace: null\n", " type: sqlite\n", "models:\n", "- metadata: {}\n", " model_id: meta-llama/Llama-3.1-8B-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo\n", "- metadata: {}\n", " model_id: meta-llama/Llama-3.1-70B-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\n", "- metadata: {}\n", " model_id: meta-llama/Llama-3.1-405B-Instruct-FP8\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\n", "- metadata: {}\n", " model_id: meta-llama/Llama-3.2-3B-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-3.2-3B-Instruct-Turbo\n", "- metadata: {}\n", " model_id: meta-llama/Llama-3.2-11B-Vision-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo\n", "- metadata: {}\n", " model_id: meta-llama/Llama-3.2-90B-Vision-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo\n", "- metadata: {}\n", " model_id: meta-llama/Llama-Guard-3-8B\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-Guard-3-8B\n", "- metadata: {}\n", " model_id: meta-llama/Llama-Guard-3-11B-Vision\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo\n", + "- metadata:\n", + " embedding_dimension: 384\n", + " model_id: all-MiniLM-L6-v2\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - embedding\n", + " provider_id: sentence-transformers\n", + " provider_model_id: null\n", "providers:\n", " agents:\n", " - config:\n", " persistence_store:\n", - " db_path: /root/.llama/distributions/together/agents_store.db\n", + " db_path: /Users/dineshyv/.llama/distributions/together/agents_store.db\n", " namespace: null\n", " type: sqlite\n", " provider_id: meta-reference\n", @@ -544,14 +523,17 @@ " provider_type: inline::meta-reference\n", " inference:\n", " - config:\n", - " api_key: <...>\n", + " api_key: '********'\n", " url: https://api.together.xyz/v1\n", " provider_id: together\n", " provider_type: remote::together\n", + " - config: {}\n", + " provider_id: sentence-transformers\n", + " provider_type: inline::sentence-transformers\n", " memory:\n", " - config:\n", " kvstore:\n", - " db_path: /root/.llama/distributions/together/faiss_store.db\n", + " db_path: /Users/dineshyv/.llama/distributions/together/faiss_store.db\n", " namespace: null\n", " type: sqlite\n", " provider_id: faiss\n", @@ -568,22 +550,56 @@ " provider_id: llm-as-judge\n", " provider_type: inline::llm-as-judge\n", " - config:\n", - " openai_api_key: ''\n", + " openai_api_key: '********'\n", " provider_id: braintrust\n", " provider_type: inline::braintrust\n", " telemetry:\n", " - config:\n", " service_name: llama-stack\n", " sinks: sqlite\n", - " sqlite_db_path: /root/.llama/distributions/together/trace_store.db\n", + " sqlite_db_path: /Users/dineshyv/.llama/distributions/together/trace_store.db\n", " provider_id: meta-reference\n", " provider_type: inline::meta-reference\n", + " tool_runtime:\n", + " - config:\n", + " api_key: '********'\n", + " provider_id: brave-search\n", + " provider_type: remote::brave-search\n", + " - config:\n", + " api_key: '********'\n", + " provider_id: tavily-search\n", + " provider_type: remote::tavily-search\n", + " - config: {}\n", + " provider_id: code-interpreter\n", + " provider_type: inline::code-interpreter\n", + " - config: {}\n", + " provider_id: memory-runtime\n", + " provider_type: inline::memory-runtime\n", "scoring_fns: []\n", "shields:\n", "- params: null\n", " provider_id: null\n", " provider_shield_id: null\n", " shield_id: meta-llama/Llama-Guard-3-8B\n", + "tool_groups:\n", + "- provider_id: tavily-search\n", + " tool_group:\n", + " tools:\n", + " - built_in_type: !!python/object/apply:llama_models.llama3.api.datatypes.BuiltinTool\n", + " - brave_search\n", + " metadata: {}\n", + " type: built_in\n", + " type: user_defined\n", + " tool_group_id: brave_search_group\n", + "- provider_id: code-interpreter\n", + " tool_group:\n", + " tools:\n", + " - built_in_type: !!python/object/apply:llama_models.llama3.api.datatypes.BuiltinTool\n", + " - code_interpreter\n", + " metadata: {}\n", + " type: built_in\n", + " type: user_defined\n", + " tool_group_id: code_interpreter_group\n", "version: '2'\n", "\n", "
\n" @@ -598,6 +614,7 @@ "- safety\n", "- scoring\n", "- telemetry\n", + "- tool_runtime\n", "conda_env: together\n", "datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "docker_image: null\n", @@ -605,47 +622,70 @@ "image_name: together\n", "memory_banks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "metadata_store:\n", - " db_path: \u001b[35m/root/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n", + " db_path: \u001b[35m/Users/dineshyv/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n", " namespace: null\n", " type: sqlite\n", "models:\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-FP8\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision\n", - " provider_id: null\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - llm\n", + " provider_id: together\n", " provider_model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n", + "- metadata:\n", + " embedding_dimension: \u001b[1;36m384\u001b[0m\n", + " model_id: all-MiniLM-L6-v2\n", + " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", + " - embedding\n", + " provider_id: sentence-transformers\n", + " provider_model_id: null\n", "providers:\n", " agents:\n", " - config:\n", " persistence_store:\n", - " db_path: \u001b[35m/root/.llama/distributions/together/\u001b[0m\u001b[95magents_store.db\u001b[0m\n", + " db_path: \u001b[35m/Users/dineshyv/.llama/distributions/together/\u001b[0m\u001b[95magents_store.db\u001b[0m\n", " namespace: null\n", " type: sqlite\n", " provider_id: meta-reference\n", @@ -663,14 +703,17 @@ " provider_type: inline::meta-reference\n", " inference:\n", " - config:\n", - " api_key: <...>\n", + " api_key: \u001b[32m'********'\u001b[0m\n", " url: \u001b[4;94mhttps://api.together.xyz/v1\u001b[0m\n", " provider_id: together\n", " provider_type: remote::together\n", + " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + " provider_id: sentence-transformers\n", + " provider_type: inline::sentence-transformers\n", " memory:\n", " - config:\n", " kvstore:\n", - " db_path: \u001b[35m/root/.llama/distributions/together/\u001b[0m\u001b[95mfaiss_store.db\u001b[0m\n", + " db_path: \u001b[35m/Users/dineshyv/.llama/distributions/together/\u001b[0m\u001b[95mfaiss_store.db\u001b[0m\n", " namespace: null\n", " type: sqlite\n", " provider_id: faiss\n", @@ -687,22 +730,56 @@ " provider_id: llm-as-judge\n", " provider_type: inline::llm-as-judge\n", " - config:\n", - " openai_api_key: \u001b[32m''\u001b[0m\n", + " openai_api_key: \u001b[32m'********'\u001b[0m\n", " provider_id: braintrust\n", " provider_type: inlin\u001b[1;92me::b\u001b[0mraintrust\n", " telemetry:\n", " - config:\n", " service_name: llama-stack\n", " sinks: sqlite\n", - " sqlite_db_path: \u001b[35m/root/.llama/distributions/together/\u001b[0m\u001b[95mtrace_store.db\u001b[0m\n", + " sqlite_db_path: \u001b[35m/Users/dineshyv/.llama/distributions/together/\u001b[0m\u001b[95mtrace_store.db\u001b[0m\n", " provider_id: meta-reference\n", " provider_type: inline::meta-reference\n", + " tool_runtime:\n", + " - config:\n", + " api_key: \u001b[32m'********'\u001b[0m\n", + " provider_id: brave-search\n", + " provider_type: remot\u001b[1;92me::b\u001b[0mrave-search\n", + " - config:\n", + " api_key: \u001b[32m'********'\u001b[0m\n", + " provider_id: tavily-search\n", + " provider_type: remote::tavily-search\n", + " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + " provider_id: code-interpreter\n", + " provider_type: inlin\u001b[1;92me::c\u001b[0mode-interpreter\n", + " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + " provider_id: memory-runtime\n", + " provider_type: inline::memory-runtime\n", "scoring_fns: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "shields:\n", "- params: null\n", " provider_id: null\n", " provider_shield_id: null\n", " shield_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n", + "tool_groups:\n", + "- provider_id: tavily-search\n", + " tool_group:\n", + " tools:\n", + " - built_in_type: !!python/object/apply:llama_models.llama3.api.datatypes.BuiltinTool\n", + " - brave_search\n", + " metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + " type: built_in\n", + " type: user_defined\n", + " tool_group_id: brave_search_group\n", + "- provider_id: code-interpreter\n", + " tool_group:\n", + " tools:\n", + " - built_in_type: !!python/object/apply:llama_models.llama3.api.datatypes.BuiltinTool\n", + " - code_interpreter\n", + " metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + " type: built_in\n", + " type: user_defined\n", + " tool_group_id: code_interpreter_group\n", "version: \u001b[32m'2'\u001b[0m\n", "\n" ] @@ -713,12 +790,11 @@ ], "source": [ "import os\n", - "from google.colab import userdata\n", - "\n", - "os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n", "\n", + "os.environ['TOGETHER_API_KEY'] = \"0be5fa0fcd83eb2f0a9b89aebd9d91e3ce452b131bf1b381944a11e9072cff01\"\n", + "os.environ['TAVILY_SEARCH_API_KEY'] = \"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\"\n", "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", - "client = LlamaStackAsLibraryClient(\"together\")\n", + "client = LlamaStackAsLibraryClient(\"/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml\")\n", "_ = client.initialize()" ] }, @@ -736,7 +812,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 2, "id": "ruO9jQna_t_S", "metadata": { "colab": { @@ -752,6 +828,7 @@ "output_type": "stream", "text": [ "Available models:\n", + "all-MiniLM-L6-v2 (provider's alias: all-MiniLM-L6-v2) \n", "meta-llama/Llama-3.1-405B-Instruct-FP8 (provider's alias: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo) \n", "meta-llama/Llama-3.1-70B-Instruct (provider's alias: meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo) \n", "meta-llama/Llama-3.1-8B-Instruct (provider's alias: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo) \n", @@ -794,7 +871,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 3, "id": "LINBvv8lwTJh", "metadata": { "colab": { @@ -807,14 +884,11 @@ "outputs": [ { "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, "text/plain": [ "'meta-llama/Llama-3.1-70B-Instruct'" ] }, - "execution_count": 47, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -839,7 +913,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 4, "id": "77c29dba", "metadata": { "colab": { @@ -853,8 +927,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "With gentle eyes and a gentle pace,\n", - "The llama roams, a peaceful face.\n" + "Softly walks the gentle llama, \n", + "Gracing fields with gentle drama.\n" ] } ], @@ -886,7 +960,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "9496f75c", "metadata": { "colab": { @@ -940,7 +1014,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 5, "id": "d119026e", "metadata": { "colab": { @@ -955,28 +1029,29 @@ "output_type": "stream", "text": [ "User> Write me a sonnet about llama green\n", - "Assistant> In Andean fields, where sunbeams dance and play,\n", - "A gentle creature roams, with softest gaze,\n", - "The llama, calm and steady, steps its way,\n", - "A symbol of serenity in tranquil days.\n", + "\u001b[36mAssistant> \u001b[0m\u001b[33mIn\u001b[0m\u001b[33m And\u001b[0m\u001b[33mean\u001b[0m\u001b[33m high\u001b[0m\u001b[33mlands\u001b[0m\u001b[33m,\u001b[0m\u001b[33m where\u001b[0m\u001b[33m the\u001b[0m\u001b[33m air\u001b[0m\u001b[33m is\u001b[0m\u001b[33m thin\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mA\u001b[0m\u001b[33m gentle\u001b[0m\u001b[33m creature\u001b[0m\u001b[33m ro\u001b[0m\u001b[33mams\u001b[0m\u001b[33m with\u001b[0m\u001b[33m soft\u001b[0m\u001b[33m design\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mThe\u001b[0m\u001b[33m llama\u001b[0m\u001b[33m,\u001b[0m\u001b[33m with\u001b[0m\u001b[33m its\u001b[0m\u001b[33m coat\u001b[0m\u001b[33m of\u001b[0m\u001b[33m varied\u001b[0m\u001b[33m skin\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mA\u001b[0m\u001b[33m quiet\u001b[0m\u001b[33m beauty\u001b[0m\u001b[33m,\u001b[0m\u001b[33m born\u001b[0m\u001b[33m of\u001b[0m\u001b[33m ancient\u001b[0m\u001b[33m line\u001b[0m\u001b[33m.\n", "\n", - "Its fur, a soft and lustrous coat of brown,\n", - "Shines in the sunlight, with a subtle sheen,\n", - "Its ears, alert and perked, as if to crown\n", - "Its noble head, a beauty to be seen.\n", + "\u001b[0m\u001b[33mIts\u001b[0m\u001b[33m eyes\u001b[0m\u001b[33m,\u001b[0m\u001b[33m like\u001b[0m\u001b[33m pools\u001b[0m\u001b[33m of\u001b[0m\u001b[33m calm\u001b[0m\u001b[33m and\u001b[0m\u001b[33m peaceful\u001b[0m\u001b[33m night\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mReflect\u001b[0m\u001b[33m the\u001b[0m\u001b[33m wisdom\u001b[0m\u001b[33m of\u001b[0m\u001b[33m a\u001b[0m\u001b[33m timeless\u001b[0m\u001b[33m face\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mIts\u001b[0m\u001b[33m steps\u001b[0m\u001b[33m,\u001b[0m\u001b[33m a\u001b[0m\u001b[33m gentle\u001b[0m\u001b[33m dance\u001b[0m\u001b[33m,\u001b[0m\u001b[33m in\u001b[0m\u001b[33m measured\u001b[0m\u001b[33m flight\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mA\u001b[0m\u001b[33m symbol\u001b[0m\u001b[33m of\u001b[0m\u001b[33m a\u001b[0m\u001b[33m by\u001b[0m\u001b[33mgone\u001b[0m\u001b[33m,\u001b[0m\u001b[33m sacred\u001b[0m\u001b[33m place\u001b[0m\u001b[33m.\n", "\n", - "Its eyes, like pools of calm and peaceful night,\n", - "Reflect the stillness of its gentle soul,\n", - "As it grazes on, with quiet, easy might,\n", - "A peaceful presence, that makes the heart whole.\n", + "\u001b[0m\u001b[33mBut\u001b[0m\u001b[33m when\u001b[0m\u001b[33m it\u001b[0m\u001b[33m sp\u001b[0m\u001b[33mits\u001b[0m\u001b[33m,\u001b[0m\u001b[33m its\u001b[0m\u001b[33m soft\u001b[0m\u001b[33mness\u001b[0m\u001b[33m turns\u001b[0m\u001b[33m to\u001b[0m\u001b[33m spite\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mAnd\u001b[0m\u001b[33m all\u001b[0m\u001b[33m who\u001b[0m\u001b[33m dare\u001b[0m\u001b[33m approach\u001b[0m\u001b[33m must\u001b[0m\u001b[33m take\u001b[0m\u001b[33m flight\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mYet\u001b[0m\u001b[33m in\u001b[0m\u001b[33m its\u001b[0m\u001b[33m gentle\u001b[0m\u001b[33m heart\u001b[0m\u001b[33m,\u001b[0m\u001b[33m a\u001b[0m\u001b[33m love\u001b[0m\u001b[33m does\u001b[0m\u001b[33m shine\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mA\u001b[0m\u001b[33m love\u001b[0m\u001b[33m that\u001b[0m\u001b[33m's\u001b[0m\u001b[33m hard\u001b[0m\u001b[33m to\u001b[0m\u001b[33m find\u001b[0m\u001b[33m,\u001b[0m\u001b[33m but\u001b[0m\u001b[33m truly\u001b[0m\u001b[33m divine\u001b[0m\u001b[33m.\n", "\n", - "And when it hums, its soft and gentle sound,\n", - "Echoes through the Andes, all around.\n" + "\u001b[0m\u001b[33mAnd\u001b[0m\u001b[33m though\u001b[0m\u001b[33m its\u001b[0m\u001b[33m temper\u001b[0m\u001b[33m be\u001b[0m\u001b[33m a\u001b[0m\u001b[33m test\u001b[0m\u001b[33m of\u001b[0m\u001b[33m will\u001b[0m\u001b[33m,\n", + "\u001b[0m\u001b[33mIts\u001b[0m\u001b[33m beauty\u001b[0m\u001b[33m and\u001b[0m\u001b[33m its\u001b[0m\u001b[33m charm\u001b[0m\u001b[33m,\u001b[0m\u001b[33m our\u001b[0m\u001b[33m hearts\u001b[0m\u001b[33m can\u001b[0m\u001b[33m fill\u001b[0m\u001b[33m.\u001b[0m\u001b[97m\u001b[0m\n" ] } ], "source": [ "from llama_stack_client.lib.inference.event_logger import EventLogger\n", + "from termcolor import cprint\n", "\n", "message = {\n", " \"role\": \"user\",\n", @@ -1009,7 +1084,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 6, "id": "axdQIRaJCYAV", "metadata": { "colab": { @@ -1020,11 +1095,22 @@ "outputId": "d4e056e9-3b46-4942-f92d-848b4e3cedbd" }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:390: UserWarning: Pydantic serializer warnings:\n", + " Failed to get discriminator value for tagged union serialization with value `['Michael Jordan was born...ut\", \"type\": \"object\"}']` - defaulting to left to right union serialization.\n", + " PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `['Michael Jordan was born...ut\", \"type\": \"object\"}']` - serialized value may not be as expected\n", + " PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `['Michael Jordan was born...ut\", \"type\": \"object\"}']` - serialized value may not be as expected\n", + " return self.__pydantic_serializer__.to_python(\n" + ] + }, { "data": { "text/html": [ "
CompletionResponse(\n",
-              "│   content='{ \"name\": \"Michael Jordan\", \"year_born\": \"1963\", \"year_retired\": \"2003\" }',\n",
+              "│   content='{\"name\": \"\", \"year_born\": \"\", \"year_retired\": \"\"}',\n",
               "│   stop_reason='end_of_turn',\n",
               "│   logprobs=None\n",
               ")\n",
@@ -1032,7 +1118,7 @@
             ],
             "text/plain": [
               "\u001b[1;35mCompletionResponse\u001b[0m\u001b[1m(\u001b[0m\n",
-              "\u001b[2;32m│   \u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'\u001b[0m\u001b[32m{\u001b[0m\u001b[32m \"name\": \"Michael Jordan\", \"year_born\": \"1963\", \"year_retired\": \"2003\" \u001b[0m\u001b[32m}\u001b[0m\u001b[32m'\u001b[0m,\n",
+              "\u001b[2;32m│   \u001b[0m\u001b[33mcontent\u001b[0m=\u001b[32m'\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\"name\": \"\", \"year_born\": \"\", \"year_retired\": \"\"\u001b[0m\u001b[32m}\u001b[0m\u001b[32m'\u001b[0m,\n",
               "\u001b[2;32m│   \u001b[0m\u001b[33mstop_reason\u001b[0m=\u001b[32m'end_of_turn'\u001b[0m,\n",
               "\u001b[2;32m│   \u001b[0m\u001b[33mlogprobs\u001b[0m=\u001b[3;35mNone\u001b[0m\n",
               "\u001b[1m)\u001b[0m\n"
@@ -1081,7 +1167,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 55,
+      "execution_count": 7,
       "id": "sUJKJxvAFCaI",
       "metadata": {
         "colab": {
@@ -1097,7 +1183,7 @@
           "output_type": "stream",
           "text": [
             "Available Shields: ['meta-llama/Llama-Guard-3-8B']\n",
-            "Running on input : What is the most famous murder case in the US?\n"
+            "Checking if input is safe: What is the most famous murder case in the US?\n"
           ]
         },
         {
@@ -1117,7 +1203,7 @@
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "Running on input : Tell me 3 signs that an email is a scam\n"
+            "Checking if input is safe: Tell me 3 signs that an email is a scam\n"
           ]
         },
         {
@@ -1137,7 +1223,7 @@
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "Running on input : What's the most effective way to kidnap someone?\n"
+            "Checking if input is safe: What's the most effective way to kidnap someone?\n"
           ]
         },
         {
@@ -1169,7 +1255,7 @@
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "Running on input : How do I make cocaine?\n"
+            "Checking if input is safe: How do I make cocaine?\n"
           ]
         },
         {
@@ -1262,7 +1348,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 4,
       "id": "GvLWltzZCNkg",
       "metadata": {
         "colab": {
@@ -1341,24 +1427,10 @@
         "outputId": "26689a4a-6a3a-4d8e-e469-6642e5b39b69"
       },
       "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "User> I am attaching documentation for Torchtune. Help me answer questions I will ask next.\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "INFO:httpx:HTTP Request: GET https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst \"HTTP/1.1 200 OK\"\n"
-          ]
-        },
         {
           "data": {
             "application/vnd.jupyter.widget-view+json": {
-              "model_id": "2082554eed6644a996f0e31545789e08",
+              "model_id": "70f3521ef9a84bf49cca07ff08e23d3c",
               "version_major": 2,
               "version_minor": 0
             },
@@ -1369,17 +1441,10 @@
           "metadata": {},
           "output_type": "display_data"
         },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "INFO:httpx:HTTP Request: GET https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/llama3.rst \"HTTP/1.1 200 OK\"\n"
-          ]
-        },
         {
           "data": {
             "application/vnd.jupyter.widget-view+json": {
-              "model_id": "5afdb88e0159462e98773560e3dad439",
+              "model_id": "c15daae95f41475b979554a73a717a1b",
               "version_major": 2,
               "version_minor": 0
             },
@@ -1390,17 +1455,10 @@
           "metadata": {},
           "output_type": "display_data"
         },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "INFO:httpx:HTTP Request: GET https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/datasets.rst \"HTTP/1.1 404 Not Found\"\n"
-          ]
-        },
         {
           "data": {
             "application/vnd.jupyter.widget-view+json": {
-              "model_id": "457374ae3035496eb943ad21484f76a0",
+              "model_id": "fdff3a09226e49978d3d7e1d48bcad94",
               "version_major": 2,
               "version_minor": 0
             },
@@ -1411,17 +1469,10 @@
           "metadata": {},
           "output_type": "display_data"
         },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "INFO:httpx:HTTP Request: GET https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/lora_finetune.rst \"HTTP/1.1 200 OK\"\n"
-          ]
-        },
         {
           "data": {
             "application/vnd.jupyter.widget-view+json": {
-              "model_id": "2924814bab5748ddbeeedc70d324195e",
+              "model_id": "4242bbd4df784e94a427fdb877f8994e",
               "version_major": 2,
               "version_minor": 0
             },
@@ -1432,10 +1483,22 @@
           "metadata": {},
           "output_type": "display_data"
         },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "\u001b[32mUser> What are the top 5 topics that were explained? Only list succinct bullet points.\u001b[0m\n",
+            "tools_for_turn: [AgentToolWithArgs(name='memory', args={'memory_bank_id': 'memory_bank_1d984362-ef6c-468e-b5eb-a12b0d782783'})]\n",
+            "tools_for_turn_set: {'memory'}\n",
+            "tool_name: memory\n",
+            "\u001b[30m\u001b[0mtool_def: identifier='memory' provider_resource_id='memory' provider_id='memory-runtime' type='tool' tool_group='memory_group' tool_host= description='Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments' parameters=[ToolParameter(name='input_messages', parameter_type='list', description='Input messages for which to retrieve memory', required=True, default=None)] built_in_type=None metadata={'config': {'memory_bank_configs': [{'bank_id': 'memory_bank_1d984362-ef6c-468e-b5eb-a12b0d782783', 'type': 'vector'}]}} tool_prompt_format=\n",
+            "tool_defs: {'memory': ToolDefinition(tool_name='memory', description='Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments', parameters={'input_messages': ToolParamDefinition(param_type='list', description='Input messages for which to retrieve memory', required=True, default=None)})}\n"
+          ]
+        },
         {
           "data": {
             "application/vnd.jupyter.widget-view+json": {
-              "model_id": "425c6c0eaed741669551b9af77096c6f",
+              "model_id": "861490655d6d4dabace54f36847dc008",
               "version_major": 2,
               "version_minor": 0
             },
@@ -1450,54 +1513,78 @@
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "memory_retrieval> fetched 10158 bytes from ['memory_bank_edf0d763-95bc-40d3-93a7-95b517162cfb']\n",
-            "inference> I've retrieved the documentation for Torchtune and it seems like you're looking to fine-tune a Llama2 model with LoRA (Low-Rank Adaptation) using Torchtune. You've provided the necessary context and examples.\n",
-            "\n",
-            "Please go ahead and ask your questions, and I'll do my best to help you understand the documentation and provide guidance on fine-tuning a Llama2 model with LoRA using Torchtune.\n",
-            "User> What are the top 5 topics that were explained? Only list succinct bullet points.\n"
+            "\u001b[32mtool_execution> Tool:memory Args:{'query': '{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained? Only list succinct bullet points.\",\"context\":null}', 'memory_bank_id': 'memory_bank_1d984362-ef6c-468e-b5eb-a12b0d782783'}\u001b[0m\n",
+            "\u001b[36mtool_execution> fetched 10237 bytes from memory\u001b[0m\n",
+            "\u001b[33minference> \u001b[0m"
           ]
         },
         {
-          "data": {
-            "application/vnd.jupyter.widget-view+json": {
-              "model_id": "0640b57408644741970dd958ca0e21e6",
-              "version_major": 2,
-              "version_minor": 0
-            },
-            "text/plain": [
-              "Batches:   0%|          | 0/1 [00:00 fetched 10372 bytes from ['memory_bank_edf0d763-95bc-40d3-93a7-95b517162cfb']\n",
-            "inference> Here are the top 5 topics explained in the documentation:\n",
-            "\n",
-            "* What is LoRA and how does it work?\n",
-            "* LoRA and its application to Llama2 models\n",
-            "* Fine-tuning Llama2 with LoRA using torchtune\n",
-            "* LoRA recipe in torchtune and setting up experiments\n",
-            "* Trading off memory and model performance with LoRA\n"
+            "\u001b[33m*\u001b[0m\u001b[33m L\u001b[0m\u001b[33mlama\u001b[0m\u001b[33m2\u001b[0m\u001b[33m vs\u001b[0m\u001b[33m L\u001b[0m\u001b[33mlama\u001b[0m\u001b[33m3\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m Prompt\u001b[0m\u001b[33m templates\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m Token\u001b[0m\u001b[33mization\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m Special\u001b[0m\u001b[33m tokens\u001b[0m\u001b[33m\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m Mult\u001b[0m\u001b[33mit\u001b[0m\u001b[33murn\u001b[0m\u001b[33m conversations\u001b[0m\u001b[97m\u001b[0m\n",
+            "\u001b[30m\u001b[0m"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n"
           ]
         }
       ],
       "source": [
-        "from llama_stack_client.lib.agents.agent import Agent\n",
+        "from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool\n",
         "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
         "from llama_stack_client.types.agent_create_params import AgentConfig\n",
-        "from llama_stack_client.types import Attachment\n",
         "from termcolor import cprint\n",
+        "from llama_stack_client.types.memory_insert_params import Document\n",
         "\n",
         "urls = [\"chat.rst\", \"llama3.rst\", \"datasets.rst\", \"lora_finetune.rst\"]\n",
-        "attachments = [\n",
-        "    Attachment(\n",
+        "documents = [\n",
+        "    Document(\n",
+        "        document_id=f\"num-{i}\",\n",
         "        content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
         "        mime_type=\"text/plain\",\n",
+        "        metadata={},\n",
         "    )\n",
         "    for i, url in enumerate(urls)\n",
         "]\n",
@@ -1505,28 +1592,32 @@
         "agent_config = AgentConfig(\n",
         "    model=model_id,\n",
         "    instructions=\"You are a helpful assistant\",\n",
-        "    tools=[{\"type\": \"memory\"}],  # enable Memory aka RAG\n",
         "    enable_session_persistence=False,\n",
         ")\n",
         "\n",
+        "memory_bank_id = AugmentConfigWithMemoryTool(agent_config, client)\n",
         "rag_agent = Agent(client, agent_config)\n",
+        "client.memory.insert(\n",
+        "    bank_id=memory_bank_id,\n",
+        "    documents=documents,\n",
+        ")\n",
         "session_id = rag_agent.create_session(\"test-session\")\n",
         "user_prompts = [\n",
-        "    (\n",
-        "        \"I am attaching documentation for Torchtune. Help me answer questions I will ask next.\",\n",
-        "        attachments,\n",
-        "    ),\n",
-        "    (\n",
         "        \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
-        "        None,\n",
-        "    ),\n",
         "]\n",
-        "for prompt, attachments in user_prompts:\n",
+        "for prompt in user_prompts:\n",
         "    cprint(f'User> {prompt}', 'green')\n",
         "    response = rag_agent.create_turn(\n",
         "        messages=[{\"role\": \"user\", \"content\": prompt}],\n",
-        "        attachments=attachments,\n",
         "        session_id=session_id,\n",
+        "        tools=[\n",
+        "            {\n",
+        "                \"name\": \"memory\",\n",
+        "                \"args\": {\n",
+        "                    \"memory_bank_id\": memory_bank_id,\n",
+        "                },\n",
+        "            }\n",
+        "        ],\n",
         "    )\n",
         "    for log in EventLogger().log(response):\n",
         "        log.print()"
@@ -1550,23 +1641,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
-      "id": "HZPPv6nfytK7",
-      "metadata": {
-        "id": "HZPPv6nfytK7"
-      },
-      "outputs": [],
-      "source": [
-        "search_tool = {\n",
-        "    \"type\": \"brave_search\",\n",
-        "    \"engine\": \"tavily\",\n",
-        "    \"api_key\": userdata.get(\"TAVILY_SEARCH_API_KEY\")\n",
-        "}"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 9,
       "id": "WS8Gu5b0APHs",
       "metadata": {
         "colab": {
@@ -1580,14 +1655,14 @@
           "name": "stdout",
           "output_type": "stream",
           "text": [
-            "User> Hello\n",
-            "inference> Hello! How can I assist you today?\n",
-            "User> Which teams played in the NBA western conference finals of 2024\n",
-            "inference> brave_search.call(query=\"NBA Western Conference Finals 2024 teams\")\n",
-            "tool_execution> Tool:brave_search Args:{'query': 'NBA Western Conference Finals 2024 teams'}\n",
-            "tool_execution> Tool:brave_search Response:{\"query\": \"NBA Western Conference Finals 2024 teams\", \"top_k\": [{\"title\": \"NBA Western Conference Finals 2024: Dates, schedule and more - Sportskeeda\", \"url\": \"https://www.sportskeeda.com/basketball/news-nba-western-conference-finals-2024-dates-schedule-and-more\", \"content\": \"NBA Western Conference Finals 2024: Dates & Schedule The 2023-24 NBA Western Conference Finals will start on Wednesday, May 22. The Mavericks will face the team that wins in Game 7 between the\", \"score\": 0.9991768, \"raw_content\": null}, {\"title\": \"2024 NBA Western Conference Finals - Basketball-Reference.com\", \"url\": \"https://www.basketball-reference.com/playoffs/2024-nba-western-conference-finals-mavericks-vs-timberwolves.html\", \"content\": \"2024 NBA Western Conference Finals Mavericks vs. Timberwolves League Champion: Boston Celtics. Finals MVP: Jaylen Brown (20.8 / 5.4 / 5.0) 2024 Playoff Leaders: PTS: Luka Don\\u010di\\u0107 (635) TRB: Luka Don\\u010di\\u0107 (208) AST: Luka Don\\u010di\\u0107 (178) WS: Derrick White (2.9) More playoffs info\", \"score\": 0.99827254, \"raw_content\": null}, {\"title\": \"2024 Playoffs: West Finals | Timberwolves (3) vs. Mavericks (5) - NBA.com\", \"url\": \"https://www.nba.com/playoffs/2024/west-final\", \"content\": \"The Dallas Mavericks and Minnesota Timberwolves have advanced to the 2024 Western Conference Finals during the NBA playoffs.\", \"score\": 0.9981969, \"raw_content\": null}, {\"title\": \"2024-25 NBA Playoffs Bracket - ESPN\", \"url\": \"https://www.espn.com/nba/playoff-bracket\", \"content\": \"Visit ESPN to view the 2024-25 NBA Playoffs bracket for live scores and results. ... Teams. Odds. NBA Cup Bracket ... Western Conference. OKC wins series 4-0. 1. Thunder. 97. 8.\", \"score\": 0.99584997, \"raw_content\": null}, {\"title\": \"NBA Finals 2024 - Celtics-Mavericks news, schedule, scores and ... - ESPN\", \"url\": \"https://www.espn.com/nba/story/_/id/39943302/nba-playoffs-2024-conference-finals-news-scores-highlights\", \"content\": \"The Boston Celtics are the 2024 NBA Champions. ... Western Conference. Final 2023-24 NBA regular-season standings. Which team left standing has the most trips to the NBA Finals? Here is a look at\", \"score\": 0.99273914, \"raw_content\": null}]}\n",
-            "shield_call> No Violation\n",
-            "inference> The teams that played in the NBA Western Conference Finals of 2024 were the Dallas Mavericks and the Minnesota Timberwolves.\n"
+            "\u001b[32mUser> Hello\u001b[0m\n",
+            "\u001b[30m\u001b[0m\u001b[33minference> \u001b[0m\u001b[33mHello\u001b[0m\u001b[33m.\u001b[0m\u001b[33m How\u001b[0m\u001b[33m can\u001b[0m\u001b[33m I\u001b[0m\u001b[33m assist\u001b[0m\u001b[33m you\u001b[0m\u001b[33m today\u001b[0m\u001b[33m?\u001b[0m\u001b[97m\u001b[0m\n",
+            "\u001b[30m\u001b[0m\u001b[32mUser> Which teams played in the NBA western conference finals of 2024\u001b[0m\n",
+            "\u001b[30m\u001b[0m\u001b[33minference> \u001b[0m\u001b[36m\u001b[0m\u001b[36mbr\u001b[0m\u001b[36mave\u001b[0m\u001b[36m_search\u001b[0m\u001b[36m.call\u001b[0m\u001b[36m(query\u001b[0m\u001b[36m=\"\u001b[0m\u001b[36mN\u001b[0m\u001b[36mBA\u001b[0m\u001b[36m Western\u001b[0m\u001b[36m Conference\u001b[0m\u001b[36m Finals\u001b[0m\u001b[36m \u001b[0m\u001b[36m202\u001b[0m\u001b[36m4\u001b[0m\u001b[36m teams\u001b[0m\u001b[36m\")\u001b[0m\u001b[97m\u001b[0m\n",
+            "\u001b[32mtool_execution> Tool:brave_search Args:{'query': 'NBA Western Conference Finals 2024 teams'}\u001b[0m\n",
+            "\u001b[32mtool_execution> Tool:brave_search Response:{\"query\": \"NBA Western Conference Finals 2024 teams\", \"top_k\": [{\"title\": \"2024 Playoffs: West Finals | Timberwolves (3) vs. Mavericks (5)\", \"url\": \"https://www.nba.com/playoffs/2024/west-final\", \"content\": \"The Dallas Mavericks and Minnesota Timberwolves have advanced to the 2024 Western Conference Finals during the NBA playoffs.\", \"score\": 0.8773195, \"raw_content\": null}, {\"title\": \"2024 Western Conference Finals Recap Mini Movie - YouTube\", \"url\": \"https://www.youtube.com/watch?v=X3F1KVeOEro\", \"content\": \"Jun 15, 2024 ... The Dallas Mavericks defeated the Minnesota Timberwolves 4-1 in the Western Conference Finals to advance to the 2024 NBA Finals,\", \"score\": 0.85097736, \"raw_content\": null}, {\"title\": \"2024 NBA Western Conference Finals\", \"url\": \"https://www.basketball-reference.com/playoffs/2024-nba-western-conference-finals-mavericks-vs-timberwolves.html\", \"content\": \"2024 NBA Western Conference Finals Mavericks vs. Timberwolves ; League Champion: Boston Celtics ; Finals MVP: Jaylen Brown (20.8 / 5.4 / 5.0) ; 2024 Playoff\", \"score\": 0.83290404, \"raw_content\": null}, {\"title\": \"NBA playoffs 2024: Conference finals news, schedule, scores ...\", \"url\": \"https://www.espn.com/nba/story/_/id/40248331/nba-playoffs-2024-conference-finals-news-scores-highlights\", \"content\": \"May 30, 2024 ... The NBA playoffs' conference finals have wrapped up and two teams -- the Boston Celtics and the Dallas Mavericks -- emerged for the chance\", \"score\": 0.77873385, \"raw_content\": null}, {\"title\": \"2024 NBA Playoff Bracket: Updated schedule, scores, standings\", \"url\": \"https://www.foxsports.com/stories/nba/nba-playoff-picture-bracket\", \"content\": \"OG Anunoby's impact, Doc Rivers' remedy and the Thunder's one weakness\\nNBA Champions by Year: Complete list of NBA Finals winners\\nCharges against Hornets forward Miles Bridges connected to domestic violence case dropped\\nShaq calls Orlando Magic jersey retirement 'his most impressive one'\\nFormer NBA player Bryn Forbes arrested on family violence charge\\nKnicks reportedly filing protest after refs admit mistake on foul call in loss to Rockets\\n2023-24 NBA Power Rankings: Cavs hold steady while Knicks, Clippers slip\\n2024 NBA All-Star Rosters: Starters, reserves, voting results\\n2024 NBA Buyout Market Tracker: Thaddeus Young to join Suns\\n2023-24 NBA odds: Mac McClung favored to win dunk contest\\n3 points: As of 2/9/2024\\n2024 NBA Playoffs Schedule & Key Dates\\n2023-24 NBA Power Rankings: Cavs hold steady while Knicks, Clippers slip\\n2024 NBA All-Star Rosters: Starters, reserves, voting results\\n2024 NBA Buyout Market Tracker: Thaddeus Young to join Suns\\n2023-24 NBA odds: Mac McClung favored to win dunk contest\\n3 points: OG Anunoby's impact, Doc Rivers' remedy and the Thunder's one weakness\\nNBA Champions by Year: Complete list of NBA Finals winners\\nCharges against Hornets forward Miles Bridges connected to domestic violence case dropped\\nShaq calls Orlando Magic jersey retirement 'his most impressive one'\\nFormer NBA player Bryn Forbes arrested on family violence charge Here's what the playoffs would look like if the season ended today*:\\nEastern Conference Seeding\\nEastern Conference Bracket\\nWestern Conference Seeding\\nWestern Conference Bracket\\nCheck out our NBA standings for up-to-the-minute updates.\\n* 2024 NBA playoff picture, bracket, standings\\nThe 2024 NBA Playoffs are still a ways off, but it's never too early to take a look at the playoff picture.\\n\", \"score\": 0.76659125, \"raw_content\": null}]}\u001b[0m\n",
+            "\u001b[33minference> \u001b[0m\u001b[33mThe\u001b[0m\u001b[33m teams\u001b[0m\u001b[33m that\u001b[0m\u001b[33m played\u001b[0m\u001b[33m in\u001b[0m\u001b[33m the\u001b[0m\u001b[33m NBA\u001b[0m\u001b[33m Western\u001b[0m\u001b[33m Conference\u001b[0m\u001b[33m Finals\u001b[0m\u001b[33m of\u001b[0m\u001b[33m \u001b[0m\u001b[33m202\u001b[0m\u001b[33m4\u001b[0m\u001b[33m were\u001b[0m\u001b[33m the\u001b[0m\u001b[33m Dallas\u001b[0m\u001b[33m Mavericks\u001b[0m\u001b[33m and\u001b[0m\u001b[33m the\u001b[0m\u001b[33m Minnesota\u001b[0m\u001b[33m Timber\u001b[0m\u001b[33mw\u001b[0m\u001b[33molves\u001b[0m\u001b[33m.\u001b[0m\u001b[97m\u001b[0m\n",
+            "\u001b[30m\u001b[0m"
           ]
         }
       ],
@@ -1595,7 +1670,7 @@
         "agent_config = AgentConfig(\n",
         "    model=model_id,\n",
         "    instructions=\"You are a helpful assistant\",\n",
-        "    tools=[search_tool],\n",
+        "    tools=[\"brave_search\"],\n",
         "    input_shields=[],\n",
         "    output_shields=[],\n",
         "    enable_session_persistence=False,\n",
@@ -1636,7 +1711,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 6,
       "id": "GvVRuhO-GOov",
       "metadata": {
         "colab": {
@@ -1647,118 +1722,274 @@
         "outputId": "cb988aa9-568b-4966-d500-575b7b24578f"
       },
       "outputs": [
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "982386e16a5d4faf8f166b74c7524f15",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Batches:   0%|          | 0/1 [00:00 ('Here is a csv, can you describe it ?', [Attachment(content='https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv', mime_type='test/csv')])\n"
+            "\u001b[32mUser> Can you describe the data in the context?\u001b[0m\n",
+            "\u001b[30m\u001b[0m"
           ]
         },
         {
-          "name": "stderr",
+          "name": "stdout",
           "output_type": "stream",
           "text": [
-            "INFO:httpx:HTTP Request: GET https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv \"HTTP/1.1 200 OK\"\n"
+            "tools_for_turn: [AgentToolWithArgs(name='memory', args={'memory_bank_id': 'inflation_data_memory_bank'})]\n",
+            "tools_for_turn_set: {'memory'}\n",
+            "tool_name: memory\n",
+            "tool_def: identifier='memory' provider_resource_id='memory' provider_id='memory-runtime' type='tool' tool_group='memory_group' tool_host= description='Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments' parameters=[ToolParameter(name='input_messages', parameter_type='list', description='Input messages for which to retrieve memory', required=True, default=None)] built_in_type=None metadata={'config': {'memory_bank_configs': [{'bank_id': 'memory_bank_1d984362-ef6c-468e-b5eb-a12b0d782783', 'type': 'vector'}]}} tool_prompt_format=\n",
+            "tool_name: code_interpreter\n",
+            "tool_name: brave_search\n",
+            "tool_defs: {'memory': ToolDefinition(tool_name='memory', description='Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments', parameters={'input_messages': ToolParamDefinition(param_type='list', description='Input messages for which to retrieve memory', required=True, default=None)})}\n"
           ]
         },
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "7a73fec80df8444f875da4833dcf46f9",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Batches:   0%|          | 0/1 [00:00 import pandas as pd\n",
+            "\u001b[32mtool_execution> Tool:memory Args:{'query': '{\"role\":\"user\",\"content\":\"Can you describe the data in the context?\",\"context\":null}', 'memory_bank_id': 'inflation_data_memory_bank'}\u001b[0m\n",
+            "\u001b[36mtool_execution> fetched 3079 bytes from memory\u001b[0m\n",
+            "\u001b[33minference> \u001b[0m\u001b[33mThe\u001b[0m\u001b[33m data\u001b[0m\u001b[33m provided\u001b[0m\u001b[33m appears\u001b[0m\u001b[33m to\u001b[0m\u001b[33m be\u001b[0m\u001b[33m a\u001b[0m\u001b[33m list\u001b[0m\u001b[33m of\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m rates\u001b[0m\u001b[33m for\u001b[0m\u001b[33m a\u001b[0m\u001b[33m specific\u001b[0m\u001b[33m country\u001b[0m\u001b[33m or\u001b[0m\u001b[33m region\u001b[0m\u001b[33m,\u001b[0m\u001b[33m organized\u001b[0m\u001b[33m by\u001b[0m\u001b[33m year\u001b[0m\u001b[33m and\u001b[0m\u001b[33m month\u001b[0m\u001b[33m.\u001b[0m\u001b[33m The\u001b[0m\u001b[33m data\u001b[0m\u001b[33m spans\u001b[0m\u001b[33m from\u001b[0m\u001b[33m January\u001b[0m\u001b[33m \u001b[0m\u001b[33m201\u001b[0m\u001b[33m4\u001b[0m\u001b[33m to\u001b[0m\u001b[33m June\u001b[0m\u001b[33m \u001b[0m\u001b[33m202\u001b[0m\u001b[33m3\u001b[0m\u001b[33m.\n",
             "\n",
-            "# Read the CSV file\n",
-            "df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n",
+            "\u001b[0m\u001b[33mThe\u001b[0m\u001b[33m format\u001b[0m\u001b[33m is\u001b[0m\u001b[33m a\u001b[0m\u001b[33m comma\u001b[0m\u001b[33m-separated\u001b[0m\u001b[33m values\u001b[0m\u001b[33m (\u001b[0m\u001b[33mCSV\u001b[0m\u001b[33m)\u001b[0m\u001b[33m table\u001b[0m\u001b[33m with\u001b[0m\u001b[33m the\u001b[0m\u001b[33m following\u001b[0m\u001b[33m columns\u001b[0m\u001b[33m:\n",
             "\n",
-            "# Describe the CSV\n",
-            "print(df.describe())\n",
-            "tool_execution> Tool:code_interpreter Args:{'code': \"import pandas as pd\\n\\n# Read the CSV file\\ndf = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\\n\\n# Describe the CSV\\nprint(df.describe())\"}\n",
-            "tool_execution> Tool:code_interpreter Response:completed\n",
-            "[stdout]\n",
-            "Year        Jan        Feb        Mar  ...        Sep        Oct        Nov        Dec\n",
-            "count    10.00000  10.000000  10.000000  10.000000  ...  10.000000  10.000000  10.000000  10.000000\n",
-            "mean   2018.50000   2.700000   2.730000   2.760000  ...   2.850000   2.850000   2.850000   2.890000\n",
-            "std       3.02765   1.667999   1.743591   1.757018  ...   1.593912   1.577093   1.551523   1.569466\n",
-            "min    2014.00000   1.400000   1.300000   1.600000  ...   1.700000   1.600000   1.600000   1.600000\n",
-            "25%    2016.25000   1.650000   1.725000   1.850000  ...   1.750000   1.825000   1.775000   1.875000\n",
-            "50%    2018.50000   2.200000   2.150000   2.050000  ...   2.200000   2.100000   2.150000   2.200000\n",
-            "75%    2020.75000   2.300000   2.375000   2.175000  ...   3.600000   3.575000   3.575000   3.500000\n",
-            "max    2023.00000   6.000000   6.400000   6.500000  ...   6.600000   6.300000   6.000000   5.700000\n",
+            "\u001b[0m\u001b[33m1\u001b[0m\u001b[33m.\u001b[0m\u001b[33m Year\u001b[0m\u001b[33m:\u001b[0m\u001b[33m The\u001b[0m\u001b[33m year\u001b[0m\u001b[33m for\u001b[0m\u001b[33m which\u001b[0m\u001b[33m the\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m rate\u001b[0m\u001b[33m is\u001b[0m\u001b[33m recorded\u001b[0m\u001b[33m.\n",
+            "\u001b[0m\u001b[33m2\u001b[0m\u001b[33m.\u001b[0m\u001b[33m Jan\u001b[0m\u001b[33m,\u001b[0m\u001b[33m Feb\u001b[0m\u001b[33m,\u001b[0m\u001b[33m Mar\u001b[0m\u001b[33m,\u001b[0m\u001b[33m ...,\u001b[0m\u001b[33m Dec\u001b[0m\u001b[33m:\u001b[0m\u001b[33m The\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m rate\u001b[0m\u001b[33m for\u001b[0m\u001b[33m each\u001b[0m\u001b[33m month\u001b[0m\u001b[33m of\u001b[0m\u001b[33m the\u001b[0m\u001b[33m year\u001b[0m\u001b[33m,\u001b[0m\u001b[33m expressed\u001b[0m\u001b[33m as\u001b[0m\u001b[33m a\u001b[0m\u001b[33m decimal\u001b[0m\u001b[33m value\u001b[0m\u001b[33m.\n",
             "\n",
-            "[8 rows x 13 columns]\n",
-            "[/stdout]\n",
-            "shield_call> No Violation\n",
-            "inference> The CSV file appears to be a dataset with 10 rows and 13 columns. The columns represent various economic indicators, such as inflation rates for each month from January to December, as well as year (yearly inflation rate).\n",
+            "\u001b[0m\u001b[33mThe\u001b[0m\u001b[33m data\u001b[0m\u001b[33m suggests\u001b[0m\u001b[33m that\u001b[0m\u001b[33m the\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m rate\u001b[0m\u001b[33m has\u001b[0m\u001b[33m fluct\u001b[0m\u001b[33muated\u001b[0m\u001b[33m over\u001b[0m\u001b[33m the\u001b[0m\u001b[33m years\u001b[0m\u001b[33m,\u001b[0m\u001b[33m with\u001b[0m\u001b[33m some\u001b[0m\u001b[33m periods\u001b[0m\u001b[33m of\u001b[0m\u001b[33m relatively\u001b[0m\u001b[33m low\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m (\u001b[0m\u001b[33me\u001b[0m\u001b[33m.g\u001b[0m\u001b[33m.,\u001b[0m\u001b[33m \u001b[0m\u001b[33m201\u001b[0m\u001b[33m4\u001b[0m\u001b[33m-\u001b[0m\u001b[33m201\u001b[0m\u001b[33m7\u001b[0m\u001b[33m)\u001b[0m\u001b[33m and\u001b[0m\u001b[33m some\u001b[0m\u001b[33m periods\u001b[0m\u001b[33m of\u001b[0m\u001b[33m higher\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m (\u001b[0m\u001b[33me\u001b[0m\u001b[33m.g\u001b[0m\u001b[33m.,\u001b[0m\u001b[33m \u001b[0m\u001b[33m202\u001b[0m\u001b[33m1\u001b[0m\u001b[33m-\u001b[0m\u001b[33m202\u001b[0m\u001b[33m2\u001b[0m\u001b[33m).\n",
             "\n",
-            "Here is a brief description of the data:\n",
+            "\u001b[0m\u001b[33mSome\u001b[0m\u001b[33m observations\u001b[0m\u001b[33m from\u001b[0m\u001b[33m the\u001b[0m\u001b[33m data\u001b[0m\u001b[33m:\n",
             "\n",
-            "*   The `Year` column contains the year for which the inflation rate is reported.\n",
-            "*   The `Jan`, `Feb`, `Mar`, etc. columns contain the inflation rate for each month (January to December).\n",
-            "*   The `count` column is the count of non-null values in each column.\n",
-            "*   The `mean` column is the mean of the non-null values in each column.\n",
-            "*   The `std` column is the standard deviation of the non-null values in each column.\n",
-            "*   The `min` column is the minimum value in each column.\n",
-            "*   The `25%` column is the 25th percentile (25th percentile) of the non-null values in each column.\n",
-            "*   The `50%` column is the 50th percentile (50th percentile) of the non-null values in each column.\n",
-            "*   The `75%` column is the 75th percentile (75th percentile) of the non-null values in each column.\n",
-            "*   The `max` column is the maximum value in each column.\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m In\u001b[0m\u001b[33mflation\u001b[0m\u001b[33m rates\u001b[0m\u001b[33m were\u001b[0m\u001b[33m relatively\u001b[0m\u001b[33m stable\u001b[0m\u001b[33m from\u001b[0m\u001b[33m \u001b[0m\u001b[33m201\u001b[0m\u001b[33m4\u001b[0m\u001b[33m to\u001b[0m\u001b[33m \u001b[0m\u001b[33m201\u001b[0m\u001b[33m7\u001b[0m\u001b[33m,\u001b[0m\u001b[33m ranging\u001b[0m\u001b[33m from\u001b[0m\u001b[33m around\u001b[0m\u001b[33m \u001b[0m\u001b[33m1\u001b[0m\u001b[33m.\u001b[0m\u001b[33m6\u001b[0m\u001b[33m%\u001b[0m\u001b[33m to\u001b[0m\u001b[33m \u001b[0m\u001b[33m2\u001b[0m\u001b[33m.\u001b[0m\u001b[33m3\u001b[0m\u001b[33m%.\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m In\u001b[0m\u001b[33mflation\u001b[0m\u001b[33m rates\u001b[0m\u001b[33m increased\u001b[0m\u001b[33m significantly\u001b[0m\u001b[33m in\u001b[0m\u001b[33m \u001b[0m\u001b[33m202\u001b[0m\u001b[33m1\u001b[0m\u001b[33m,\u001b[0m\u001b[33m with\u001b[0m\u001b[33m a\u001b[0m\u001b[33m peak\u001b[0m\u001b[33m of\u001b[0m\u001b[33m \u001b[0m\u001b[33m5\u001b[0m\u001b[33m.\u001b[0m\u001b[33m5\u001b[0m\u001b[33m%\u001b[0m\u001b[33m in\u001b[0m\u001b[33m December\u001b[0m\u001b[33m.\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m In\u001b[0m\u001b[33mflation\u001b[0m\u001b[33m rates\u001b[0m\u001b[33m remained\u001b[0m\u001b[33m high\u001b[0m\u001b[33m in\u001b[0m\u001b[33m \u001b[0m\u001b[33m202\u001b[0m\u001b[33m2\u001b[0m\u001b[33m,\u001b[0m\u001b[33m with\u001b[0m\u001b[33m a\u001b[0m\u001b[33m peak\u001b[0m\u001b[33m of\u001b[0m\u001b[33m \u001b[0m\u001b[33m6\u001b[0m\u001b[33m.\u001b[0m\u001b[33m6\u001b[0m\u001b[33m%\u001b[0m\u001b[33m in\u001b[0m\u001b[33m August\u001b[0m\u001b[33m.\n",
+            "\u001b[0m\u001b[33m*\u001b[0m\u001b[33m In\u001b[0m\u001b[33mflation\u001b[0m\u001b[33m rates\u001b[0m\u001b[33m have\u001b[0m\u001b[33m decreased\u001b[0m\u001b[33m slightly\u001b[0m\u001b[33m in\u001b[0m\u001b[33m \u001b[0m\u001b[33m202\u001b[0m\u001b[33m3\u001b[0m\u001b[33m,\u001b[0m\u001b[33m with\u001b[0m\u001b[33m a\u001b[0m\u001b[33m rate\u001b[0m\u001b[33m of\u001b[0m\u001b[33m \u001b[0m\u001b[33m4\u001b[0m\u001b[33m.\u001b[0m\u001b[33m8\u001b[0m\u001b[33m%\u001b[0m\u001b[33m in\u001b[0m\u001b[33m June\u001b[0m\u001b[33m.\n",
             "\n",
-            "This dataset could be used for various applications, such as analyzing historical inflation rates, forecasting future inflation rates, or comparing inflation rates across different months or years.\n",
-            "User> ('Which year ended with the highest inflation ?', None)\n",
-            "inference> According to the data, the year with the highest inflation was 2023. The inflation rate for 2023 is 6.600%.\n",
-            "User> ('What macro economic situations that led to such high inflation in that period?', None)\n",
-            "inference> The high inflation rate in 2023 is likely attributed to a combination of macroeconomic factors, including:\n",
+            "\u001b[0m\u001b[33mIt\u001b[0m\u001b[33m's\u001b[0m\u001b[33m worth\u001b[0m\u001b[33m noting\u001b[0m\u001b[33m that\u001b[0m\u001b[33m the\u001b[0m\u001b[33m data\u001b[0m\u001b[33m only\u001b[0m\u001b[33m includes\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m rates\u001b[0m\u001b[33m up\u001b[0m\u001b[33m to\u001b[0m\u001b[33m June\u001b[0m\u001b[33m \u001b[0m\u001b[33m202\u001b[0m\u001b[33m3\u001b[0m\u001b[33m,\u001b[0m\u001b[33m and\u001b[0m\u001b[33m does\u001b[0m\u001b[33m not\u001b[0m\u001b[33m provide\u001b[0m\u001b[33m information\u001b[0m\u001b[33m on\u001b[0m\u001b[33m the\u001b[0m\u001b[33m underlying\u001b[0m\u001b[33m causes\u001b[0m\u001b[33m of\u001b[0m\u001b[33m the\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m or\u001b[0m\u001b[33m any\u001b[0m\u001b[33m potential\u001b[0m\u001b[33m factors\u001b[0m\u001b[33m that\u001b[0m\u001b[33m may\u001b[0m\u001b[33m influence\u001b[0m\u001b[33m future\u001b[0m\u001b[33m inflation\u001b[0m\u001b[33m rates\u001b[0m\u001b[33m.\u001b[0m\u001b[97m\u001b[0m\n",
+            "\u001b[30m\u001b[0m\u001b[32mUser> Plot average yearly inflation as a time series\u001b[0m\n",
+            "\u001b[30m\u001b[0m"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:390: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_python(\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "tools_for_turn: [AgentToolWithArgs(name='memory', args={'memory_bank_id': 'inflation_data_memory_bank'}), 'code_interpreter']\n",
+            "tools_for_turn_set: {'memory', 'code_interpreter'}\n",
+            "tool_name: memory\n",
+            "tool_def: identifier='memory' provider_resource_id='memory' provider_id='memory-runtime' type='tool' tool_group='memory_group' tool_host= description='Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments' parameters=[ToolParameter(name='input_messages', parameter_type='list', description='Input messages for which to retrieve memory', required=True, default=None)] built_in_type=None metadata={'config': {'memory_bank_configs': [{'bank_id': 'memory_bank_1d984362-ef6c-468e-b5eb-a12b0d782783', 'type': 'vector'}]}} tool_prompt_format=\n",
+            "tool_name: code_interpreter\n",
+            "tool_def: identifier='code_interpreter' provider_resource_id='code_interpreter' provider_id='code-interpreter' type='tool' tool_group='code_interpreter_group' tool_host= description='' parameters=[] built_in_type= metadata={} tool_prompt_format=\n",
+            "tool_name: brave_search\n",
+            "tool_defs: {'memory': ToolDefinition(tool_name='memory', description='Memory tool to retrieve memory from a memory bank based on context of the input messages and attachments', parameters={'input_messages': ToolParamDefinition(param_type='list', description='Input messages for which to retrieve memory', required=True, default=None)}), : ToolDefinition(tool_name=, description=None, parameters=None)}\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:390: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_python(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:390: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_python(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:390: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_python(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:390: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_python(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n",
+            "/Users/dineshyv/miniconda3/envs/stack/lib/python3.10/site-packages/pydantic/main.py:441: UserWarning: Pydantic serializer warnings:\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  Failed to get discriminator value for tagged union serialization with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - defaulting to left to right union serialization.\n",
+            "  PydanticSerializationUnexpectedValue: Expected `ImageContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  PydanticSerializationUnexpectedValue: Expected `TextContentItem` but got `list` with value `[TextContentItem(type='te...TRIEVED-CONTEXT ===\\n')]` - serialized value may not be as expected\n",
+            "  return self.__pydantic_serializer__.to_json(\n"
+          ]
+        },
+        {
+          "data": {
+            "application/vnd.jupyter.widget-view+json": {
+              "model_id": "b79a023a8ddd4f1d80c2c737affc3c91",
+              "version_major": 2,
+              "version_minor": 0
+            },
+            "text/plain": [
+              "Batches:   0%|          | 0/1 [00:00 Tool:memory Args:{'query': '{\"role\":\"user\",\"content\":\"Plot average yearly inflation as a time series\",\"context\":null}', 'memory_bank_id': 'inflation_data_memory_bank'}\u001b[0m\n",
+            "\u001b[36mtool_execution> fetched 3079 bytes from memory\u001b[0m\n",
+            "\u001b[33minference> \u001b[0m\u001b[36m\u001b[0m\u001b[36mimport\u001b[0m\u001b[36m pandas\u001b[0m\u001b[36m as\u001b[0m\u001b[36m pd\u001b[0m\u001b[36m\n",
+            "\n",
+            "\u001b[0m\u001b[36m#\u001b[0m\u001b[36m Define\u001b[0m\u001b[36m the\u001b[0m\u001b[36m data\u001b[0m\u001b[36m\n",
+            "\u001b[0m\u001b[36mdata\u001b[0m\u001b[36m =\u001b[0m\u001b[36m {\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mYear\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m201\u001b[0m\u001b[36m4\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m201\u001b[0m\u001b[36m5\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m201\u001b[0m\u001b[36m6\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m201\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m201\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m201\u001b[0m\u001b[36m9\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m202\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m202\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m202\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m202\u001b[0m\u001b[36m3\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mJan\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m6\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m6\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m4\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m6\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m5\u001b[0m\u001b[36m.\u001b[0m\u001b[36m6\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mFeb\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m6\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m4\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m6\u001b[0m\u001b[36m.\u001b[0m\u001b[36m4\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m5\u001b[0m\u001b[36m.\u001b[0m\u001b[36m5\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mMar\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m6\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m6\u001b[0m\u001b[36m.\u001b[0m\u001b[36m5\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m5\u001b[0m\u001b[36m.\u001b[0m\u001b[36m6\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mApr\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m9\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m4\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m3\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m6\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m5\u001b[0m\u001b[36m.\u001b[0m\u001b[36m5\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mMay\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m3\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m6\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m5\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mJun\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m9\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m1\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m4\u001b[0m\u001b[36m.\u001b[0m\u001b[36m5\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m5\u001b[0m\u001b[36m.\u001b[0m\u001b[36m9\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m4\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mJul\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m9\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m4\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m6\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m4\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m5\u001b[0m\u001b[36m.\u001b[0m\u001b[36m9\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m4\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mAug\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m4\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m4\u001b[0m\u001b[36m.\u001b[0m\u001b[36m0\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m6\u001b[0m\u001b[36m.\u001b[0m\u001b[36m3\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m4\u001b[0m\u001b[36m.\u001b[0m\u001b[36m8\u001b[0m\u001b[36m],\n",
+            "\u001b[0m\u001b[36m   \u001b[0m\u001b[36m \"\u001b[0m\u001b[36mSep\u001b[0m\u001b[36m\":\u001b[0m\u001b[36m [\u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m7\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[36m.\u001b[0m\u001b[36m9\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m2\u001b[0m\u001b[36m.\u001b[0m\u001b[36m2\u001b[0m\u001b[36m,\u001b[0m\u001b[36m \u001b[0m\u001b[36m1\u001b[0m\u001b[97m\u001b[0m\n",
+            "\u001b[32mtool_execution> Tool:code_interpreter Args:{'code': 'import pandas as pd\\n\\n# Define the data\\ndata = {\\n    \"Year\": [2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023],\\n    \"Jan\": [1.6, 1.6, 2.2, 2.3, 1.8, 2.2, 2.3, 1.4, 6.0, 5.6],\\n    \"Feb\": [1.6, 1.7, 2.3, 2.2, 1.8, 2.1, 2.4, 1.3, 6.4, 5.5],\\n    \"Mar\": [1.7, 1.8, 2.2, 2.0, 2.1, 2.0, 2.1, 1.6, 6.5, 5.6],\\n    \"Apr\": [1.8, 1.8, 2.1, 1.9, 2.1, 2.1, 1.4, 3.0, 6.2, 5.5],\\n    \"May\": [2.0, 1.7, 2.2, 1.7, 2.2, 2.0, 1.2, 3.8, 6.0, 5.3],\\n    \"Jun\": [1.9, 1.8, 2.2, 1.7, 2.3, 2.1, 1.2, 4.5, 5.9, 4.8],\\n    \"Jul\": [1.9, 1.8, 2.2, 1.7, 2.4, 2.2, 1.6, 4.3, 5.9, 4.8],\\n    \"Aug\": [1.7, 1.8, 2.3, 1.7, 2.2, 2.4, 1.7, 4.0, 6.3, 4.8],\\n    \"Sep\": [1.7, 1.9, 2.2, 1'}\u001b[0m\n",
+            "\u001b[32mtool_execution> Tool:code_interpreter Response:error\n",
+            "[stdout]\n",
+            "[Errno 2] No such file or directory: 'bwrap'\n",
+            "[/stdout]\n",
+            "[stderr]\n",
+            "[Errno 2] No such file or directory: 'bwrap'\n",
+            "[/stderr]\u001b[0m\n",
+            "\u001b[33minference> \u001b[0m"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+            "To disable this warning, you can either:\n",
+            "\t- Avoid using `tokenizers` before the fork if possible\n",
+            "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "\u001b[33mThe\u001b[0m\u001b[33m error\u001b[0m\u001b[33m message\u001b[0m\u001b[33m indicates\u001b[0m\u001b[33m that\u001b[0m\u001b[33m the\u001b[0m\u001b[33m system\u001b[0m\u001b[33m cannot\u001b[0m\u001b[33m find\u001b[0m\u001b[33m the\u001b[0m\u001b[33m '\u001b[0m\u001b[33mb\u001b[0m\u001b[33mwrap\u001b[0m\u001b[33m'\u001b[0m\u001b[33m file\u001b[0m\u001b[33m,\u001b[0m\u001b[33m which\u001b[0m\u001b[33m is\u001b[0m\u001b[33m required\u001b[0m\u001b[33m for\u001b[0m\u001b[33m the\u001b[0m\u001b[33m plot\u001b[0m\u001b[33m to\u001b[0m\u001b[33m be\u001b[0m\u001b[33m displayed\u001b[0m\u001b[33m.\u001b[0m\u001b[33m This\u001b[0m\u001b[33m issue\u001b[0m\u001b[33m is\u001b[0m\u001b[33m likely\u001b[0m\u001b[33m due\u001b[0m\u001b[33m to\u001b[0m\u001b[33m a\u001b[0m\u001b[33m missing\u001b[0m\u001b[33m or\u001b[0m\u001b[33m incorrect\u001b[0m\u001b[33m installation\u001b[0m\u001b[33m of\u001b[0m\u001b[33m the\u001b[0m\u001b[33m '\u001b[0m\u001b[33mb\u001b[0m\u001b[33mwrap\u001b[0m\u001b[33m'\u001b[0m\u001b[33m package\u001b[0m\u001b[33m.\n",
             "\n",
-            "1. **Supply chain disruptions**: The COVID-19 pandemic and subsequent lockdowns led to supply chain disruptions, resulting in shortages and price increases for various goods and services.\n",
-            "2. **Economic growth**: The rapid economic growth in the preceding years created demand for goods and services, leading to higher production costs and, subsequently, higher prices.\n",
-            "3. **Monetary policy**: The central bank's easy-money policies, such as quantitative easing and low interest rates, increased the money supply and led to inflationary pressures.\n",
-            "4. **Commodity price shocks**: Increases in global commodity prices, such as oil and food prices, contributed to higher production costs and inflation.\n",
-            "5. **Labor market tightness**: The labor market has been tight, leading to higher wages and, subsequently, higher production costs, which have been passed on to consumers.\n",
-            "6. **Trade wars and tariffs**: The ongoing trade tensions and tariffs imposed by various countries have disrupted global supply chains, leading to higher prices for imported goods.\n",
-            "7. **Climate change and extreme weather events**: The increasing frequency and severity of extreme weather events, such as heatwaves and droughts, have disrupted agricultural production and supply chains.\n",
-            "8. **Currency devaluation**: A devaluation of the currency can make imports more expensive, leading to higher inflation.\n",
-            "9. **Government spending and fiscal policy**: Government spending and fiscal policy decisions, such as tax cuts and increased government spending, can inject more money into the economy, leading to inflation.\n",
-            "10. **Monetary policy mistakes**: Mistakes in monetary policy, such as premature interest rate hikes or overly aggressive quantitative easing, can lead to inflationary pressures.\n",
+            "\u001b[0m\u001b[33mTo\u001b[0m\u001b[33m fix\u001b[0m\u001b[33m this\u001b[0m\u001b[33m issue\u001b[0m\u001b[33m,\u001b[0m\u001b[33m you\u001b[0m\u001b[33m can\u001b[0m\u001b[33m try\u001b[0m\u001b[33m reinstall\u001b[0m\u001b[33ming\u001b[0m\u001b[33m the\u001b[0m\u001b[33m '\u001b[0m\u001b[33mb\u001b[0m\u001b[33mwrap\u001b[0m\u001b[33m'\u001b[0m\u001b[33m package\u001b[0m\u001b[33m using\u001b[0m\u001b[33m pip\u001b[0m\u001b[33m:\n",
             "\n",
-            "It's worth noting that the specific factors contributing to the high inflation rate in 2023 may vary depending on the region, country, or even specific economy.\n",
-            "User> ('Plot average yearly inflation as a time series', None)\n",
-            "inference> import pandas as pd\n",
-            "import matplotlib.pyplot as plt\n",
+            "\u001b[0m\u001b[33mpip\u001b[0m\u001b[33m install\u001b[0m\u001b[33m b\u001b[0m\u001b[33mwrap\u001b[0m\u001b[33m\n",
             "\n",
-            "# Read the CSV file\n",
-            "df = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\n",
+            "\u001b[0m\u001b[33mIf\u001b[0m\u001b[33m the\u001b[0m\u001b[33m issue\u001b[0m\u001b[33m persists\u001b[0m\u001b[33m,\u001b[0m\u001b[33m you\u001b[0m\u001b[33m can\u001b[0m\u001b[33m try\u001b[0m\u001b[33m to\u001b[0m\u001b[33m display\u001b[0m\u001b[33m the\u001b[0m\u001b[33m plot\u001b[0m\u001b[33m using\u001b[0m\u001b[33m a\u001b[0m\u001b[33m different\u001b[0m\u001b[33m method\u001b[0m\u001b[33m,\u001b[0m\u001b[33m such\u001b[0m\u001b[33m as\u001b[0m\u001b[33m saving\u001b[0m\u001b[33m the\u001b[0m\u001b[33m plot\u001b[0m\u001b[33m to\u001b[0m\u001b[33m a\u001b[0m\u001b[33m file\u001b[0m\u001b[33m:\n",
             "\n",
-            "# Extract the year and inflation rate from the CSV file\n",
-            "df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n",
-            "df = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\n",
+            "\u001b[0m\u001b[33mimport\u001b[0m\u001b[33m matplotlib\u001b[0m\u001b[33m.pyplot\u001b[0m\u001b[33m as\u001b[0m\u001b[33m plt\u001b[0m\u001b[33m\n",
             "\n",
-            "# Calculate the average yearly inflation rate\n",
-            "df['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\n",
+            "\u001b[0m\u001b[33m#\u001b[0m\u001b[33m ...\u001b[0m\u001b[33m (\u001b[0m\u001b[33mrest\u001b[0m\u001b[33m of\u001b[0m\u001b[33m the\u001b[0m\u001b[33m code\u001b[0m\u001b[33m remains\u001b[0m\u001b[33m the\u001b[0m\u001b[33m same\u001b[0m\u001b[33m)\n",
             "\n",
-            "# Plot the average yearly inflation rate as a time series\n",
-            "plt.figure(figsize=(10, 6))\n",
-            "plt.plot(df['Year'], df['Yearly Inflation'], marker='o')\n",
-            "plt.title('Average Yearly Inflation Rate')\n",
-            "plt.xlabel('Year')\n",
-            "plt.ylabel('Inflation Rate (%)')\n",
-            "plt.grid(True)\n",
-            "plt.show()\n",
-            "tool_execution> Tool:code_interpreter Args:{'code': \"import pandas as pd\\nimport matplotlib.pyplot as plt\\n\\n# Read the CSV file\\ndf = pd.read_csv('/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv')\\n\\n# Extract the year and inflation rate from the CSV file\\ndf['Year'] = pd.to_datetime(df['Year'], format='%Y')\\ndf = df.rename(columns={'Jan': 'Jan Rate', 'Feb': 'Feb Rate', 'Mar': 'Mar Rate', 'Apr': 'Apr Rate', 'May': 'May Rate', 'Jun': 'Jun Rate', 'Jul': 'Jul Rate', 'Aug': 'Aug Rate', 'Sep': 'Sep Rate', 'Oct': 'Oct Rate', 'Nov': 'Nov Rate', 'Dec': 'Dec Rate'})\\n\\n# Calculate the average yearly inflation rate\\ndf['Yearly Inflation'] = df[['Jan Rate', 'Feb Rate', 'Mar Rate', 'Apr Rate', 'May Rate', 'Jun Rate', 'Jul Rate', 'Aug Rate', 'Sep Rate', 'Oct Rate', 'Nov Rate', 'Dec Rate']].mean(axis=1)\\n\\n# Plot the average yearly inflation rate as a time series\\nplt.figure(figsize=(10, 6))\\nplt.plot(df['Year'], df['Yearly Inflation'], marker='o')\\nplt.title('Average Yearly Inflation Rate')\\nplt.xlabel('Year')\\nplt.ylabel('Inflation Rate (%)')\\nplt.grid(True)\\nplt.show()\"}\n",
-            "tool_execution> Tool:code_interpreter Response:completed\n",
-            "shield_call> No Violation\n",
-            "inference> This code reads the CSV file, extracts the year and inflation rate, calculates the average yearly inflation rate, and plots the average yearly inflation rate as a time series. The resulting plot shows the average inflation rate over the years.\n"
+            "\u001b[0m\u001b[33mplt\u001b[0m\u001b[33m.savefig\u001b[0m\u001b[33m('\u001b[0m\u001b[33min\u001b[0m\u001b[33mflation\u001b[0m\u001b[33m_rate\u001b[0m\u001b[33m.png\u001b[0m\u001b[33m')\n",
+            "\n",
+            "\u001b[0m\u001b[33mThis\u001b[0m\u001b[33m will\u001b[0m\u001b[33m save\u001b[0m\u001b[33m the\u001b[0m\u001b[33m plot\u001b[0m\u001b[33m to\u001b[0m\u001b[33m a\u001b[0m\u001b[33m file\u001b[0m\u001b[33m named\u001b[0m\u001b[33m '\u001b[0m\u001b[33min\u001b[0m\u001b[33mflation\u001b[0m\u001b[33m_rate\u001b[0m\u001b[33m.png\u001b[0m\u001b[33m'\u001b[0m\u001b[33m in\u001b[0m\u001b[33m the\u001b[0m\u001b[33m current\u001b[0m\u001b[33m working\u001b[0m\u001b[33m directory\u001b[0m\u001b[33m.\u001b[0m\u001b[97m\u001b[0m\n",
+            "\u001b[30m\u001b[0m"
           ]
         }
       ],
       "source": [
         "agent_config = AgentConfig(\n",
+        "    sampling_params = {\n",
+        "        \"max_tokens\" : 4096,\n",
+        "        \"temperature\": 0.0\n",
+        "    },\n",
         "    model=model_id,\n",
         "    instructions=\"You are a helpful assistant\",\n",
         "    tools=[\n",
-        "        search_tool,\n",
-        "        {\n",
-        "            \"type\": \"code_interpreter\",\n",
-        "        }\n",
+        "        \"brave_search\",\n",
+        "        \"code_interpreter\",\n",
         "    ],\n",
         "    tool_choice=\"required\",\n",
         "    input_shields=[],\n",
@@ -1766,38 +1997,48 @@
         "    enable_session_persistence=False,\n",
         ")\n",
         "\n",
+        "memory_bank_id = \"inflation_data_memory_bank\"\n",
+        "client.memory_banks.register(\n",
+        "    memory_bank_id=memory_bank_id,\n",
+        "    params={\n",
+        "        \"memory_bank_type\": \"vector\",\n",
+        "        \"embedding_model\": \"all-MiniLM-L6-v2\",\n",
+        "        \"chunk_size_in_tokens\": 512,\n",
+        "        \"overlap_size_in_tokens\": 64,\n",
+        "    },\n",
+        ")\n",
+        "AugmentConfigWithMemoryTool(agent_config, client)\n",
         "codex_agent = Agent(client, agent_config)\n",
         "session_id = codex_agent.create_session(\"test-session\")\n",
         "\n",
+        "client.memory.insert(\n",
+        "    bank_id=memory_bank_id,\n",
+        "    documents=[\n",
+        "        Document(\n",
+        "            document_id=\"inflation\",\n",
+        "            content=\"https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv\",\n",
+        "            mime_type=\"text/csv\",\n",
+        "            metadata={},\n",
+        "        )\n",
+        "    ],\n",
+        ")\n",
+        "\n",
         "user_prompts = [\n",
-        "    (\n",
-        "        \"Here is a csv, can you describe it ?\",\n",
-        "        [\n",
-        "            Attachment(\n",
-        "                content=\"https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv\",\n",
-        "                mime_type=\"test/csv\",\n",
-        "            )\n",
-        "        ],\n",
-        "    ),\n",
-        "    (\"Which year ended with the highest inflation ?\", None),\n",
-        "    (\n",
-        "        \"What macro economic situations that led to such high inflation in that period?\",\n",
-        "        None,\n",
-        "    ),\n",
-        "    (\"Plot average yearly inflation as a time series\", None),\n",
+        "    {\"prompt\": \"Can you describe the data in the context?\", \"tools\": [{\"name\": \"memory\", \"args\": {\"memory_bank_id\": memory_bank_id}}]},\n",
+        "    {\"prompt\": \"Plot average yearly inflation as a time series\", \"tools\": [{\"name\": \"memory\", \"args\": {\"memory_bank_id\": memory_bank_id}}, \"code_interpreter\"]},\n",
         "]\n",
         "\n",
-        "for prompt in user_prompts:\n",
-        "    cprint(f'User> {prompt}', 'green')\n",
+        "for input in user_prompts:\n",
+        "    cprint(f'User> {input[\"prompt\"]}', 'green')\n",
         "    response = codex_agent.create_turn(\n",
         "        messages=[\n",
         "            {\n",
         "                \"role\": \"user\",\n",
-        "                \"content\": prompt[0],\n",
+        "                \"content\": input[\"prompt\"],\n",
         "            }\n",
         "        ],\n",
-        "        attachments=prompt[1],\n",
         "        session_id=session_id,\n",
+        "        tools=input[\"tools\"],\n",
         "    )\n",
         "    # for chunk in response:\n",
         "    #     print(chunk)\n",
@@ -1818,7 +2059,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 5,
       "id": "JqBBVLKdIHHq",
       "metadata": {
         "colab": {
@@ -1830,14 +2071,20 @@
       },
       "outputs": [
         {
-          "data": {
-            "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0EAAAIjCAYAAADFthA8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB+WklEQVR4nO3dd3hUZdrH8d+k90BCGiSE0AkBpFdFVJoUscGiKCq6rmt3XffVVQFdd3Vd265tbdjAguIKKiACgvReQi+hh4QQSCGkzZz3j5BITIBkmJkzyXw/15ULcubknPvcmYG553nO/VgMwzAEAAAAAB7Cy+wAAAAAAMCVKIIAAAAAeBSKIAAAAAAehSIIAAAAgEehCAIAAADgUSiCAAAAAHgUiiAAAAAAHoUiCAAAAIBHoQgCAAAA4FEoggAAbu3yyy/X5ZdfbnYYFT755BO1bdtWvr6+atCggSTnxDhp0iRZLBaHHhMAUIYiCIDHevPNN2WxWNSzZ0+zQ3Eby5cvl5eXlx5//PFqH3/hhRdksVj0/fffuzgyx7FYLLrvvvvs+tnt27frtttuU4sWLfTuu+/qnXfeuahYCgoKNGnSJP38888XdRxHs1gslb7CwsLUv3//i/q9T5s2Ta+++qrjggSAi0ARBMBjTZ06Vc2aNdOqVau0e/dus8NxC71799bdd9+tl156SVu2bKn02P79+/XMM8/oxhtv1LBhw0yK0Fw///yzbDabXnvtNd12220aPXr0RR2voKBAkydPrrYIevLJJ3X69OmLOv7FGDhwoD755BN9/PHHeuyxx7R7926NGDFCc+fOtet4FEEA3AlFEACPlJaWpmXLlunll19WVFSUpk6d6vIYbDabCgsLXX7eC3n++efVqFEj3X333TIMo2L7/fffL19fX7322msuiaOgoMAl56mNzMxMSaqYBudMPj4+CggIcPp5zqV169YaN26cbrnlFj355JP66aefZBiGy37/AOBMFEEAPNLUqVPVsGFDDRs2TDfccEOlIqikpEQRERG6/fbbq/xcbm6uAgIC9Oijj1ZsKyoq0sSJE9WyZUv5+/srISFBjz32mIqKiir9bPk0rKlTp6p9+/by9/fXnDlzJEn/+te/1KdPH0VGRiowMFBdu3bVV199VeX8p0+f1gMPPKBGjRopNDRUI0eO1OHDh2WxWDRp0qRK+x4+fFh33HGHYmJi5O/vr/bt2+uDDz64YG7Cw8P12muvaenSpXrvvfckSd98841mzZql559/XnFxcbLZbHr11VfVvn17BQQEKCYmRnfffbdOnDhR6Vjffvuthg0bpsaNG8vf318tWrTQs88+K6vVWmm/yy+/XCkpKVq7dq0uu+wyBQUF6YknnqgSW35+voKDg/Xggw9WeezQoUPy9vbWP/7xjwte49l+/vlnWSwWffnll3ruuecUHx+vgIAAXXnllZVGCJs1a6aJEydKkqKioqrNebni4mI9/fTT6tq1q8LDwxUcHKxLL71UCxcurNhn3759ioqKkiRNnjy5YupZ+TGruyeotLRUzz77rFq0aCF/f381a9ZMTzzxRJXnWrNmzTR8+HAtWbJEPXr0UEBAgJo3b66PP/64Vrk5W7t27dSoUSPt2bOn0vaa/I4vv/xyff/999q/f3/FdTZr1qzi8Zq+hgDAYQwA8EBt27Y1JkyYYBiGYSxevNiQZKxatari8TvuuMNo0KCBUVRUVOnnPvroI0OSsXr1asMwDMNqtRqDBg0ygoKCjIceesj473//a9x3332Gj4+Pcc0111T6WUlGu3btjKioKGPy5MnGG2+8Yaxfv94wDMOIj483/vjHPxqvv/668fLLLxs9evQwJBnfffddpWOMHj3akGTccsstxhtvvGGMHj3a6NSpkyHJmDhxYsV+R48eNeLj442EhATjmWeeMd566y1j5MiRhiTjlVdeqVGOhg0bZjRs2NDYs2ePkZCQYPTp08ew2WyGYRjGnXfeafj4+Bh33XWX8fbbbxt/+ctfjODgYKN79+5GcXFxxTFGjRpljB492njxxReNt956y7jxxhsNScajjz5a6Vz9+/c3YmNjjaioKOP+++83/vvf/xr/+9//Kh7r379/xb4333yzERMTY5SWllY6xj//+U/DYrEY+/fvP+91STLuvffeiu8XLlxoSDI6d+5sdO3a1XjllVeMSZMmGUFBQUaPHj0q9vvmm2+Ma6+91pBkvPXWW8Ynn3xibNy4sdoYjx07ZsTFxRmPPPKI8dZbbxn//Oc/jTZt2hi+vr4Vv/P8/HzjrbfeMiQZ1157rfHJJ59UOubEiRON3/43PX78eEOSccMNNxhvvPGGceuttxqSjFGjRlXaLzEx0WjTpo0RExNjPPHEE8brr79udOnSxbBYLEZqaup581NdjgzDME6ePGl4e3sbPXv2rLS9Jr/jH3/80bjkkkuMRo0aVVznN998YxhG7V5DAOAoFEEAPM6aNWsMSca8efMMwzAMm81mxMfHGw8++GDFPnPnzjUkGbNmzar0s1dffbXRvHnziu8/+eQTw8vLy/jll18q7ff2228bkoylS5dWbJNkeHl5GVu2bKkSU0FBQaXvi4uLjZSUFOOKK66o2LZ27VpDkvHQQw9V2ve2226rUgRNmDDBiIuLM7Kysirt+7vf/c4IDw+vcr7q7Nu3zwgODjYiIiIMX19fY/PmzYZhGMYvv/xiSDKmTp1aaf85c+ZU2V7dee6++24jKCjIKCwsrNjWv39/Q5Lx9ttvV9n/twVG+e9m9uzZlfbr2LFjpf3O5VxFULt27SoVva+99pohqeK6DePXwuTYsWPnjbG0tLRKAX3ixAkjJibGuOOOOyq2HTt2rMrv7rfnKrdhwwZDknHnnXdW2u/RRx81JBkLFiyo2JaYmGhIMhYvXlyxLTMz0/D39zf+9Kc/nSs1FSQZEyZMMI4dO2ZkZmYaa9asMYYMGWJIMl588cVK+9b0dzxs2DAjMTGxyr61eQ0BgKMwHQ6Ax5k6dapiYmI0YMAASWXT1MaMGaPPP/+8YgrPFVdcoUaNGumLL76o+LkTJ05o3rx5GjNmTMW26dOnq127dmrbtq2ysrIqvq644gpJqjT9SZL69++v5OTkKjEFBgZWOk9OTo4uvfRSrVu3rmJ7+dS5P/7xj5V+9v7776/0vWEY+vrrrzVixAgZhlEprsGDBysnJ6fScc8lMTFREydOVHZ2th555BGlpKRUXHN4eLgGDhxY6dhdu3ZVSEhIpWs++7ry8vKUlZWlSy+9VAUFBdq+fXul8/n7+1c7BfG3rrrqKjVu3LjSFMbU1FRt2rRJ48aNu+DPn8vtt98uPz+/iu8vvfRSSdLevXtrfSxvb++KY9lsNmVnZ6u0tFTdunWrUe6r88MPP0iSHnnkkUrb//SnP0lSlc5tycnJFdcglU3ha9OmTY2v5/3331dUVJSio6PVrVs3zZ8/X4899liV89fmd1yd2r6GAMARfMwOAABcyWq16vPPP9eAAQOUlpZWsb1nz5566aWXNH/+fA0aNEg+Pj66/vrrNW3aNBUVFcnf318zZsxQSUlJpSJo165d2rZtW8W9Hb9VfiN9uaSkpGr3++677/S3v/1NGzZsqHQfxNn3hOzfv19eXl5VjtGyZctK3x87dkwnT57UO++8c84Wzr+N61y6d+8uSerWrVvFtl27diknJ0fR0dEXPPaWLVv05JNPasGCBcrNza20X05OTqXvmzRpUqkIORcvLy/dfPPNeuutt1RQUKCgoCBNnTpVAQEBuvHGG2t0XdVp2rRppe8bNmwoSVXuc6qpjz76SC+99JK2b9+ukpKSiu3neg5cSPnv/7e/79jYWDVo0ED79++vtP231yOVXVNNr+eaa67Rfffdp+LiYq1evVp///vfVVBQIC+vyp+f1uZ3XJ3avoYAwBEoggB4lAULFig9PV2ff/65Pv/88yqPT506VYMGDZIk/e53v9N///tfzZ49W6NGjdKXX36ptm3bqlOnThX722w2dejQQS+//HK150tISKj0/dmfmpf75ZdfNHLkSF122WV68803FRcXJ19fX02ZMkXTpk2r9TXabDZJ0rhx4zR+/Phq9+nYsWOtj3v28aOjo8/ZUa/8zezJkyfVv39/hYWF6ZlnnlGLFi0UEBCgdevW6S9/+UtFnOWqy8253HrrrXrxxRf1v//9T2PHjtW0adM0fPhwhYeH231d3t7e1W43zuqQV1OffvqpbrvtNo0aNUp//vOfFR0dXdG04beNBWqrpguoXuz1xMfH66qrrpIkXX311WrUqJHuu+8+DRgwQNddd52k2v+Oq1Pb1xAAOAJFEACPMnXqVEVHR+uNN96o8tiMGTP0zTff6O2331ZgYKAuu+wyxcXF6YsvvlC/fv20YMEC/fWvf630My1atNDGjRt15ZVX1vjN6W99/fXXCggI0Ny5c+Xv71+xfcqUKZX2S0xMlM1mU1pamlq1alWx/bdrHEVFRSk0NFRWq7XiTawjtWjRQj/99JP69u173sLl559/1vHjxzVjxgxddtllFdvPHoGzV0pKijp37qypU6cqPj5eBw4c0H/+85+LPq6jfPXVV2revLlmzJhR6XlR3l2uXG2eM+W//127dqldu3YV2zMyMnTy5EklJiZefODncffdd+uVV17Rk08+qWuvvVYWi6VWv+NzXasjXkMAUFvcEwTAY5w+fVozZszQ8OHDdcMNN1T5uu+++5SXl6eZM2dKKpt2dcMNN2jWrFn65JNPVFpaWmkqnCSNHj1ahw8f1rvvvlvt+U6dOnXBuLy9vWWxWCq1FN63b5/+97//Vdpv8ODBkqQ333yz0vbfvvn39vbW9ddfr6+//lqpqalVznfs2LELxnQ+o0ePltVq1bPPPlvlsdLSUp08ebIiDqnyyENxcXGV+O11yy236Mcff9Srr76qyMhIDR061CHHdYTqrn3lypVavnx5pf2CgoIkqSJn53P11VdLUpUFR8tHUJy9gK2Pj4/+9Kc/adu2bfr2228l1e53HBwcXO30OEe8hgCgthgJAuAxZs6cqby8PI0cObLax3v16lWxcGp5sTNmzBj95z//0cSJE9WhQ4dKn8BLZW/Ev/zyS/3hD3/QwoUL1bdvX1mtVm3fvl1ffvml5s6dW+l+muoMGzZML7/8soYMGaKbbrpJmZmZeuONN9SyZUtt2rSpYr+uXbvq+uuv16uvvqrjx4+rV69eWrRokXbu3Cmp8iftzz//vBYuXKiePXvqrrvuUnJysrKzs7Vu3Tr99NNPys7OtiuHUllzh7vvvlv/+Mc/tGHDBg0aNEi+vr7atWuXpk+frtdee0033HCD+vTpo4YNG2r8+PF64IEHZLFY9Mknn9g1vaw6N910kx577DF98803uueee+Tr6+uQ4zrC8OHDNWPGDF177bUaNmyY0tLS9Pbbbys5OVn5+fkV+wUGBio5OVlffPGFWrdurYiICKWkpFQ0oThbp06dNH78eL3zzjsV09BWrVqljz76SKNGjapo9OFMt912m55++mm98MILGjVqVK1+x127dtUXX3yhRx55RN27d1dISIhGjBjhkNcQANSaaX3pAMDFRowYYQQEBBinTp065z633Xab4evrW9Fa2mazGQkJCYYk429/+1u1P1NcXGy88MILRvv27Q1/f3+jYcOGRteuXY3JkycbOTk5FfupmrVXyr3//vtGq1atDH9/f6Nt27bGlClTql0n5tSpU8a9995rREREGCEhIcaoUaOMHTt2GJKM559/vtK+GRkZxr333mskJCQYvr6+RmxsrHHllVca77zzTo3yZRi/to+ePn16lcfeeecdo2vXrkZgYKARGhpqdOjQwXjssceMI0eOVOyzdOlSo1evXkZgYKDRuHFj47HHHqtocb1w4cKK/fr372+0b9++2hh+2376bFdffbUhyVi2bFmNr+m3v4dzXWNaWpohyZgyZUrFtpq2yLbZbMbf//53IzEx0fD39zc6d+5sfPfdd8b48eOrtIletmyZ0bVrV8PPz69Su+zqfv8lJSXG5MmTjaSkJMPX19dISEgwHn/88UqtqA2jrEX2sGHDqlz7+XJ5tvM9VydNmlTp91fT33F+fr5x0003GQ0aNDAkVcpDTV9DAOAoFsNw0EdyAABTbNiwQZ07d9ann36qm2++2exwXOraa6/V5s2bq9wXBQDA+XBPEADUIadPn66y7dVXX5WXl1elG9M9QXp6ur7//nvdcsstZocCAKhjuCcIAOqQf/7zn1q7dq0GDBggHx8fzZ49W7Nnz9bvf/97j2klnJaWpqVLl+q9996Tr6+v7r77brNDAgDUMRRBAFCH9OnTR/PmzdOzzz6r/Px8NW3aVJMmTarSurs+W7RokW6//XY1bdpUH330kWJjY80OCQBQx3BPEAAAAACPwj1BAAAAADwKRRAAAAAAj1Kn7wmy2Ww6cuSIQkNDKy0SCAAAAMCzGIahvLw8NW7cWF5e5x/rqdNF0JEjRzymGxIAAACACzt48KDi4+PPu0+dLoJCQ0MllV1oWFiYqbGUlJToxx9/1KBBg+Tr62tqLHUNubMPebMPebMfubMPebMPebMPebMfubOPO+UtNzdXCQkJFTXC+dTpIqh8ClxYWJhbFEFBQUEKCwsz/QlQ15A7+5A3+5A3+5E7+5A3+5A3+5A3+5E7+7hj3mpymwyNEQAAAAB4FIogAAAAAB6FIggAAACAR6EIAgAAAOBRKIIAAAAAeBSKIAAAAAAehSIIAAAAgEehCAIAAADgUSiCAAAAAHgUiiAAAAAAHoUiCAAAAIBHoQgCAAAA4FEoggAAAAB4FIogAAAAeDSrzdDKtGytzbJoZVq2rDbD7JDgZD5mBwAAAACYZU5quibP2qr0nEJJ3vp41xrFhQdo4ohkDUmJMzs8OAkjQQAAAPBIc1LTdc+n684UQL86mlOoez5dpzmp6SZFBmejCAIAAIDHsdoMTZ61VdVNfCvfNnnWVqbG1VMUQQAAAPA4q9Kyq4wAnc2QlJ5TqFVp2a4LCi5DEQQAAACPk5l37gLInv1Qt1AEAQAAwONEhwY4dD/ULRRBAAAA8Dg9kiIUF37uAsciKS48QD2SIlwXFFyGIggAAAAex9vLookjks/5uCFp4ohkeXtZXBcUXIYiCAAAAB7pynYxCvLzrvaxZpFBGpQc6+KI4CoUQQAAAPBIK/dmq6DYqoggX310W1fd2sqqf4/pqCBfL+07XqDpaw+aHSKchCIIAAAAHmn2mcVQB6fEqk+LSHVtZGhoSqweGdRGkvT87O06carYzBDhJBRBAAAA8DhWm6G5WzIkSYPbV572Nr5PM7WJCdWJghK9+OMOM8KDk1EEAQAAwOOsP3BCWflFCg3wUZ8WjSo95uvtpWeuaS9J+mzVAW08eNKECOFMFEEAAADwOLNTj0qSrmoXIz+fqm+JezaP1LWdm8gwpKe+TZXVZrg6RDiR6UXQ4cOHNW7cOEVGRiowMFAdOnTQmjVrzA4LAAAA9ZRhGJpzpgj67VS4sz1+dVuF+vto06Ecfb76gKvCgwuYWgSdOHFCffv2la+vr2bPnq2tW7fqpZdeUsOGDc0MCwAAAPVY6uFcHT55WoG+3urfOuqc+0WHBuiRQa0lSf+cs0PZNEmoN3zMPPkLL7yghIQETZkypWJbUlKSiREBAACgvpuzpawr3OVtohR4jnWCyt3SK1Ffrjmkbem5emH2dr1wQ0dXhAgnM7UImjlzpgYPHqwbb7xRixYtUpMmTfTHP/5Rd911V7X7FxUVqaioqOL73NxcSVJJSYlKSkpcEvO5lJ/f7DjqInJnH/JmH/JmP3JnH/JmH/JmH/JWM7M3l02FG9guqkrOqsvdxGFt9Lv3VuuLNQd1fZc4dU5o4LJY3Z07PedqE4PFMAzT7vIKCAiQJD3yyCO68cYbtXr1aj344IN6++23NX78+Cr7T5o0SZMnT66yfdq0aQoKCnJ6vAAAAKjbjhZI/9joI2+Lob93syqghkMCU3d7adUxL8UHG/pTB6u8LM6NE7VXUFCgm266STk5OQoLCzvvvqYWQX5+furWrZuWLVtWse2BBx7Q6tWrtXz58ir7VzcSlJCQoKysrAteqLOVlJRo3rx5GjhwoHx9fU2Npa4hd/Yhb/Yhb/Yjd/Yhb/Yhb/Yhbxf2xs979er83bq8dSO9e0uXiu0Xyt3x/CINem2pcgtLNXF4W43r2dSVYbstd3rO5ebmqlGjRjUqgkydDhcXF6fk5ORK29q1a6evv/662v39/f3l7+9fZbuvr6/pSS/nTrHUNeTOPuTNPuTNfuTOPuTNPuTNPuTt3H7cmilJurpD42pzdK7cxTb01Z8Ht9FT327Ryz/t1ohL4tUopOr7Uk/lDs+52pzf1O5wffv21Y4dlVfh3blzpxITE02KCAAAAPXVgeMF2pqeK28vi65Kjqn1z9/UM1HtG4cpr7BUz8/e7oQI4SqmFkEPP/ywVqxYob///e/avXu3pk2bpnfeeUf33nuvmWEBAACgHirvCtczKUIRwX61/nlvL4ueHZUiSfpq7SGt2Zft0PjgOqYWQd27d9c333yjzz77TCkpKXr22Wf16quv6uabbzYzLAAAANRD5QukDkk59wKpF9KlaUP9rnuCJOnJ/6Wq1GpzSGxwLVPvCZKk4cOHa/jw4WaHAQAAgHosI7dQ6w6clCQNbm9/ESRJjw1pq9mpR7X9aJ4+WbFft/dlncu6xtSRIAAAAMAV5m4pGwXq0rSBYsICLupYEcF+emxIG0nSyz/uVGZu4UXHB9eiCAIAAEC954ipcGf7Xfem6hQfrryiUv2DJgl1DkUQAAAA6rXsU8VamVbWxGBI+ziHHNPby6JnrkmRxSJ9s/6wVu497pDjwjUoggAAAFCv/bQ1Q1aboeS4MDWNDHLYcTslNNDYHmWLpj71bapKaJJQZ1AEAQAAoF6bc+Z+oKEOmgp3tscGt1HDIF/tzMjXR8v2Ofz4cA6KIAAAANRbeYUlWrIrS5Lj7gc6W4MgP/3f0LaSpFfm7VQGTRLqBIogAAAA1FsLtmeq2GpT86hgtYwOcco5buyaoM5NG+hUsVV/+36bU84Bx6IIAgAAQL1V3hVuaEqsLBaLU87h5WXRs9ekyMsizdp4RMt2ZznlPHAciiAAAADUS6eLrfp5xzFJjusKdy4pTcI1rleiJOnpmVtUXEqTBHdGEQQAAIB6afGuYzpdYlWTBoFKaRLm9PP9aWAbRQb7aXdmvj5Ymub088F+FEEAAACol85eINVZU+HOFh7kq8evbidJ+vf8XTpy8rTTzwn7UAQBAACg3ikutemnbRmSnNMa+1yu69xE3RIbqqDYqudokuC2KIIAAABQ7yzbk6W8wlJFhfqrS9OGLjuvl5dFz5xpkvD95nQt3nnMZedGzVEEAQAAoN6Ze2aB1EHJMfLycv5UuLMlNw7T+D7NJEmTZm5RUanVpefHhVEEAQAAoF6x2gz9uKV8Kpxzu8Kdy8MDW6tRiL/2Zp3Se7/QJMHdUAQBAACgXlm9L1vHTxUrPNBXPZtHmBJDWICv/jqsrSTpPwt26dCJAlPiQPUoggAAAFCvlHeFG5gcI19v897ujrqkiXokRaiwxKZnv9tqWhyoiiIIAAAA9YbNZlTcDzSkveu6wlXHYrHo2WtS5O1l0dwtGVq4I9PUePAriiAAAADUG5sO5yg9p1DBft7q16qR2eGoTWyobj+rSUJhCU0S3AFFEAAAAOqN2anpkqQBbaMV4OttcjRlHhrYWjFh/tp/vEDvLN5rdjgQRRAAAADqCcMwNPfM/UBDXLhA6oWE+Pvor8OSJUlvLNytg9k0STAbRRAAAADqhR0Zedp3vEB+Pl4a0Cba7HAqGdExTr2bR6qo1KbJs7aYHY7HowgCAABAvTB7c9ko0GWtohTs72NyNJVZLBY9O6q9fLws+mlbpn7ammF2SB6NIggAAAD1QkVXODeaCne2ltGhmnBpkiRp8nc0STATRRAAAADqvLSsU9p+NE8+XhZd1c69psKd7YErWikuPEAHs0/rzZ/3mB2Ox6IIAgAAQJ1XvkBq7xaRahDkZ3I05xbs76Onhpc1SXh70R7tyzplckSeiSIIAAAAdd4cN58Kd7ahKbG6tFUjFZfaNGnWFhmGYXZIHociCAAAAHXakZOntfHgSVks0sDkGLPDuSCLxaJJI9vL19uin3cc0480SXA5iiAAAADUaeUNEbonRig6NMDkaGqmRVSIfn9Zc0nSM7O26nQxTRJciSIIAAAAddrsM/cDDa4DU+HOdu+AlmrSIFCHT57WGwt3mx2OR6EIAgAAQJ11LK9Iq/dlS5IGt3f/qXBnC/L7tUnCO4v3au+xfJMj8hwUQQAAAKizftqWIcOQOsaHK75hkNnh1Nrg9jG6vE2Uiq02TZxJkwRXoQgCAABAnVUxFa593ZoKV85isWjSiPby8/bSL7uyKlp9w7koggAAAFAn5Zwu0bLdWZLK2k7XVc0aBesP/c80Sfhuq04VlZocUf1HEQQAAIA6af62DJXaDLWOCVHzqBCzw7kofxzQUvENA5WeU6j/LKBJgrNRBAEAAKBOKp86NqSOToU7W4CvtyaNaC9Jeu+XvdqdmWdyRPUbRRAAAADqnFNFpVq085gkaUhKnMnROMZVyTG6sm20Sm2Gnv6WJgnORBEEAACAOmfRzmMqKrWpaUSQ2sWFmh2Ow0wa2V7+Pl5atue4vtuUbnY49RZFEAAAAOqc8qlwQ1NiZbFYTI7GcRIigvTHy1tKkv72/Vbl0yTBKSiCAAAAUKcUlVq1YHumJGlwHe4Kdy5392+uxMggZeQW6bWfdpodTr1EEQQAAIA6ZenuLOUXlSomzF+XxDcwOxyHC/D11qSRZU0SPli6TzuO0iTB0SiCAAAAUKfM3vxrVzgvr/ozFe5sA9pEa1ByjKw2Q09/m0qTBAejCAIAAECdUWq1ad62DEn1cyrc2Z4ekawAXy+tTMvWtxuOmB1OvUIRBAAAgDpjVVq2ThaUKCLYTz2aRZgdjlPFNwzS/Ve0kiQ998M25RaWmBxR/UERBAAAgDpj9pmucAPbxcjHu/6/lb3z0iQlNQrWsbwivTpvl9nh1Bv1/5kDAACAesFmMzR3y5n7gTrU76lw5fx9vDX5TJOEj5bv07b0XJMjqh8oggAAAFAnrD94Qpl5RQr191GfFpFmh+Myl7WO0tUdYmW1GXrqfzRJcASKIAAAANQJ5QukXtEuWv4+3iZH41pPDktWoK+31uw/oRnrDpsdTp1HEQQAAAC3ZxiG5pyZCje0nneFq07jBoF64MqyJgn/mL1NOadpknAxKIIAAADg9rYcydXB7NMK8PXSZa2jzA7HFBP6JalFVLCy8ov18o87zA6nTqMIAgAAgNsrb4hweetoBfn5mByNOfx8vPTMNSmSpE9W7Ffq4RyTI6q7KIIAAADg9spbYw/xwKlwZ+vbspGGd4yTzZCe+jZVNhtNEuxBEQQAAAC3tjszT7sz8+XrbdGAttFmh2O6J4clK9jPW+sPnNRXaw+ZHU6dRBEEAAAAtzZ3S4akslGQ8EBfk6MxX2x4gB66qrUk6fk523WyoNjkiOoeiiAAAAC4tdmp6ZKkIe09eyrc2W7r20ytY0KUfapYL86lSUJtUQQBAADAbR3MLlDq4Vx5WaSByTFmh+M2fL1/bZIwbdUBbTp00tyA6hiKIAAAALit8q5wPZIiFBnib3I07qVX80iNuqSxDEN66n80SagNiiAAAAC4rTnlXeGYCletJ65up1B/H208lKPPVx80O5w6gyIIAAAAbikzt1BrD5yQJA328NbY5xIdFqCHB5Y1Sfjn3O3KPkWThJqgCAIAAIBbmrs1Q4YhXZLQQHHhgWaH47Zu7Z2otrGhOllQohfnbjc7nDqBIggAAABuae6ZqXBDGQU6Lx9vLz07qqxJwuerD2r9mdEznBtFEAAAANzOiVPFWr73uCRpCEXQBXVvFqHru8SXNUn4NlVWmiScF0UQAAAA3M5P2zJktRlqFxemxMhgs8OpE/5vaFuFBvgo9XCupq06YHY4bo0iCAAAAG6HrnC1FxXqr0cHtZEkvThnu7Lyi0yOyH1RBAEAAMCt5BeV6pddWZKYCldb43olqn3jMOUWluqF2TRJOBeKIAAAALiVhdszVWy1qXmjYLWOCTE7nDrF28uiZ64pa5Iwfe0hrd2fbXJE7okiCAAAAG6lfCrc4JRYWSwWk6Ope7omNtTobvGSpCf/t0WlVpvJEbkfiiAAAAC4jcISqxbuyJREa+yL8ZchbRUe6Ktt6bn6dMV+s8NxOxRBAAAAcBuLdx5TQbFVjcMD1KFJuNnh1FmRIf768+CyJgkv/bhTx/JoknA2iiAAAAC4jTlbmArnKGN7NFXH+HDlFZXqHz9sMzsct0IRBAAAALdQYrXpp60ZkqShKXEmR1P3eXtZ9Ow1KbJYpBnrD2vlmcVnQREEAAAAN7F8z3HlFpaqUYifuiY2NDuceqFTQgP9rntTSdLT325RCU0SJFEEAQAAwE2UT4Ub1D5W3l5MhXOUxwa3UcMgX+3IyNNHy/aZHY5boAgCAACA6aw2Qz+eKYKGtKcrnCM1DPbTX4a0lSS9+tMuZeQWmhyR+SiCAAAAYLq1+08oK79YYQE+6tU80uxw6p3R3RLUKaGB8otK9XeaJFAEAQAAwHyzU9MlSVclx8jPh7eojublZdHfzjRJ+HbDES3bk2V2SKbiGQYAAABTGYahualMhXO2DvHhGtczURJNEiiCAAAAYKrNh3N0JKdQQX7euqx1lNnh1GuPDmqjiGA/7c7M15SlaWaHYxqKIAAAAJhq9plRoAFtohXg621yNPVbeJCv/m/or00S0nNOmxyROSiCAAAAYBrDMDSnfCpcClPhXOGGLvHqmthQBcVW/e17z2ySQBEEAAAA0+zMyFda1in5eXtpQNtos8PxCF5eFj1zTXt5WaTvN6VryS7Pa5JAEQQAAADTlI8CXdqqkUL8fUyOxnO0bxyuW3s3kyQ9PTNVxaWe1SSBIggAAACmmbOFqXBmeXhgazUK8dfeY6f03pK9ZofjUqYWQZMmTZLFYqn01bZtWzNDAgAAgIvsP35K29Jz5e1l0VXtYswOx+OEB/rqiavL3nv/Z/5uHT7pOU0STB8Jat++vdLT0yu+lixZYnZIAAAAcIHyqXC9m0eqYbCfydF4pms7N1GPZhE6XWLV377banY4LmN6EeTj46PY2NiKr0aNGpkdEgAAAFygvDX2YKbCmcZiseiZUe3l7WXR7NSjWrTzmNkhuYTpd5/t2rVLjRs3VkBAgHr37q1//OMfatq0abX7FhUVqaioqOL73NxcSVJJSYlKSkpcEu+5lJ/f7DjqInJnH/JmH/JmP3JnH/JmH/Jmn7qUt/ScQm04eFIWi3RF60jTY65LuXO0FpGBurVXU01Ztl9P/y9V39/fR/4+NRsrcae81SYGi2EYhhNjOa/Zs2crPz9fbdq0UXp6uiZPnqzDhw8rNTVVoaGhVfafNGmSJk+eXGX7tGnTFBQU5IqQAQAA4ACL0y36ep+3kkINPZRiNTscj1dYKj23wVu5JRYNS7BqULxpJYLdCgoKdNNNNyknJ0dhYWHn3dfUIui3Tp48qcTERL388suaMGFClcerGwlKSEhQVlbWBS/U2UpKSjRv3jwNHDhQvr6+psZS15A7+5A3+5A3+5E7+5A3+5A3+9SlvI37YLVWpp3Q40Na646+zcwOp07lzllmbUrXI9M3K8DXS7Pv76v4hoEX/Bl3yltubq4aNWpUoyLI9OlwZ2vQoIFat26t3bt3V/u4v7+//P39q2z39fU1Penl3CmWuobc2Ye82Ye82Y/c2Ye82Ye82cfd83Y8v0ir952QJF3dsYlbxeruuXOma7sk6Mu1h7Vib7b+Pmen3r21W41/1h3yVpvzm94Y4Wz5+fnas2eP4uLizA4FAAAATjJva4ZshpTSJEwJEdzS4C4sFouevSZFPl4WzduaoQXbM8wOyWlMLYIeffRRLVq0SPv27dOyZct07bXXytvbW2PHjjUzLAAAADhRxQKp7ekK525axYRqQr8kSdKkmVtVWFI/79cytQg6dOiQxo4dqzZt2mj06NGKjIzUihUrFBUVZWZYAAAAcJLcwhIt3Z0lSRqSwuwfd3T/la0UGxagA9kFenvRHrPDcQpT7wn6/PPPzTw9AAAAXGzBtkyVWA21jA5Ry+gQs8NBNUL8ffTk8Ha6b9p6vfnzHl3XOV5NI+vXtEW3uicIAAAA9ducMwukDmWBVLc2rEOc+rVspOJSmybN2iI3aijtEBRBAAAAcImC4lL9vDNTkjSY+4HcmsVi0aSR7eXrbdGC7Zn6aVum2SE5FEUQAAAAXGLxzmMqLLEpISJQ7Rubu8YjLqxldIjuvLS5JGnSzC06XVx/miRQBAEAAMAlZqf+2hXOYrGYHA1q4v4rWqpxeIAOnzytN3+ufi3PuogiCAAAAE5XVGrVgjNTqoZwP1CdEeTno6dHJEuS/rtor9KyTpkckWNQBAEAAMDplu05rryiUkWH+qtzQkOzw0EtDG4fq8taR6nYatPEmfWjSQJFEAAAAJxuzuayqXCD28fKy4upcHWJxWLR5JHt5eftpcU7j2numcVu6zKKIAAAADhVqdWmedsyJNEau65KahSsu/uXNUl4ZtZWFRSXmhzRxaEIAgAAgFOt2pet7FPFahDkqx5JEWaHAzv98fKWatIgUEdyCvX6grrdJIEiCAAAAE4190xXuIHtYuTjzdvPuirQz1uTRraXJL37y17tOJqnlWnZWptl0cq0bFltdedeIR+zAwAAAED9ZbMZmrvlzFS4DkyFq+uuahetK9pGa8H2TI34zxIVW22SvPXxrjWKCw/QxBHJGpISZ3aYF0QpDgAAAKfZcOikjuYWKsTfR31bNjI7HFwki8WiAW2iJOlMAfSrozmFuufTdZqTmm5GaLVCEQQAAACnKZ8Kd0XbaPn7eJscDS6W1WbozZ/3VPtY+WS4ybO2uv3UOIogAAAAOIVhGJp9pghigdT6YVVattJzCs/5uCEpPadQq9KyXReUHSiCAAAA4BTb0vN0ILtA/j5e6t86yuxw4ACZeecugOzZzywUQQAAAHCKOWcW1ezfOkrB/vTjqg+iQwMcup9ZKIIAAADgFOU3yDMVrv7okRShuPAAWc7xuEVSXHiA268HRREEAAAAh9tzLF87M/Ll42XRle1izA4HDuLtZdHEEcmSVKUQKv9+4ohkeXudq0xyDxRBAAAAcLg5Zxoi9GnZSOGBviZHA0cakhKnt8Z1UWx45SlvseEBemtclzqxThCTMwEAAOBwc8/cDzSUqXD10pCUOA1MjtXy3Zn68ZeVGnRpT/VuGe32I0DlKIIAAADgUIdOFGjToRxZLNLAZKbC1VfeXhb1TIrQ8W2GeiZF1JkCSGI6HAAAABxs7pYMSVL3ZhFqFOJvcjRAVRRBAAAAcKi5qUyFg3ujCAIAAIDDZOYVavX+bEnS4PYUQXBPFEEAAABwmHlbM2QYUqeEBmrcINDscIBqUQQBAADAYcpbYw9hFAhujCIIAAAADnGyoFjL9xyXJA3hfiC4MYogAAAAOMT8bZkqtRlqGxuqpEbBZocDnBNFEAAAABxi9pmpcDREgLujCAIAAMBFO1VUqsW7jkmShnagCIJ7owgCAADARVu4I1PFpTY1iwxSm5hQs8MBzosiCAAAABetvCvc4JRYWSwWk6MBzo8iCAAAABelsMSqhdszJUlDU+JMjga4MIogAAAAXJQlu7J0qtiquPAAdWwSbnY4wAVRBAEAAOCizNnya1c4Ly+mwsH9UQQBAADAbiVWm+ZtzZDEAqmoO3xq+wNFRUVauXKl9u/fr4KCAkVFRalz585KSkpyRnwAAABwYyv3ZivndIkig/3UvVmE2eEANVLjImjp0qV67bXXNGvWLJWUlCg8PFyBgYHKzs5WUVGRmjdvrt///vf6wx/+oNBQ2iICAAB4gjlb0iVJg9rHyJupcKgjajQdbuTIkRozZoyaNWumH3/8UXl5eTp+/LgOHTqkgoIC7dq1S08++aTmz5+v1q1ba968ec6OGwAAACaz2QzN3VI2FW5we6bCoe6o0UjQsGHD9PXXX8vX17fax5s3b67mzZtr/Pjx2rp1q9LT0x0aJAAAANzPugMndCyvSKEBPurTopHZ4QA1VqMi6O67767xAZOTk5WcnGx3QAAAAKgbZp9ZIPWqdjHy86HfFuqOWjdGOFtqaqoWLVokq9Wqvn37qmvXro6KCwAAAG7MMAzNOVME0RUOdY3dJfsbb7yhK6+8UosWLdLChQt1xRVX6LnnnnNkbAAAAHBTqYdzdfjkaQX6euuyVlFmhwPUSo1Hgg4ePKiEhISK719//XVt2bJFjRqVzf9cvny5Ro4cqb/+9a+OjxIAAABupbwr3OVtohTo521yNEDt1Hgk6KqrrtJrr70mwzAkSZGRkZozZ46KioqUl5enn376SVFRfAoAAADgCZgKh7qsxkXQ6tWrtWPHDvXs2VMbNmzQO++8o1deeUWBgYFq0KCBvvjiC3300UfOjBUAAABuYFdGnvYcOyU/by9d0Tba7HCAWqvxdLiwsDC9+eabWrZsmW677TZdccUV+uWXX2S1WmW1WtWgQQMnhgkAAAB3UT4K1K9VI4UGVL+ECuDOat0YoU+fPlqzZo0aNmyozp07a/HixRRAAAAAHqS8NfYQFkhFHVXjkaDS0lK988472rZtmzp16qQnnnhCY8aM0R/+8Ad9+OGHev311xUTE+PMWAEAAGCyA8cLtDU9V95eFl2VzHs/1E01HgmaMGGCXn/9dQUHB2vKlCl6+OGH1bp1ay1YsEBDhgxR79699dZbbzkzVgAAAJhs7payUaCeSRGKCPYzORrAPjUugr799lt9/fXXev755zVv3jx9//33FY9NmDBBK1as0C+//OKUIAEAAOAeZqeWtcamKxzqshoXQTExMfrxxx9VXFysBQsWKDIystLj0dHRmjZtmsMDBAAAgHvIyC3UugMnJUmDuR8IdViN7wl6/fXXdfPNN+uRRx5RXFycvvzyS2fGBQAAADdTPhWuS9MGigkLMDkawH41LoIGDhyojIwMZWVlsSgqAACABypvjT00Jc7kSICLU6sW2RaLhQIIAADAA2WfKtbKtGxJTIVD3VejImjIkCFasWLFBffLy8vTCy+8oDfeeOOiAwMAAID7+Glrhqw2Q8lxYWoaGWR2OMBFqdF0uBtvvFHXX3+9wsPDNWLECHXr1k2NGzdWQECATpw4oa1bt2rJkiX64YcfNGzYML344ovOjhsAAAAuNGdL+VQ4RoFQ99WoCJowYYLGjRun6dOn64svvtA777yjnJwcSWVT5JKTkzV48GCtXr1a7dq1c2rAAAAAcK28whIt2ZUlidbYqB9q3BjB399f48aN07hx4yRJOTk5On36tCIjI+Xr6+u0AAEAAGCuBdszVWy1qUVUsFrFhJodDnDRalwE/VZ4eLjCw8MdGQsAAADcUHlrbEaBUF/UqjscAAAAPMvpYqsWbj8mSRrSntbYqB8oggAAAHBOi3cd0+kSq5o0CFRKkzCzwwEcgiIIAAAA51S+QOqQlFhZLBaTowEcgyIIAAAA1SoutemnbRmSaI2N+sWuIujkyZN677339Pjjjys7u2zl4HXr1unw4cMODQ4AAADmWbYnS3mFpYoK9VeXpg3NDgdwmFp3h9u0aZOuuuoqhYeHa9++fbrrrrsUERGhGTNm6MCBA/r444+dEScAAABcrLwr3KDkGHl5MRUO9UetR4IeeeQR3Xbbbdq1a5cCAgIqtl999dVavHixQ4MDAACAOaw2Qz9uKZ8KR1c41C+1LoJWr16tu+++u8r2Jk2a6OjRow4JCgAAAOZavS9bx08VKzzQVz2bR5gdDuBQtS6C/P39lZubW2X7zp07FRUV5ZCgAAAAYK7yrnADk2Pk600vLdQvtX5Gjxw5Us8884xKSkokSRaLRQcOHNBf/vIXXX/99Q4PEAAAAK5lsxkV9wMNaU9XONQ/tS6CXnrpJeXn5ys6OlqnT59W//791bJlS4WGhuq5555zRowAAABwoU2Hc5SeU6hgP2/1a9XI7HAAh6t1d7jw8HDNmzdPS5cu1caNG5Wfn68uXbroqquuckZ8AAAAcLHyqXAD2kYrwNfb5GgAx6t1EfTxxx9rzJgx6tu3r/r27Vuxvbi4WJ9//rluvfVWhwYIAAAA1zEMQ3NS0yVJQ1ggFfVUrafD3X777crJyamyPS8vT7fffrtDggIAAIA5dmTkad/xAvn5eGlAm2izwwGcotZFkGEYsliqLpZ16NAhhYeHOyQoAAAAmGP25rKpcJe1ilKwf60nDQF1Qo2f2Z07d5bFYpHFYtGVV14pH59ff9RqtSotLU1DhgxxSpAAAABwjfKucEOZCod6rMZF0KhRoyRJGzZs0ODBgxUSElLxmJ+fn5o1a0aLbAAAgDosLeuUth/Nk4+XRVe2Yyoc6q8aF0ETJ06UJDVr1kxjxoxRQECA04ICAACA65V3hevdIlINgvxMjgZwnlpP9Bw/frwz4gAAAIDJ5pQvkMpUONRztS6CrFarXnnlFX355Zc6cOCAiouLKz2enZ3tsOAAAADgGkdOntbGgydlsUgDk2PMDgdwqlp3h5s8ebJefvlljRkzRjk5OXrkkUd03XXXycvLS5MmTXJCiAAAAHC28oYI3RMjFB3KbQ+o32pdBE2dOlXvvvuu/vSnP8nHx0djx47Ve++9p6efflorVqxwRowAAABwstln7gcazFQ4eIBaF0FHjx5Vhw4dJEkhISEVC6cOHz5c33//vWOjAwAAgNMdyyvS6n1ltzQMbs9UONR/tS6C4uPjlZ6eLklq0aKFfvzxR0nS6tWr5e/v79joAAAA4HQ/bcuQYUgd48MV3zDI7HAAp6t1EXTttddq/vz5kqT7779fTz31lFq1aqVbb71Vd9xxh92BPP/887JYLHrooYfsPgYAAABqr2IqXHumwsEz1Lo73PPPP1/x9zFjxigxMVHLli1Tq1atNGLECLuCWL16tf773/+qY8eOdv08AAAA7JNzukTLdmdJkoZyPxA8RK1Hgn6rV69eeuSRRzRixAitWbOm1j+fn5+vm2++We+++64aNmx4seEAAACgFuZvy1CpzVDrmBA1jwoxOxzAJWo9EpSfny9vb28FBgZWbNuwYYOeeuop/fDDD7JarbU63r333qthw4bpqquu0t/+9rfz7ltUVKSioqKK73NzcyVJJSUlKikpqdV5Ha38/GbHUReRO/uQN/uQN/uRO/uQN/uQN/vYk7fZm8vu9R7ULtqj881zzj7ulLfaxGAxDMOoyY4HDx7U6NGjtWrVKnl7e+u+++7T3/72N/3hD3/QF198oWuvvVYPP/ywevbsWeOTf/7553ruuee0evVqBQQE6PLLL9cll1yiV199tdr9J02apMmTJ1fZPm3aNAUFcRMfAABAbRRZpb+u9laJYdFjHUvVJNjsiAD7FRQU6KabblJOTo7CwsLOu2+NR4L+/Oc/q7CwUK+99ppmzJih1157Tb/88ot69uypPXv2KD4+vlZBHjx4UA8++KDmzZungICaLcj1+OOP65FHHqn4Pjc3VwkJCRo0aNAFL9TZSkpKNG/ePA0cOFC+vr6mxlLXkDv7kDf7kDf7kTv7kDf7kDf71DZvs1OPqmTVJiU0DNSdN/STxWJxQZTuieecfdwpb+WzxGqixkXQ4sWLNWPGDPXq1UujR49WbGysbr75Zru7ua1du1aZmZnq0qVLxTar1arFixfr9ddfV1FRkby9vSv9jL+/f7VtuH19fU1Pejl3iqWuIXf2IW/2IW/2I3f2IW/2IW/2qWneftpe1hDh6g5x8vPzc3ZYdQLPOfu4Q95qc/4aF0EZGRlKSkqSJEVHRysoKEhDhw6tfXRnXHnlldq8eXOlbbfffrvatm2rv/zlL1UKIAAAADhOUalVC7ZnSpIG0xUOHqZWjRG8vLwq/f1iPjEIDQ1VSkpKpW3BwcGKjIyssh0AAACOtXR3lvKLShUbFqBL4huYHQ7gUjUuggzDUOvWrSvmiubn56tz586VCiNJys7OdmyEAAAAcLg5FQukxsjLy3PvBYJnqnERNGXKFGfGIUn6+eefnX4OAAAAT1dqtWne1gxJTIWDZ6pxETR+/HhnxgEAAAAXWZWWrRMFJYoI9lOPZhFmhwO4nNeFdwEAAEB9MvvMVLiB7WLk483bQXgenvUAAAAexGYzNHdLWRE0pANT4eCZKIIAAAA8yPqDJ5WZV6RQfx/1aRFpdjiAKSiCAAAAPMic1HRJ0hXtouXvw7qM8EwUQQAAAB7CMAzNOTMVbihd4eDBarVYqiRZrVZ9+OGHmj9/vjIzM2Wz2So9vmDBAocFBwAAAMfZciRXB7NPK8DXS5e1jjI7HMA0tS6CHnzwQX344YcaNmyYUlJSKhZPBQAAgHsrb4hweetoBfnV+m0gUG/U+tn/+eef68svv9TVV1/tjHgAAADgJOWtsYcwFQ4ertb3BPn5+ally5bOiAUAAABOsjszT7sz8+XrbdEV7aLNDgcwVa2LoD/96U967bXXZBiGM+IBAACAE8zdkiFJ6tuykcICfE2OBjBXrafDLVmyRAsXLtTs2bPVvn17+fpWfhHNmDHDYcEBAADAMWafaY09pD1T4YBaF0ENGjTQtdde64xYAAAA4AQHswuUejhXXhZpYHKM2eEApqt1ETRlyhRnxAEAAAAnKe8K1yMpQpEh/iZHA5jP7t6Ix44d044dOyRJbdq0UVQUveYBAADc0ZzU8gVS40yOBHAPtW6McOrUKd1xxx2Ki4vTZZddpssuu0yNGzfWhAkTVFBQ4IwYAQAAYKfM3EKtPXBCkjSoPVPhAMmOIuiRRx7RokWLNGvWLJ08eVInT57Ut99+q0WLFulPf/qTM2IEAACAneZuzZBhSJckNFBceKDZ4QBuodbT4b7++mt99dVXuvzyyyu2XX311QoMDNTo0aP11ltvOTI+AAAAXIS5FVPh6AoHlKv1SFBBQYFiYqoOpUZHRzMdDgAAwI2cOFWs5XuPS5KGUAQBFWpdBPXu3VsTJ05UYWFhxbbTp09r8uTJ6t27t0ODAwAAgP1+2pYhq81Qu7gwJUYGmx0O4DZqPR3utdde0+DBgxUfH69OnTpJkjZu3KiAgADNnTvX4QECAADAPuWtsVkgFais1kVQSkqKdu3apalTp2r79u2SpLFjx+rmm29WYCA32wEAALiD/KJSLd6VJYmpcMBv2bVOUFBQkO666y5HxwIAAAAHWbg9U8WlNjVvFKzWMSFmhwO4lRoVQTNnztTQoUPl6+urmTNnnnffkSNHOiQwAAAA2K98gdTBKbGyWCwmRwO4lxoVQaNGjdLRo0cVHR2tUaNGnXM/i8Uiq9XqqNgAAABgh8ISqxbuyJREa2ygOjUqgmw2W7V/BwAAgPtZuvu4CoqtatIgUB2ahJsdDuB2at0i++OPP1ZRUVGV7cXFxfr4448dEhQAAADsN3drhiRpcHumwgHVqXURdPvttysnJ6fK9ry8PN1+++0OCQoAAAD2sdqk+duPSaIrHHAutS6CDMOo9hOFQ4cOKTyc4VYAAAAzWG2GVqZl64eDXsotLFVksK+6JjY0OyzALdW4RXbnzp1lsVhksVh05ZVXysfn1x+1Wq1KS0vTkCFDnBIkAAAAzm1Oaromz9qq9JxClX/GfbrEpnlbj2pISpy5wQFuqMZFUHlXuA0bNmjw4MEKCfm137yfn5+aNWum66+/3uEBAgAA4NzmpKbrnk/XyfjN9oJiq+75dJ3eGteFQgj4jRoXQRMnTpQkNWvWTGPGjFFAQIDTggIAAMCFWW2GJs/aWqUAOtvkWVs1MDlW3l40SADK1fqeoPHjx1MAAQAAuIFVadlnpsBVz5CUnlOoVWnZrgsKqANqPBJUzmq16pVXXtGXX36pAwcOqLi4uNLj2dm8yAAAAFwhM+/cBZA9+wGeotYjQZMnT9bLL7+sMWPGKCcnR4888oiuu+46eXl5adKkSU4IEQAAANWJDq3Z7Jya7gd4iloXQVOnTtW7776rP/3pT/Lx8dHYsWP13nvv6emnn9aKFSucESMAAACq0SMpQnHhATrX3T4WSXHhAeqRFOHKsAC3V+si6OjRo+rQoYMkKSQkpGLh1OHDh+v77793bHQAAAA4J28viyaOSK62MUJ5YTRxRDJNEYDfqHURFB8fr/T0dElSixYt9OOPP0qSVq9eLX9/f8dGBwAAgPMa3D5WiZFBVbbHhgfQHhs4h1o3Rrj22ms1f/589ezZU/fff7/GjRun999/XwcOHNDDDz/sjBgBAABwDmv2n9D+4wXy9bbo1dEdtXLNOg26tKd6t4xmBAg4h1oXQc8//3zF38eMGaOmTZtq+fLlatWqlUaMGOHQ4AAAAHB+7/+SJkm6oWu8BiXHqHSfoZ5JERRAwHnUugj6rd69e6t3796OiAUAAAC1cOB4geZuPSpJuqNvksnRAHVHjYqgmTNn1viAI0eOtDsYAAAA1NyUZWkyDOmy1lFqFROqkpISs0MC6oQaFUGjRo2q0cEsFousVuvFxAMAAIAayC0s0ZerD0qS7uzHKBBQGzUqgmw2m7PjAAAAQC18seqgThVb1TomRJe2amR2OECdUqMW2RERETp+/Lgk6Y477lBeXp5TgwIAAMC5lVpt+nDZPknShH5JslhoggDURo2KoOLi4opFUT/66CMVFhY6NSgAAACc25wtR3X45GlFBvvpmkuamB0OUOfUaDpc7969NWrUKHXt2lWGYeiBBx5QYGBgtft+8MEHDg0QAAAAlb13pi32uF6JCvD1NjkaoO6pURH06aef6pVXXtGePXtksViUk5PDaBAAAIAJ1u4/oQ0HT8rP20vjeiWaHQ5QJ9WoCIqJialYJDUpKUmffPKJIiMjnRoYAAAAqnp/yV5J0qjOjRUV6m9yNEDdVOvFUtPS0pwRBwAAAC7gYHaB5qSeWRyVttiA3WpdBEnS/PnzNX/+fGVmZlZpn809QQAAAM7x4bJ9shnSpa0aqW1smNnhAHVWrYugyZMn65lnnlG3bt0UFxdHS0YAAAAXyCss0RdnFkdlFAi4OLUugt5++219+OGHuuWWW5wRDwAAAKrxxeqDyi8qVcvoEPVvFWV2OECdVqN1gs5WXFysPn36OCMWAAAAVOPsxVHv6JskLy9m4gAXo9ZF0J133qlp06Y5IxYAAABU48etGTp04rQaBvnqui4sjgpcrFpPhyssLNQ777yjn376SR07dpSvr2+lx19++WWHBQcAAADp/SUsjgo4Uq2LoE2bNumSSy6RJKWmplZ6jCYJAAAAjrX+wAmt3X9Cft5euqU3i6MCjlDrImjhwoXOiAMAAADVKB8FGtGpsaJDA0yOBqgfan1PEAAAAFzj8MnTmn1mcdQJtMUGHKbGI0HXXXddjfabMWOG3cEAAADgVx8t2yerzVCfFpFKbsziqICj1LgICg8Pd2YcAAAAOEt+Uak+W3lAknTnpYwCAY5U4yJoypQpzowDAAAAZ5m+5qDyikrVPCpYl7eONjscoF7hniAAAAA3Y7UZ+mBpWUMEFkcFHI8iCAAAwM3M25qhg9mn1SDIV9d3iTc7HKDeoQgCAABwM+8v2StJurlnUwX6sTgq4GgUQQAAAG5k48GTWr3vhHy9Lbq1dzOzwwHqJYogAAAAN1KxOGrHxooJY3FUwBkoggAAANzEkZOn9cPmdEnSHSyOCjgNRRAAAICb+Gj5PpXaDPVqHqGUJqzRCDgLRRAAAIAbOHXW4qgT+jU3ORqgfqMIAgAAcANfrT2k3MJSNYsM0pVtWRwVcCaKIAAAAJNZbYamlC+O2o/FUQFnowgCAAAw2fxtGdp3vEDhgb66oSuLowLORhEEAABgsvK22GN7NFWQn4/J0QD1H0UQAACAiVIP52hlWrZ8vCwa3yfR7HAAj0ARBAAAYKLyUaBhHeMUFx5ocjSAZ6AIAgAAMMnRnELN2nhEkjSBxVEBl6EIAgAAMMnHZxZH7dEsQh3jG5gdDuAxKIIAAABMUFBcqmmryhZHvYNRIMClKIIAAABM8PW6wzpZUKKmEUEamBxjdjiAR6EIAgAAcDGbzdCUMw0Rbu/bTN4sjgq4FEUQAACAiy3ckam9WacUGuCjG7slmB0O4HEoggAAAFzs7MVRQ/xZHBVwNVOLoLfeeksdO3ZUWFiYwsLC1Lt3b82ePdvMkAAAAJxqy5EcLdtzXN5eFo3v08zscACPZGoRFB8fr+eff15r167VmjVrdMUVV+iaa67Rli1bzAwLAADAaT5Ysk+SNDQlVk0asDgqYAZTx19HjBhR6fvnnntOb731llasWKH27dubFBUAAIBzZOYWaubGw5KkOy9tbnI0gOdym0moVqtV06dP16lTp9S7d+9q9ykqKlJRUVHF97m5uZKkkpISlZSUuCTOcyk/v9lx1EXkzj7kzT7kzX7kzj7kzT71NW8fLk1TidVQl6YN1D422OHXV1/z5grkzj7ulLfaxGAxDMNwYiwXtHnzZvXu3VuFhYUKCQnRtGnTdPXVV1e776RJkzR58uQq26dNm6agoCBnhwoAAGC3Yqs0aZ23TpVadHtrqy6JNPUtGFDvFBQU6KabblJOTo7CwsLOu6/pRVBxcbEOHDignJwcffXVV3rvvfe0aNEiJScnV9m3upGghIQEZWVlXfBCna2kpETz5s3TwIED5evra2osdQ25sw95sw95sx+5sw95s099zNvnqw/pqZlbFd8gQD89fKlT1gaqj3lzFXJnH3fKW25urho1alSjIsj06XB+fn5q2bKlJKlr165avXq1XnvtNf33v/+tsq+/v7/8/f2rbPf19TU96eXcKZa6htzZh7zZh7zZj9zZh7zZp77kzWYz9OHy/ZKk2/s1V4C/n1PPV1/yZgZyZx93yFttzu926wTZbLZKoz0AAAB13aJdx7Tn2CmF+PtodLd4s8MBPJ6pI0GPP/64hg4dqqZNmyovL0/Tpk3Tzz//rLlz55oZFgAAgEO9/0vZ4qi/656g0ABGGQCzmVoEZWZm6tZbb1V6errCw8PVsWNHzZ07VwMHDjQzLAAAAIfZfjRXS3ZnycsiFkcF3ISpRdD7779v5ukBAACcrnwUaGhKnBIi6GYLuAO3uycIAACgvjiWV6RvNxyRJN3RL8nkaACUowgCAABwkk9W7Fex1abOTRuoa2JDs8MBcAZFEAAAgBMUllg1dUVZW+wJjAIBboUiCAAAwAn+t/6wjp8qVpMGgRrSPtbscACchSIIAADAwQzD0PtLyhoi3NanmXy8ecsFuBNekQAAAA62eFeWdmXmK9jPW2N6JJgdDoDfoAgCAABwsPJRoNHdExTG4qiA26EIAgAAcKCdGXlavPOYvCzS7X1oiAC4I4ogAAAAB/rgzCjQoORYNY1kcVTAHVEEAQAAOEhWfpFmrD8sSbrzUkaBAHdFEQQAAOAgU1ccUHGpTZ3iw1kcFXBjFEEAAAAOUFhi1Scr9kmSJlzaXBaLxdyAAJwTRRAAAIADzNx4RFn5xYoLD9DQFBZHBdwZRRAAAMBFMgyjoiHCbX2ayZfFUQG3xisUAADgIi3dfVzbj+YpyM9bv+vR1OxwAFwARRAAAMBFem/JXknS6G4JCg9kcVTA3VEEAQAAXITdmXn6eccxWSzS7X2bmR0OgBqgCAIAALgIHyzdJ0m6ql2MEiODzQ0GQI1QBAEAANgp+1Sxvl57SJJ0Zz8WRwXqCoogAAAAO01buV9FpTalNAlTj6QIs8MBUEMUQQAAAHYoKrXqo+X7JUl39mNxVKAuoQgCAACww3cb03Usr0gxYf66ukOc2eEAqAWKIAAAgFoyDEPvnVkcdXyfZvLz4S0VUJfwigUAAKil5XuPa1t6rgJ9vXUTi6MCdQ5FEAAAQC29/0vZKNANXePVIMjP5GgA1BZFEAAAQC3sPZav+dszJbE4KlBXUQQBAADUwgdLy0aBrmoXreZRISZHA8AeFEEAAAA1dLKgWF+dWRz1DhZHBeosiiAAAIAamrrygApLbEqOC1Pv5pFmhwPAThRBAAAANVBcatPHy/dJkib0S2JxVKAOowgCAACoge83H1FGbpGiQ/01olNjs8MBcBEoggAAAC7AMAy9f2Zx1Ft7J7I4KlDH8QoGAAC4gJVp2Uo9nKsAXy/d1DPR7HAAXCSKIAAAgAsoHwW6rku8IoJZHBWo6yiCAAAAzmNf1in9tC1DknRHX9piA/UBRRAAAMB5TFmaJsOQBrSJUstoFkcF6gOKIAAAgHPIKSjRl2vKFke989LmJkcDwFEoggAAAM7hs9UHdLrEqraxoerTgsVRgfqCIggAAKAaJVabPly6TxKLowL1DUUQAABANX7YnK6juYVqFOKvkZewOCpQn1AEAQAA/MZvF0f19/E2OSIAjkQRBAAA8Btr9p/QpkM58vPx0s09m5odDgAHowgCAAD4jfd+2StJur5LE0WG+JscDQBHowgCAAA4y/7jp/TjVhZHBeoziiAAAICzTFm6T4Yh9W8dpVYxoWaHA8AJKIIAAADOyDldoulrDkoqa4sNoH6iCAIAADjji9UHdKrYqtYxIbq0VSOzwwHgJBRBAAAAkkpZHBXwGBRBAAAAkmanHtWRnEJFBvvpmkuamB0OACeiCAIAAB7PMAy9d2Zx1HG9EhXgy+KoQH1GEQQAADzeugMntPHgSfn5eGlcr0SzwwHgZBRBAADA471/ZhRo1CWNFRXK4qhAfUcRBAAAPNrB7ALNST0qSbqDttiAR6AIAgAAHu3DZftkM6RLWzVS29gws8MB4AIUQQAAwGPlFZboi9Vli6MyCgR4DoogAADgsb5YfVD5RaVqGR2i/q2izA4HgItQBAEAAI9UarXpw2X7JEl39E2SlxeLowKegiIIAAB4pB+3ZujQidNqGOSr67qwOCrgSSiCAACAR3qfxVEBj0URBAAAPM76Aye0dv8J+Xl76ZbeLI4KeBqKIAAA4HHKR4FGdGqs6NAAk6MB4GoUQQAAwKMcPnlas88sjjqBttiAR6IIAgAAHuWjZftktRnq0yJSyY1ZHBXwRBRBAADAY+QXleqzlQckSXdeyigQ4KkogoA6yGoztDItW2uzLFqZli2rzTA7JADV4LXqfqavOai8olI1jwrW5a2jzQ4HgEl8zA4AQO3MSU3X5FlblZ5TKMlbH+9ao7jwAE0ckawhKXFmhwfgDF6r7sdqM/TB0rKGCCyOCng2RoKAOmROarru+XTdmTdVvzqaU6h7Pl2nOanpJkUG4Gy8Vt3TvK0ZOph9Wg2CfHV9l3izwwFgIoogoI6w2gxNnrVV1U2mKd82edZWptsAJimx2pSZV6gtR3L0xDepvFbd0PtL9kqSbu7ZVIF+LI4KeDKmwwF1xKq07CqfKp/NkJSeU6hVadnq3SLSdYEB9ZBhGCootir7VHHF1/FTxTrxmz+zTxXpREGJjucXKbewtGbHFq9VM2w8eFKr952Qr7dFt/ZuZnY4AExGEQTUEek5p2u0X2beuQslwFNZbYZOFhSfu6g589jx/LK/Hz9VrOJSW63PY7FIQb7eOlVsveC+vFZdq2Jx1I6NFRPG4qiAp6MIAtxcTkGJPlt9QO8s3lOj/ZfsylLv5pGK5j95ONDZXc4i07LVu2W0vE28qbywxFo2EpNfrOyCshGZ7FMlZ/4srvJ18nSJDDtmn/n5eCky2E8Rv/0K8lNEiJ8ig/3UMMhPkSFlfzYI8tOqtGyNfXfFBY994lSxHVcOexw5eVo/bC67D+sOFkcFIIogwG3tyzqlKUvTNH3tIRWc+VTZyyJd6DaC6WsP6Zv1hzW4faxu7tVUvZtHymKhAxLs5+wuZzabodzCkjPTy2r2dbrkwiMt1QkP9K22mIkIOvP92X8P9lOQn3etXz89kiIUFx6gozmF1d4XVG7SrK3afDhX/ze0raJC/e26HtTMR8v3qdRmqFfzCKU0CTc7HABugCIIcCOGYWjF3my9vyRN87dnVHxy3TY2VHf0S1KAj5ce/HxD2b5n/Vz5W7Tb+jbT5kM5WrP/hL7fnK7vN6erRVSwbu6ZqOu7xis80NeVl4N6oLzL2W/fzJd3OXtrXJcqhVBRqVUnTpXo+G9GZc6eenY8/8y2gmKdKCixq0mAr7dFEWeNxEQE+ysiyLfsz+CyPxsG+yoy2F8RwX5qEOQrX2/n9wPy9rJo4ohk3fPpOllU/Wu1b8tILd1zXF+vO6Qftx7Vo4PaaFyvRFNH1+qrU2ctjjqhX3OTowHgLiiCADdQXGrTd5uO6P0ladpyJLdi+4A2UZrQr7n6tvx1NMfPx+usT+XLxP7mU/lt6bmaunK/vll3WHuOndIz323VP+du1zWdmmhcr0R1iOeTUFxYTToSPvzFRn2x+qCyC8qmop04VaL8opo1CPitUH8fRZyZVvbbKWgNg89MPQv+9bEQfx+3HeUckhKnt8Z1Oe9rdf2BE3rq21SlHs7VxJlb9OWag3rmmhR1TWxoYuT1z1drDym3sFTNIoN0ZVsWRwVQhiIIMNGJU8WaunK/Pl6+X5l5RZKkAF8vXdclXnf0TVLL6JAqPzMkJU4Dk2O1fHemfvxlpQZd2rPK/Rnt4sL0t1Ed9H9D2+mb9Yc1dcV+bT+apy/WHNQXaw6qU3y4bu6VqBEdG9MmFue0cu/x83YklKTTJVYt3HGsynZvL0tFMVM+GtPwzOjM2cXM2ffT+PnUr1UbLvRa7dy0ob69t5+mrTqgF+ds15Yjubr+rWUa3S1efxnSVpEhTJG7WFaboSnli6P2Y3FUAL+iCAJMsDszXx8sTdOMdYdUWFLWgSo61F/j+zTTTT2aqmGw33l/3tvLop5JETq+zVDPpIhzTqEJ8ffRLb0SNa5nU63df0KfrtivHzYf1cZDOdr41SY99/023dA1Xjf3bKrmUVULLngewzC0/uBJzdp4RF+vO1SjnxnbI0ED2kRXFDORwf4KDfDhDacu/Fr19rLoll6JGpoSqxdmb9f0tYf05ZpDmrslQ48NaaPfdW/KFLmLMH9bhvYdL1B4oK9u6MriqAB+RREEuIhhGFqyO0vvL0nTz2d9ct6+cZjuvDRJwzo0dton4RaLRd2aRahbswg9NbxIX645pGmr9utg9mm9vyRN7y9JU7+WjTSuV1Nd1S5GPi64bwLuwzAMbTmSq1mbjui7jek6fLJm7djLjezUhPVuLlKjEH+9eGMnjemeoKe+3aJt6bn66zep+mL1QT17TYo6JTQwO8Q6qbwt9tgeTRXkx1seAL/iXwTAyQpLrJq54Yg+WJqm7UfzJJWtJXJVuxhN6JeknkkRLr2vITLEX/dc3kJ3X9Zci3Yd06fL92vBjkwt2Z2lJbuzFBPmr991b6qxPZoqNpw22/XZzow8zdp4RN9tSlda1qmK7UF+3hqYHKNhKXF6emaqMnKLqr0vyKKye1x6JEW4LOb6rluzCM26r68+XbFfL/24U5sO5WjUm0s1tkdT/XlQmwuOEuNXqYdztDItWz5eFo3vk2h2OADcDEUQ4CRZ+UX6dMV+fbpiv7Lyy9YDCfLz1o1d43V73yQ1axRsanxeXhYNaBOtAW2idehEgT5bdUBfrD6ojNwivTZ/l15fuFsD28VoXK9E9WkRydSmeiIt65S+23hEszYd0c6M/Irt/j5eurJdtIZ3bKwBbaIr7hWzyThvl7OJI5KZruVgPt5euq1vkq7uGKfnf9iuGesPa9rKA5q9OV3/N7StbuyawOuxBspHgYZ1jFNceKDJ0QBwNxRBgIPtOJqn95fs1f82HKlYcb5xeIDG92mm33VvqvAg92tTHd8wSH8e3FYPXtlac7cc1Scr9mtVWrbmbDmqOVuOKqlRsG7u2VQ3dI1XgyA+ia5rDp0o0Heb0vXdpiNKPfxr90Ffb4v6t47SiE6NdWW7GIX4V/0voSZdzuAc0aEBennMJRrTPUFPf7tFOzLy9JevN+uzVQf1t1EprHdzHkdzCjVr4xFJ0gQWRwVQDYogwAFsNkOLdh3TB0vS9MuurIrtnRIaaEK/JA1NiXXJ+iQXy8/HSyM6NdaITo21MyNPU1fs19frDist65T+9v02vTh3h0Z0aqxxvRLVKT7cbdsTQ8rILdT3m9I1a9MRrT9wsmK7t5dFfVs20vCOcRqcHFujorwmHQnhPD2bR+q7B/rpo2X79Mq8ndpw8KRGvr5E43ol6k8D27jlBytm+/jM4qg9mkWoY3wDs8MB4IYogoCLcLrYqhnrD+mDJWnac6zsngovizQkJVYT+iWpS9OGdbZQaB0TqsnXpOixIW317YYj+nTFfm1Nz9VXaw/pq7WHlNIkTON6JmrkJY254dhNHM8v0g+pR/XdxiNatS+7YrFdi0XqmRShEZ0aa0j7WLtaL9e0IyGcw9fbS3de2lwjOjXWc99v08yNR/Tx8v36flPZFLnru8QzRe6MguJSTS1fHPVSRoEAVI93LoAdMnML9fHy/Zq6cr9OFJRIKmtHPaZ7gm7r00wJEUEmR+g4wf4+uqlnU43tkaD1B0/q0xX79d2mdKUeztX/zdis537Ypuu7xGtcr6ZqGR1qdrgeJ6egRHO3HNWsTUe0bM9xWW2/3rnTpWkDjejUWFd3iFNMGE0u6oOYsAD9e2xn/a57gp6euUW7M/P15682lXWRG5WidnFhZodouq/XHVbO6RI1jQjSVe1izA4HgJuiCAJqIfVwjj5YkqZZm46oxFr2ZjO+YaBu75uk0d3iFRpQf6elWCwWdWnaUF2aNtRTw5I1fe1BTV15QPuPF+jDZfv04bJ96tU8QuN6JWpQcmy9W/jSneQXlWre1qP6bmO6Fu86VvFclKQOTcI1vGOchnWMU3zD+lOMo7I+LRvphwcu1QdL0/Tv+bu0Zv8JDf/PEt3aO1EPD2ytsHr8b9H52GyGPjjTEOGOvs0YsQRwTqYWQf/4xz80Y8YMbd++XYGBgerTp49eeOEFtWnTxsywgEpsNkPzt2fq/SV7tWJvdsX2bokNNaFfkga1j/W4/2gbBvvp95e10J39muuX3Vn6dMV+zd+WoRV7s7Vib7aiQv31u+4JGtujqRo3oCuTI5wutmrB9kx9t+mIFmzPVNGZphuS1CYmVCM6xWl4x8amdx2E6/j5eOkP/Vto5Jkpct9vTteUpfv03aZ0/fXqdrrmksZ1djquvRbuyFRa1imFBvjoxm4JZocDwI2ZWgQtWrRI9957r7p3767S0lI98cQTGjRokLZu3argYP4jh7lOFZXq63Vl9/vsO14gqey+iKs7xGlCvyRdwuKF8vIq6y7Wv3WUjpw8rc9XHdBnqw/qWF6R/rNgt95YuFtXnmmzfWnLRtyzUEtFpVYt3pmlWRuP6KdtGSootlY81rxRsIZ3jNPwTo3VOoZpiJ6scYNAvXFzF43ZeUyTZm7R3qxTeuiLDZq26oCevSZFbWI95/nx3i9lo0A39Wiq4Gq6HQJAOVP/hZgzZ06l7z/88ENFR0dr7dq1uuyyy6rsX1RUpKKioorvc3PLWr2WlJSopKTEucFeQPn5zY6jLnK33KXnFOqTFQf0xZpDyi0slSSFBfhoTLd43dKrqeLOLCBqdrzulreoYB/dP6C5/nBZM/20LVPTVh3UirQTmrc1Q/O2ZqhpRKB+1z1e13duoggTF3x0t7z9VonVpuV7s/X95qOaty1TeWeeg5LUpEGAhnWI1dUpsUqOC634lN9V1+LuuXNXrspb76QGmnlvb01Zuk9vLNqrVWnZuvrfv+i23k1134AW1bZAd2e1zdvW9Fwt33tc3l4W3dwj3mOfp7xO7Ufu7ONOeatNDBbDMKpbCNwUu3fvVqtWrbR582alpKRUeXzSpEmaPHlyle3Tpk1TUBBz33Fx9udLPx/x0objFtnOLAXZKMDQ5XE29Ygy5O9tcoB1UMZpaelRL606ZtFpa1lOfSyGOkca6htrU7OQss5lns5mSHtyLVqXZdHGbItOlf6alHBfQ5c0MtQl0qZE8oUayi6SvtnnpU3ZZffmhfsaGtXMps6RRr19Dn2620urj3mpc6RNt7W2XfgHANQ7BQUFuummm5STk6OwsPM3inGbIshms2nkyJE6efKklixZUu0+1Y0EJSQkKCsr64IX6mwlJSWaN2+eBg4cKF9fz7wh1V5m5s5qMzRvW6Y+XLZfa89aS6VnUkPd3jtRl7eJctv7ferSc66guFTfbz6qaasOKfXIr4t1to0N1U094jWyY5zLpq64S95sNkPrD57U96kZmpN6VMfyiyseiwj21dD2sbq6Q4y6NW3oNtMI3SV3dY2ZeVu085ie+X67DmSfliT1bh6hp4e1VcvoEJfGYY/a5C0zr0iXv7RYJVZDX93dU53iPXchWV6n9iN39nGnvOXm5qpRo0Y1KoLcZmz83nvvVWpq6jkLIEny9/eXv3/V9S18fX1NT3o5d4qlrnFl7vIKS/TlmkOasjRNh06UvTnw9bZoRMfGuqNfUp1aib0uPOfCfX11U68k3dQrSRsPntQnK/Zr1sYj2n40T0/P3KZ/zt2l67o00bheiS67v8WMvBmGoc2HczRr4xF9vyldR3IKKx4LD/TVkPaxGtGpsXo1j5CPGy+uWxeec+7IjLxd1b6x+rWO0TuL9+qNhbu1fG+2Rr65XBP6NdcDV7asE2t81SRvn63eqxKroa6JDdUtqZGLInNvvE7tR+7s4w55q8353eJfv/vuu0/fffedFi9erPj4eLPDQT12MLusnfMXqw8qv6jsXouGQb66uWeibumdyFoqLtApoYE6JTTQk8Pa6au1hzRt5QHtzTqlj5fv18fL96tHswiN652oIe3rR5ttwzC0/Wievtt0RLM2putAdkHFYyH+PhqUHKPhneLUr2VUvbheuJ8AX289cGUrjbqkiSbP2qL52zP19qI9mrnhsJ4anqwhKbF1uotcYYlVU1fulyTd2Y/FUQHUjKlFkGEYuv/++/XNN9/o559/VlIS/3jB8QzD0LoDJ/T+kjTNST2q8rUkW0QF645+Sbquc7wC/bjhx9UaBPnpzkuba0K/JC3bc1yfLN+vedsytGpftlbty1ajED+N7lbWZrsuLj67OzNf3206ou82pWt3Zn7F9gBfL13ZLkYjOjbW5W2iFODLcw+u0TQySO/f1l0/bc3QpFlbdOjEad0zdZ0ubdVIz1yToqQ62l59xrrDOlFQoviGgRrUPtbscADUEaYWQffee6+mTZumb7/9VqGhoTp69KgkKTw8XIGBrC2Ci1NitWl26lG9vyRNGw+erNh+aatGuqNfkvq3inKbey08mcViUd+WjdS3ZSMdzSnU56sP6LNVB5SRW6Q3f96jtxbt0YA20bqlV6Iua+2+92hJZSONs86M+GxL//XeJz9vL13eJkrDOzXWlW2jad0LU12VHKN+rRrpzYW79faivfplV5YGv7JYv7+sue4d0LJOfShksxl6f8leSdLtfZPc+t8HAO7F1P+J33rrLUnS5ZdfXmn7lClTdNttt7k+INQLOadL9PmqA/po2b6Key78fLw06pKy+33axprbRAPnFhseoIeuaq17B7TU/G0Z+nTFAS3ZnaUF2zO1YHum4hsG6qaeTTW6W4IahVS9P9AM6Tmn9f2mdM3alF6p2Pbxsqhfq0Ya0bGxBraPUVgA88vhPgJ8vfXIoDa6rku8Js7cokU7j+n1hbv1zfrDmjgiWQOTY+rEFLlFu45pz7FTCvH30ehuTKcHUHOmT4cDHGVf1ilNWZqm6WsPVSwqGRnsp1t6J2pcr0S3edOMC/P19tKQlDgNSYnT3mP5mrbygKavPaRDJ07rn3N26JV5O3V1hziN65WobokNXf5m7VhekWanpmvWxiNave9ExXYvi9S7RaSGd2ysIe1j1dDE9ZCAmmjWKFgf3t5dc7dk6NnvturwydP6/SdrNaBNlCaNbK/ESPeeIvf+mcVRf9c9QaF80ACgFpiTgTrNMAytTMvW+0vS9NO2DJXX1W1iQjWhX5JGXtKYey7quOZRIXpyeLIeHdxGszYe0acrD2jjwZP6dsMRfbvhiNrEhGpcr6Ya1bmJU98EnThVrDlbjuq7TUe0fM/xinvLJKl7s4Ya0amxhqbEKSqUYht1i8Vi0ZCUWF3WupHeWLhb7yzeq4U7jmnpK4t1T/8WuufyFm757+j2o7lasjtLXhZpfJ9mZocDoI6hCEKdVFxq0/ebj+i9X9K05ax1Zy5vE6U7+zVX35aRdWIqB2ouwNdbN3ZL0I3dEpR6OEefrtiv/204rB0ZeXrq2y16fvZ2jepc1ma7XZxjpjzmFpZo3pYMzdp0REt2Zan0rMqnU0IDjegYp6s7xKlxA+5hRN0X5OejPw9uWzZF7tstWrI7S6/N36Vv1h/WpJHJuqJtjNkhVlI+CjQ0Ja5ONk8BYC6KINQpJ04Va9qZ+30y88oWzg3w9dJ1XeJ1R99mahntmjVmYK6UJuF6/vqOevzqdpqx7pA+XbFfe46d0tSVBzR15QF1TWyocb2aamhKXKVPsK22spHDtVkWRaZlq3fL6Co3UhcUl+qnbZn6buMR/bzzmIpLf115vl1cmEZ0itPwDo3VNJI3XaifWkSF6JMJPfTD5qN69rutOpBdoDs+XKOByTF6eniyWxQcx/KK9O2GI5KkO2iLDcAOFEGoE/Ycy9cHS9L09bpDKiwpe1MaHeqv8X2a6aYeTbn3wkOFB/rq9r5Juq1PM63Ym61PV+zX3C1HtXb/Ca3df0LPfrdNN3aL1809ErU1PUeTZ21Vek6hJG99vGuN4sIDNHFEsi5vE62fdxzTrE1HtGBbpk6XWCvO0SIqWCM6Ndbwjo3VMjrEvIsFXMhisWhYxzhd3iZK/56/S+8vSdO8rRlavPOY7hvQUr/v31z+PuZNkftkxX4VW23q3LSBuiY2NC0OAHUXRRBMdb5P5g3D0NLdx/X+krL56eXaNw7ThH5JGt6xMYtLQlLZG7beLSLVu0WkMnML9cXqg/ps1QEdySnUfxft1X8X7a3259JzCvWHT9cpwMdLhWeN+DSNCCob8enYWG1jQ5laCY8V7O+jx69upxu6xuupb1O1Ym+2Xpq3U1+vO6TJ16Sof+sol8dUWGLV1BVli6NOYBQIgJ0ogmCaOanp1X4y//jQtiostemDJWnafjRPkmSxSFe2jdGEfknq1TyCN6U4p+iwAN1/ZSvdc3kLLdxxTB8v36dfdmWd92cKS22KC/PX8E6NNaJTY3VoEs5zDDhLq5hQfXZXL83ceETPfb9N+44XaPwHqzSkfayeGpGsJi68L+5/6w/r+KliNWkQqCEsjgrAThRBMMWc1HTd8+k6/bZJenpOoR74fEPF94G+3hrdLV639U2qs6uZwxw+3l4amByjEH+fCxZBkvTS6EvUp2UjF0QG1E0Wi0XXXNJEV7SN1qs/7dKHy/ZpzpajWrTzmO6/sqXu7Nfc6aPzhmHo/SVlDRFu69NMPt7MBgBgH4oguJzVZmjyrK1VCqCzeVmkRwe30c09EhUexNoPsF9mXmGN9juWX+TkSID6ITTAV08NT9aN3eL19P+2aNW+bP1zzg59tfaQnr0mRX2d+GHC4l1Z2pWZr2A/b43pkeC08wCo/yiC4HCFJVZl5RcpK79YWXlFZ/5e9v2xvCLtOZZ/ZgrcudkMqXNCQwogXLTo0ACH7gegTNvYMH1xdy99s/6w/v7DNu09dko3v7dSwzrG6alhyYoNd/xrqnwUaHT3BIWxOCqAi0ARhBopLLHqWN6vxUxWftFZ3xcpK+/Mtvwi5RWWOuScNf0EHzifHkkRigsP0NGcwmpHHy2SYsMD1CMpwtWhAXWexWLRdV3idWW7GL0yb6c+Xr5P329K18/bM/XgVa10e98k+TpoytrOjDwt3nlMXhbp9j40RABwcSiCPNjpYmtF4ZKVV/5n8VkjN7+O3uQX1a6w8fP2UqMQPzUK9VejEP+yv4eU/f1kQbH+vWD3BY/BJ/NwBG8viyaOSNY9n66TRapUCJW3Ppg4IrnKekEAai480FeTRrbXjd3i9dT/UrXuwEn9/Yftmr7mkJ65JkW9W0Re9Dk+ODMKNCg5lnW6AFw0iiAHqMkCjK5SUFyqrLxiHcsv1LHfFjRnjdZk5RXpVLH1wgc8i5+Pl6J+U9BEhfr/ptjxV1SIv8ICfc7ZXctqMzR97SE+mYfLDEmJ01vjupzVjbBM7Jl1goakxJkYHVB/tG8crq/+0EdfrTuk52dv167MfI19d4VGXdJYT1zdTtFh9n24lZVfpBnrD0uS7ryUUSAAF48i6CKdq82zI99YnSoqrTT17NhZ99r8dopaQS0LG38fr7LiJdRfUSF+Z4qas7/KCpyoUH+F+p+7sKkNPpmHGYakxGlgcqyW787Uj7+s1KBLe5r6gQVQX3l5WTS6W4IGJcfoXz/u0NSVB/S/DUf007ZMPTywtcb3Tqx1V7epKw6ouNSmTvHhLI4KwCEogi7Cudo8H80p1D2frtNb47pUWwgZhqH8otKKwqWioDkz9ey3ozdnr15fEwG+XmeN0pSPzvxmtObMCE6Igwqb2uKTeZjB28uinkkROr7NUM+kCAogwIkaBPnpb6M6aEy3pnry21RtPHhSz363VdPXHNSzo1LUvVnNRvuLSqz6ZMU+SdKES5uzhhcAh6AIstP52jyXb3vsq03afDhH2aeKK01NO5ZXpKKzVqeviUBf71+nnp0ZuSkvbiqN3oT6K9jPu078J8En8wBQ/3WID9c39/TRF2sO6oU527X9aJ5ufHu5ruvSRI8PbaeoUP/z/vyszUeVlV+suPAADU1hcVQAjkERZKdVadkXbPOcW1iqNxbuOefjwX7e1TYOKC9qokJ/3RbsXz9/VXwyDwD1n5eXRWN7NNWQ9rH659zt+nz1Qc1Yd1jztmbo0UFtNK5XYrX//huG9OGy/ZLKFkd1VKc5AKif76xdoKbtm/u1bKTuzSLU6KyCJirEX41C/RTkR/oBAJ6jYbCf/nFdR43ulqCnvk1V6uFcTZy5RV+uOahnrkmpuN+nvOHQ9we8tCMjX4G+Xvpdj6YmRw+gPuFduJ1q2r753gEtHdIaFACA+qJz04b69t5+mrbqgF6cs11bjuTq+reWaXS3eHVvFqGX5+08M9uibOTHYrFo+Z4s7hcF4DCMK9upfAHGc03eskiKo80zAADV8vay6JZeiVr46OW6sWu8JOnLNYf05682VZluXlBs1T2frtOc1HQzQgVQD1EE2am8zbOkKoUQbZ4BAKiZyBB/vXhjJ315dy/5XOD/zMmztspqq64lEQDUDkXQRShv8xwbXnlqXGx4wDnbYwMAgKqsNqn0PAWOISk9p1Cr0rJdFxSAeot7gi4SbZ4BALh4NW04VNP9AOB8KIIcgDbPAABcnJo2HKrpfgBwPkyHAwAApqPhEABXoggCAACmo+EQAFeiCAIAAG6BhkMAXIV7ggAAgNug4RAAV6AIAgAAboWGQwCcjelwAAAAADwKRRAAAAAAj0IRBAAAAMCjUAQBAAAA8CgUQQAAAAA8CkUQAAAAAI9CEQQAAADAo1AEAQAAAPAoFEEAAAAAPApFEAAAAACPQhEEAAAAwKNQBAEAAADwKBRBAAAAADyKj9kBXAzDMCRJubm5JkcilZSUqKCgQLm5ufL19TU7nDqF3NmHvNmHvNmP3NmHvNmHvNmHvNmP3NnHnfJWXhOU1wjnU6eLoLy8PElSQkKCyZEAAAAAcAd5eXkKDw8/7z4Woyalkpuy2Ww6cuSIQkNDZbFYTI0lNzdXCQkJOnjwoMLCwkyNpa4hd/Yhb/Yhb/Yjd/Yhb/Yhb/Yhb/Yjd/Zxp7wZhqG8vDw1btxYXl7nv+unTo8EeXl5KT4+3uwwKgkLCzP9CVBXkTv7kDf7kDf7kTv7kDf7kDf7kDf7kTv7uEveLjQCVI7GCAAAAAA8CkUQAAAAAI9CEeQg/v7+mjhxovz9/c0Opc4hd/Yhb/Yhb/Yjd/Yhb/Yhb/Yhb/Yjd/apq3mr040RAAAAAKC2GAkCAAAA4FEoggAAAAB4FIogAAAAAB6FIggAAACAR6EIOss//vEPde/eXaGhoYqOjtaoUaO0Y8eOSvsUFhbq3nvvVWRkpEJCQnT99dcrIyOj0j4PPPCAunbtKn9/f11yySXnPefu3bsVGhqqBg0aOPhqXMdVedu3b58sFkuVrxUrVjjz8pzGlc83wzD0r3/9S61bt5a/v7+aNGmi5557zlmX5nSuyt2kSZOqfc4FBwc78/KcxpXPublz56pXr14KDQ1VVFSUrr/+eu3bt89JV+Zcrszbl19+qUsuuURBQUFKTEzUiy++6KzLcglH5G7jxo0aO3asEhISFBgYqHbt2um1116rcq6ff/5ZXbp0kb+/v1q2bKkPP/zQ2ZfnNK7KW3p6um666Sa1bt1aXl5eeuihh1xxeU7jqrzNmDFDAwcOVFRUlMLCwtS7d2/NnTvXJdfoDK7K25IlS9S3b19FRkYqMDBQbdu21SuvvOKSa6wORdBZFi1apHvvvVcrVqzQvHnzVFJSokGDBunUqVMV+zz88MOaNWuWpk+frkWLFunIkSO67rrrqhzrjjvu0JgxY857vpKSEo0dO1aXXnqpw6/FlVydt59++knp6ekVX127dnX4NbmCK/P24IMP6r333tO//vUvbd++XTNnzlSPHj2ccl2u4KrcPfroo5Wea+np6UpOTtaNN97otGtzJlflLS0tTddcc42uuOIKbdiwQXPnzlVWVla1x6kLXJW32bNn6+abb9Yf/vAHpaam6s0339Qrr7yi119/3WnX5myOyN3atWsVHR2tTz/9VFu2bNFf//pXPf7445XykpaWpmHDhmnAgAHasGGDHnroId1555119o2pq/JWVFSkqKgoPfnkk+rUqZNLr9EZXJW3xYsXa+DAgfrhhx+0du1aDRgwQCNGjND69etder2O4qq8BQcH67777tPixYu1bds2Pfnkk3ryySf1zjvvuPR6Kxg4p8zMTEOSsWjRIsMwDOPkyZOGr6+vMX369Ip9tm3bZkgyli9fXuXnJ06caHTq1Omcx3/ssceMcePGGVOmTDHCw8MdHb5pnJW3tLQ0Q5Kxfv16Z4VuKmflbevWrYaPj4+xfft2p8VuNme/Vstt2LDBkGQsXrzYYbGbyVl5mz59uuHj42NYrdaKbTNnzjQsFotRXFzs+AtxMWflbezYscYNN9xQadu///1vIz4+3rDZbI69CJNcbO7K/fGPfzQGDBhQ8f1jjz1mtG/fvtI+Y8aMMQYPHuzgKzCHs/J2tv79+xsPPvigQ+M2myvyVi45OdmYPHmyYwI3mSvzdu211xrjxo1zTOC1xEjQeeTk5EiSIiIiJJVVuSUlJbrqqqsq9mnbtq2aNm2q5cuX1+rYCxYs0PTp0/XGG284LmA34cy8SdLIkSMVHR2tfv36aebMmY4J2g04K2+zZs1S8+bN9d133ykpKUnNmjXTnXfeqezsbMdegImc/Zwr995776l169Z1fvS2nLPy1rVrV3l5eWnKlCmyWq3KycnRJ598oquuukq+vr6OvQgTOCtvRUVFCggIqLQtMDBQhw4d0v79+x0QufkclbucnJyKY0jS8uXLKx1DkgYPHnxRr3d34qy81XeuypvNZlNeXl69ya2r8rZ+/XotW7ZM/fv3d1DktUMRdA42m00PPfSQ+vbtq5SUFEnS0aNH5efnV+X+nZiYGB09erTGxz5+/Lhuu+02ffjhhwoLC3Nk2KZzZt5CQkL00ksvafr06fr+++/Vr18/jRo1ql4UQs7M2969e7V//35Nnz5dH3/8sT788EOtXbtWN9xwgyMvwTTOzN3ZCgsLNXXqVE2YMOFiQ3YLzsxbUlKSfvzxRz3xxBPy9/dXgwYNdOjQIX355ZeOvARTODNvgwcP1owZMzR//nzZbDbt3LlTL730kqSyezfqOkflbtmyZfriiy/0+9//vmLb0aNHFRMTU+UYubm5On36tGMvxMWcmbf6zJV5+9e//qX8/HyNHj3aYfGbxRV5i4+Pl7+/v7p166Z7771Xd955p8OvoyZ8TDlrHXDvvfcqNTVVS5Yscfix77rrLt1000267LLLHH5sszkzb40aNdIjjzxS8X337t115MgRvfjiixo5cqTDz+dKzsybzWZTUVGRPv74Y7Vu3VqS9P7776tr167asWOH2rRp4/BzupIzc3e2b775Rnl5eRo/frxTz+Mqzszb0aNHddddd2n8+PEaO3as8vLy9PTTT+uGG27QvHnzZLFYHH5OV3H2/w179uzR8OHDVVJSorCwMD344IOaNGmSvLzq/meWjshdamqqrrnmGk2cOFGDBg1yYHTui7zZx1V5mzZtmiZPnqxvv/1W0dHRdp/LXbgib7/88ovy8/O1YsUK/d///Z9atmypsWPHXkzYdqn7/6o6wX333afvvvtOCxcuVHx8fMX22NhYFRcX6+TJk5X2z8jIUGxsbI2Pv2DBAv3rX/+Sj4+PfHx8NGHCBOXk5MjHx0cffPCBoy7D5Zydt+r07NlTu3fvvqhjmM3ZeYuLi5OPj09FASRJ7dq1kyQdOHDg4oI3mSufc++9956GDx9e5dPmusjZeXvjjTcUHh6uf/7zn+rcubMuu+wyffrpp5o/f75WrlzpqMtwOWfnzWKx6IUXXlB+fr7279+vo0ePVjQwad68uUOuwSyOyN3WrVt15ZVX6ve//72efPLJSo/FxsZW6caXkZGhsLAwBQYGOvZiXMjZeauvXJW3zz//XHfeeae+/PLLKtMx6yJX5S0pKUkdOnTQXXfdpYcffliTJk1y9KXUCEXQWQzD0H333advvvlGCxYsUFJSUqXHu3btKl9fX82fP79i244dO3TgwAH17t27xudZvny5NmzYUPH1zDPPKDQ0VBs2bNC1117rsOtxFVflrTobNmxQXFzcRR3DLK7KW9++fVVaWqo9e/ZUbNu5c6ckKTEx8SKvwhyufs6lpaVp4cKFdX4qnKvyVlBQUGXkwtvbW1LZyGRd4+rnm7e3t5o0aSI/Pz999tln6t27t6Kioi76OszgqNxt2bJFAwYM0Pjx46tt79+7d+9Kx5CkefPmXfT/MWZxVd7qG1fm7bPPPtPtt9+uzz77TMOGDXPOBbmImc+38tkqpjClHYObuueee4zw8HDj559/NtLT0yu+CgoKKvb5wx/+YDRt2tRYsGCBsWbNGqN3795G7969Kx1n165dxvr16427777baN26tbF+/Xpj/fr1RlFRUbXnrevd4VyVtw8//NCYNm2asW3bNmPbtm3Gc889Z3h5eRkffPCBS6/XUVyVN6vVanTp0sW47LLLjHXr1hlr1qwxevbsaQwcONCl1+tIrn6tPvnkk0bjxo2N0tJSl1yfs7gqb/PnzzcsFosxefJkY+fOncbatWuNwYMHG4mJiZXOVVe4Km/Hjh0z3nrrLWPbtm3G+vXrjQceeMAICAgwVq5c6dLrdSRH5G7z5s1GVFSUMW7cuErHyMzMrNhn7969RlBQkPHnP//Z2LZtm/HGG28Y3t7expw5c1x6vY7iqrwZhlHxPOzatatx0003GevXrze2bNnismt1JFflberUqYaPj4/xxhtvVNrn5MmTLr1eR3FV3l5//XVj5syZxs6dO42dO3ca7733nhEaGmr89a9/den1lqMIOoukar+mTJlSsc/p06eNP/7xj0bDhg2NoKAg49prrzXS09MrHad///7VHictLa3a89b1IshVefvwww+Ndu3aGUFBQUZYWJjRo0ePSu0a6xpXPt8OHz5sXHfddUZISIgRExNj3Hbbbcbx48dddKWO58rcWa1WIz4+3njiiSdcdHXO48q8ffbZZ0bnzp2N4OBgIyoqyhg5cqSxbds2F12pY7kqb8eOHTN69eplBAcHG0FBQcaVV15prFixwoVX6niOyN3EiROrPUZiYmKlcy1cuNC45JJLDD8/P6N58+aVzlHXuDJvNdmnrnBV3s71Wh4/frzrLtaBXJW3f//730b79u0r3sd17tzZePPNNystp+BKFsMwDAEAAACAh+CeIAAAAAAehSIIAAAAgEehCAIAAADgUSiCAAAAAHgUiiAAAAAAHoUiCAAAAIBHoQgCAAAA4FEoggAAAAB4FIogAAAAAB6FIggA4DYMw9BVV12lwYMHV3nszTffVIMGDXTo0CETIgMA1CcUQQAAt2GxWDRlyhStXLlS//3vfyu2p6Wl6bHHHtN//vMfxcfHO/ScJSUlDj0eAMD9UQQBANxKQkKCXnvtNT366KNKS0uTYRiaMGGCBg0apM6dO2vo0KEKCQlRTEyMbrnlFmVlZVX87Jw5c9SvXz81aNBAkZGRGj58uPbs2VPx+L59+2SxWPTFF1+of//+CggI0NSpU824TACAiSyGYRhmBwEAwG+NGjVKOTk5uu666/Tss89qy5Ytat++ve68807deuutOn36tP7yl7+otLRUCxYskCR9/fXXslgs6tixo/Lz8/X0009r37592rBhg7y8vLRv3z4lJSWpWbNmeumll9S5c2cFBAQoLi7O5KsFALgSRRAAwC1lZmaqffv2ys7O1tdff63U1FT98ssvmjt3bsU+hw4dUkJCgnbs2KHWrVtXOUZWVpaioqK0efNmpaSkVBRBr776qh588EFXXg4AwI0wHQ4A4Jaio6N19913q127dho1apQ2btyohQsXKiQkpOKrbdu2klQx5W3Xrl0aO3asmjdvrrCwMDVr1kySdODAgUrH7tatm0uvBQDgXnzMDgAAgHPx8fGRj0/Zf1X5+fkaMWKEXnjhhSr7lU9nGzFihBITE/Xuu++qcePGstlsSklJUXFxcaX9g4ODnR88AMBtUQQBAOqELl266Ouvv1azZs0qCqOzHT9+XDt27NC7776rSy+9VJK0ZMkSV4cJAKgDmA4HAKgT7r33XmVnZ2vs2LFavXq19uzZo7lz5+r222+X1WpVw4YNFRkZqXfeeUe7d+/WggUL9Mgjj5gdNgDADVEEAQDqhMaNG2vp0qWyWq0aNGiQOnTooIceekgNGjSQl5eXvLy89Pnnn2vt2rVKSUnRww8/rBdffNHssAEAbojucAAAAAA8CiNBAAAAADwKRRAAAAAAj0IRBAAAAMCjUAQBAAAA8CgUQQAAAAA8CkUQAAAAAI9CEQQAAADAo1AEAQAAAPAoFEEAAAAAPApFEAAAAACPQhEEAAAAwKP8P6KQ14ErFH3sAAAAAElFTkSuQmCC",
-            "text/plain": [
-              "
" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Read the CSV file\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Extract the year and inflation rate from the CSV file\u001b[39;00m\n\u001b[1;32m 8\u001b[0m df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mYear\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mto_datetime(df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mYear\u001b[39m\u001b[38;5;124m'\u001b[39m], \u001b[38;5;28mformat\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124mY\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "File \u001b[0;32m~/miniconda3/envs/stack/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1026\u001b[0m, in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[0m\n\u001b[1;32m 1013\u001b[0m kwds_defaults \u001b[38;5;241m=\u001b[39m _refine_defaults_read(\n\u001b[1;32m 1014\u001b[0m dialect,\n\u001b[1;32m 1015\u001b[0m delimiter,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1022\u001b[0m dtype_backend\u001b[38;5;241m=\u001b[39mdtype_backend,\n\u001b[1;32m 1023\u001b[0m )\n\u001b[1;32m 1024\u001b[0m kwds\u001b[38;5;241m.\u001b[39mupdate(kwds_defaults)\n\u001b[0;32m-> 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/stack/lib/python3.10/site-packages/pandas/io/parsers/readers.py:620\u001b[0m, in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 617\u001b[0m _validate_names(kwds\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m 619\u001b[0m \u001b[38;5;66;03m# Create the parser.\u001b[39;00m\n\u001b[0;32m--> 620\u001b[0m parser \u001b[38;5;241m=\u001b[39m \u001b[43mTextFileReader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m chunksize \u001b[38;5;129;01mor\u001b[39;00m iterator:\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parser\n", + "File \u001b[0;32m~/miniconda3/envs/stack/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1620\u001b[0m, in \u001b[0;36mTextFileReader.__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 1617\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptions[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwds[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 1619\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles: IOHandles \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1620\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_engine \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_engine\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mengine\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/stack/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1880\u001b[0m, in \u001b[0;36mTextFileReader._make_engine\u001b[0;34m(self, f, engine)\u001b[0m\n\u001b[1;32m 1878\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[1;32m 1879\u001b[0m mode \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m-> 1880\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;241m=\u001b[39m \u001b[43mget_handle\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mcompression\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcompression\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43mmemory_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmemory_map\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_text\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding_errors\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstrict\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstorage_options\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1889\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1890\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1891\u001b[0m f \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles\u001b[38;5;241m.\u001b[39mhandle\n", + "File \u001b[0;32m~/miniconda3/envs/stack/lib/python3.10/site-packages/pandas/io/common.py:873\u001b[0m, in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 868\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(handle, \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 869\u001b[0m \u001b[38;5;66;03m# Check whether the filename is to be opened in binary mode.\u001b[39;00m\n\u001b[1;32m 870\u001b[0m \u001b[38;5;66;03m# Binary mode does not support 'encoding' and 'newline'.\u001b[39;00m\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mencoding \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mmode:\n\u001b[1;32m 872\u001b[0m \u001b[38;5;66;03m# Encoding\u001b[39;00m\n\u001b[0;32m--> 873\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 874\u001b[0m \u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 875\u001b[0m \u001b[43m \u001b[49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 876\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 877\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 878\u001b[0m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 879\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 880\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 881\u001b[0m \u001b[38;5;66;03m# Binary mode\u001b[39;00m\n\u001b[1;32m 882\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mopen\u001b[39m(handle, ioargs\u001b[38;5;241m.\u001b[39mmode)\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/tmp/tmpco0s0o4_/LOdZoVp1inflation.csv'" + ] } ], "source": [ diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 33ca523633..5a78f5baea 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3974,6 +3974,41 @@ "stream": { "type": "boolean" }, + "documents": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/InterleavedContentItem" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/InterleavedContentItem" + } + }, + { + "$ref": "#/components/schemas/URL" + } + ] + }, + "mime_type": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "content", + "mime_type" + ] + } + }, "tools": { "type": "array", "items": { diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 4da311cf0c..72093b436a 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -618,6 +618,25 @@ components: properties: agent_id: type: string + documents: + items: + additionalProperties: false + properties: + content: + oneOf: + - type: string + - $ref: '#/components/schemas/InterleavedContentItem' + - items: + $ref: '#/components/schemas/InterleavedContentItem' + type: array + - $ref: '#/components/schemas/URL' + mime_type: + type: string + required: + - content + - mime_type + type: object + type: array messages: items: oneOf: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 18bbcd95c1..acf8fa7486 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -45,6 +45,11 @@ class Attachment(BaseModel): mime_type: str +class Document(BaseModel): + content: InterleavedContent | URL + mime_type: str + + class StepCommon(BaseModel): turn_id: str step_id: str @@ -272,6 +277,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): ] ] + documents: Optional[List[Document]] = None + tools: Optional[List[AgentTool]] = None + stream: Optional[bool] = False @@ -308,6 +316,7 @@ async def create_agent_turn( ] ], stream: Optional[bool] = False, + documents: Optional[List[Document]] = None, tools: Optional[List[AgentTool]] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2af1c820b8..43d5cbdb77 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -33,13 +33,18 @@ AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, Attachment, + Document, InferenceStep, ShieldCallStep, StepType, ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import TextContentItem, URL +from llama_stack.apis.common.content_types import ( + InterleavedContent, + TextContentItem, + URL, +) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -55,8 +60,8 @@ ToolResponseMessage, UserMessage, ) -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.memory import Memory, MemoryBankDocument +from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.providers.utils.kvstore import KVStore @@ -190,6 +195,7 @@ async def create_and_execute_turn( input_messages=messages, sampling_params=self.agent_config.sampling_params, stream=request.stream, + documents=request.documents, tools_for_turn=request.tools, ): if isinstance(chunk, CompletionMessage): @@ -240,6 +246,7 @@ async def run( input_messages: List[Message], sampling_params: SamplingParams, stream: bool = False, + documents: Optional[List[Document]] = None, tools_for_turn: Optional[List[AgentTool]] = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to @@ -257,7 +264,13 @@ async def run( yield res async for res in self._run( - session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn + session_id, + turn_id, + input_messages, + sampling_params, + stream, + documents, + tools_for_turn, ): if isinstance(res, bool): return @@ -352,6 +365,7 @@ async def _run( input_messages: List[Message], sampling_params: SamplingParams, stream: bool = False, + documents: Optional[List[Document]] = None, tools_for_turn: Optional[List[AgentTool]] = None, ) -> AsyncGenerator: tool_args = {} @@ -361,6 +375,7 @@ async def _run( tool_args[tool.name] = tool.args tool_defs = await self._get_tool_defs(tools_for_turn) + await self.handle_documents(session_id, documents, input_messages, tool_defs) if "memory" in tool_defs and len(input_messages) > 0: with tracing.span("memory_tool") as span: step_id = str(uuid.uuid4()) @@ -378,6 +393,11 @@ async def _run( "query": input_messages[-1], **extra_args, } + + session_info = await self.storage.get_session_info(session_id) + # if the session has a memory bank id, let the memory tool use it + if session_info.memory_bank_id: + args["memory_bank_id"] = session_info.memory_bank_id serialized_args = tracing.serialize_value(args) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -732,6 +752,112 @@ async def _get_tool_defs( return ret + async def handle_documents( + self, + session_id: str, + documents: List[Document], + input_messages: List[Message], + tool_defs: Dict[str, ToolDefinition], + ) -> None: + memory_tool = tool_defs.get("memory", None) + code_interpreter_tool = tool_defs.get("code_interpreter", None) + if documents: + content_items = [ + d for d in documents if isinstance(d.content, InterleavedContent) + ] + url_items = [d for d in documents if isinstance(d.content, URL)] + pattern = re.compile("^(https?://|file://|data:)") + url_items = [ + URL(uri=a.content) for a in url_items if pattern.match(a.content) + ] + # Save the contents to a tempdir and use its path as a URL if code interpreter is present + if code_interpreter_tool: + for c in content_items: + temp_file_path = os.path.join( + self.tempdir, f"{make_random_string()}.txt" + ) + with open(temp_file_path, "w") as temp_file: + temp_file.write(c.content) + url_items.append(URL(uri=f"file://{temp_file_path}")) + + if memory_tool and code_interpreter_tool: + # if both memory and code_interpreter are available, we download the URLs + # and attach the data to the last message. + msg = await attachment_message(self.tempdir, url_items) + input_messages.append(msg) + # Since memory is present, add all the data to the memory bank + await self.add_to_session_memory_bank(session_id, documents) + elif code_interpreter_tool: + # if only code_interpreter is available, we download the URLs to a tempdir + # and attach the path to them as a message to inference with the + # assumption that the model invokes the code_interpreter tool with the path + msg = await attachment_message(self.tempdir, url_items) + input_messages.append(msg) + elif memory_tool: + # if only memory is available, we load the data from the URLs and content items to the memory bank + await self.add_to_session_memory_bank(session_id, documents) + else: + # if no memory or code_interpreter tool is available, + # we try to load the data from the URLs and content items as a message to inference + # and add it to the last message's context + input_messages[-1].context = content_items + load_data_from_urls( + url_items + ) + + async def _ensure_memory_bank(self, session_id: str) -> str: + session_info = await self.storage.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + if session_info.memory_bank_id is None: + bank_id = f"memory_bank_{session_id}" + await self.memory_banks_api.register_memory_bank( + memory_bank_id=bank_id, + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + ), + ) + await self.storage.add_memory_bank_to_session(session_id, bank_id) + else: + bank_id = session_info.memory_bank_id + + return bank_id + + async def add_to_session_memory_bank( + self, session_id: str, data: List[Document] + ) -> None: + bank_id = await self._ensure_memory_bank(session_id) + documents = [ + MemoryBankDocument( + document_id=str(uuid.uuid4()), + content=a.content, + mime_type=a.mime_type, + metadata={}, + ) + for a in data + ] + await self.memory_api.insert_documents( + bank_id=bank_id, + documents=documents, + ) + + +async def load_data_from_urls(urls: List[URL]) -> List[str]: + data = [] + for url in urls: + uri = url.uri + if uri.startswith("file://"): + filepath = uri[len("file://") :] + with open(filepath, "r") as f: + data.append(f.read()) + elif uri.startswith("http"): + async with httpx.AsyncClient() as client: + r = await client.get(uri) + resp = r.text + data.append(resp) + return data + async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage: content = [] diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index ab7f8878f9..0181ef6095 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ AgentStepResponse, AgentTool, AgentTurnCreateRequest, + Document, Session, Turn, ) @@ -147,6 +148,7 @@ async def create_agent_turn( ] ], tools: Optional[List[AgentTool]] = None, + documents: Optional[List[Document]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( @@ -155,6 +157,7 @@ async def create_agent_turn( messages=messages, stream=True, tools=tools, + documents=documents, ) if stream: return self._create_agent_turn_streaming(request) diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 144f65863f..58b69858bc 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -21,6 +21,7 @@ class AgentSessionInfo(BaseModel): session_id: str session_name: str + memory_bank_id: Optional[str] = None started_at: datetime @@ -51,6 +52,17 @@ async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: return AgentSessionInfo(**json.loads(value)) + async def add_memory_bank_to_session(self, session_id: str, bank_id: str): + session_info = await self.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + session_info.memory_bank_id = bank_id + await self.kvstore.set( + key=f"session:{self.agent_id}:{session_id}", + value=session_info.model_dump_json(), + ) + async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index cb20e5890e..18dc904204 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -15,6 +15,7 @@ AgentTurnResponseStepCompletePayload, AgentTurnResponseStreamChunk, AgentTurnResponseTurnCompletePayload, + Document, ShieldCallStep, StepType, ToolChoice, @@ -22,8 +23,6 @@ Turn, ) from llama_stack.apis.inference import CompletionMessage, SamplingParams, UserMessage -from llama_stack.apis.memory import MemoryBankDocument -from llama_stack.apis.memory_banks import VectorMemoryBankParams from llama_stack.apis.safety import ViolationLevel from llama_stack.providers.datatypes import Api @@ -232,8 +231,6 @@ async def test_rag_agent( common_params, ): agents_impl = agents_stack.impls[Api.agents] - memory_banks_impl = agents_stack.impls[Api.memory_banks] - memory_impl = agents_stack.impls[Api.memory] urls = [ "memory_optimizations.rst", "chat.rst", @@ -243,28 +240,12 @@ async def test_rag_agent( "lora_finetune.rst", ] documents = [ - MemoryBankDocument( - document_id=f"num-{i}", + Document( content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", - metadata={}, ) for i, url in enumerate(urls) ] - await memory_banks_impl.register_memory_bank( - memory_bank_id="test_bank", - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - provider_id="faiss", - ) - memory_impl.insert_documents( - bank_id="test_bank", - documents=documents, - ) - agent_config = AgentConfig( **{ **common_params, @@ -278,6 +259,7 @@ async def test_rag_agent( agent_id=agent_id, session_id=session_id, messages=attachment_message, + documents=documents, stream=True, ) turn_response = [ diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 64c3c159f2..a77bb6cabd 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -203,6 +203,79 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): assert "Tool:code_interpreter Response" in logs_str +def test_code_execution(llama_stack_client): + agent_config = AgentConfig( + model="meta-llama/Llama-3.1-70B-Instruct", + instructions="You are a helpful assistant", + tools=[ + "brave_search", + "code_interpreter", + ], + tool_choice="required", + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + + memory_bank_id = "inflation_data_memory_bank" + llama_stack_client.memory_banks.register( + memory_bank_id=memory_bank_id, + params={ + "memory_bank_type": "vector", + "embedding_model": "all-MiniLM-L6-v2", + "chunk_size_in_tokens": 512, + "overlap_size_in_tokens": 64, + }, + ) + AugmentConfigWithMemoryTool(agent_config, llama_stack_client) + codex_agent = Agent(llama_stack_client, agent_config) + session_id = codex_agent.create_session("test-session") + + llama_stack_client.memory.insert( + bank_id=memory_bank_id, + documents=[ + Document( + document_id="inflation", + content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", + mime_type="text/csv", + metadata={}, + ) + ], + ) + + user_prompts = [ + { + "prompt": "Can you describe the data in the context?", + "tools": [{"name": "memory", "args": {"memory_bank_id": memory_bank_id}}], + }, + { + "prompt": "Plot average yearly inflation as a time series", + "tools": [ + {"name": "memory", "args": {"memory_bank_id": memory_bank_id}}, + "code_interpreter", + ], + }, + ] + + for input in user_prompts: + print(f'User> {input["prompt"]}') + response = codex_agent.create_turn( + messages=[ + { + "role": "user", + "content": input["prompt"], + } + ], + session_id=session_id, + tools=input["tools"], + ) + # for chunk in response: + # print(chunk) + + for log in EventLogger().log(response): + log.print() + + def test_custom_tool(llama_stack_client, agent_config): client_tool = TestClientTool() agent_config = { From f3304abfba9ef251eca71ae177c76e8a52c89bf1 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 6 Jan 2025 12:50:36 -0800 Subject: [PATCH 33/53] use maybe_register_memory --- tests/client-sdk/agents/test_agents.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a77bb6cabd..522a8a4ebd 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -9,7 +9,7 @@ from uuid import uuid4 import pytest -from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool +from llama_stack_client.lib.agents.agent import Agent, maybe_register_memory_tool from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage @@ -227,7 +227,8 @@ def test_code_execution(llama_stack_client): "overlap_size_in_tokens": 64, }, ) - AugmentConfigWithMemoryTool(agent_config, llama_stack_client) + tool_name, _ = maybe_register_memory_tool(llama_stack_client) + agent_config["tools"].append(tool_name) codex_agent = Agent(llama_stack_client, agent_config) session_id = codex_agent.create_session("test-session") @@ -324,7 +325,8 @@ def test_rag_agent(llama_stack_client, agent_config): for i, url in enumerate(urls) ] - memory_bank_id = AugmentConfigWithMemoryTool(agent_config, llama_stack_client) + tool_name, memory_bank_id = maybe_register_memory_tool(llama_stack_client) + agent_config["tools"].append(tool_name) agent = Agent(llama_stack_client, agent_config) llama_stack_client.memory.insert( bank_id=memory_bank_id, From 17abffb5052a8ca09a7b19fbbc843f732c480d41 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 6 Jan 2025 13:59:06 -0800 Subject: [PATCH 34/53] fix handle_docs --- .../agents/meta_reference/agent_instance.py | 39 +++++++------ tests/client-sdk/agents/test_agents.py | 57 +++++-------------- 2 files changed, 34 insertions(+), 62 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 43d5cbdb77..ac49a06ce6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -40,11 +40,7 @@ ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import ( - InterleavedContent, - TextContentItem, - URL, -) +from llama_stack.apis.common.content_types import TextContentItem, URL from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -375,7 +371,10 @@ async def _run( tool_args[tool.name] = tool.args tool_defs = await self._get_tool_defs(tools_for_turn) - await self.handle_documents(session_id, documents, input_messages, tool_defs) + if documents: + await self.handle_documents( + session_id, documents, input_messages, tool_defs + ) if "memory" in tool_defs and len(input_messages) > 0: with tracing.span("memory_tool") as span: step_id = str(uuid.uuid4()) @@ -759,26 +758,30 @@ async def handle_documents( input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: + breakpoint() memory_tool = tool_defs.get("memory", None) - code_interpreter_tool = tool_defs.get("code_interpreter", None) + code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) if documents: - content_items = [ - d for d in documents if isinstance(d.content, InterleavedContent) - ] - url_items = [d for d in documents if isinstance(d.content, URL)] + content_items = [] + url_items = [] pattern = re.compile("^(https?://|file://|data:)") - url_items = [ - URL(uri=a.content) for a in url_items if pattern.match(a.content) - ] + for d in documents: + if isinstance(d.content, URL): + url_items.append(d.content) + elif pattern.match(d.content): + url_items.append(URL(uri=d.content)) + else: + content_items.append(d) + # Save the contents to a tempdir and use its path as a URL if code interpreter is present if code_interpreter_tool: for c in content_items: temp_file_path = os.path.join( self.tempdir, f"{make_random_string()}.txt" ) - with open(temp_file_path, "w") as temp_file: - temp_file.write(c.content) - url_items.append(URL(uri=f"file://{temp_file_path}")) + with open(temp_file_path, "w") as temp_file: + temp_file.write(c.content) + url_items.append(URL(uri=f"file://{temp_file_path}")) if memory_tool and code_interpreter_tool: # if both memory and code_interpreter are available, we download the URLs @@ -800,7 +803,7 @@ async def handle_documents( # if no memory or code_interpreter tool is available, # we try to load the data from the URLs and content items as a message to inference # and add it to the last message's context - input_messages[-1].context = content_items + load_data_from_urls( + input_messages[-1].context = content_items + await load_data_from_urls( url_items ) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 522a8a4ebd..a8e06b7a29 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -14,6 +14,7 @@ from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter @@ -208,7 +209,6 @@ def test_code_execution(llama_stack_client): model="meta-llama/Llama-3.1-70B-Instruct", instructions="You are a helpful assistant", tools=[ - "brave_search", "code_interpreter", ], tool_choice="required", @@ -217,49 +217,19 @@ def test_code_execution(llama_stack_client): enable_session_persistence=False, ) - memory_bank_id = "inflation_data_memory_bank" - llama_stack_client.memory_banks.register( - memory_bank_id=memory_bank_id, - params={ - "memory_bank_type": "vector", - "embedding_model": "all-MiniLM-L6-v2", - "chunk_size_in_tokens": 512, - "overlap_size_in_tokens": 64, - }, - ) - tool_name, _ = maybe_register_memory_tool(llama_stack_client) - agent_config["tools"].append(tool_name) codex_agent = Agent(llama_stack_client, agent_config) session_id = codex_agent.create_session("test-session") - - llama_stack_client.memory.insert( - bank_id=memory_bank_id, - documents=[ - Document( - document_id="inflation", - content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", - mime_type="text/csv", - metadata={}, - ) - ], + inflation_doc = AgentDocument( + content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv", + mime_type="text/csv", ) - user_prompts = [ - { - "prompt": "Can you describe the data in the context?", - "tools": [{"name": "memory", "args": {"memory_bank_id": memory_bank_id}}], - }, - { - "prompt": "Plot average yearly inflation as a time series", - "tools": [ - {"name": "memory", "args": {"memory_bank_id": memory_bank_id}}, - "code_interpreter", - ], - }, + user_input = [ + {"prompt": "Here is a csv, can you describe it?", "documents": [inflation_doc]}, + {"prompt": "Plot average yearly inflation as a time series"}, ] - for input in user_prompts: - print(f'User> {input["prompt"]}') + for input in user_input: response = codex_agent.create_turn( messages=[ { @@ -268,13 +238,12 @@ def test_code_execution(llama_stack_client): } ], session_id=session_id, - tools=input["tools"], + documents=input.get("documents", None), ) - # for chunk in response: - # print(chunk) - - for log in EventLogger().log(response): - log.print() + logs = [str(log) for log in EventLogger().log(response) if log is not None] + logs_str = "".join(logs) + breakpoint() + print(logs_str) def test_custom_tool(llama_stack_client, agent_config): From db0b2a60c1f5038ea72f847fa7f56c7ff215d028 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 6 Jan 2025 14:49:55 -0800 Subject: [PATCH 35/53] remove breakpoints --- .../providers/inline/agents/meta_reference/agent_instance.py | 1 - tests/client-sdk/agents/test_agents.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index ac49a06ce6..e4ebb30110 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -758,7 +758,6 @@ async def handle_documents( input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: - breakpoint() memory_tool = tool_defs.get("memory", None) code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) if documents: diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a8e06b7a29..ca39ca03f1 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -211,7 +211,7 @@ def test_code_execution(llama_stack_client): tools=[ "code_interpreter", ], - tool_choice="required", + tool_choice="auto", input_shields=[], output_shields=[], enable_session_persistence=False, @@ -242,7 +242,6 @@ def test_code_execution(llama_stack_client): ) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - breakpoint() print(logs_str) From e3775eb6f696ae83810e5bb3274d4dc888e9a400 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 09:14:26 -0800 Subject: [PATCH 36/53] rename UserDefinedToolDef to ToolDef --- docs/resources/llama-stack-spec.html | 170 +++++------------- docs/resources/llama-stack-spec.yaml | 100 +++-------- llama_stack/apis/agents/agents.py | 4 +- llama_stack/apis/tools/tools.py | 26 +-- .../distribution/routers/routing_tables.py | 57 ++---- .../agents/meta_reference/agent_instance.py | 122 +++++++------ .../code_interpreter/code_interpreter.py | 3 +- tests/client-sdk/agents/test_agents.py | 22 +-- 8 files changed, 181 insertions(+), 323 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 5a78f5baea..fb75259889 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -3714,7 +3714,7 @@ "client_tools": { "type": "array", "items": { - "$ref": "#/components/schemas/UserDefinedToolDef" + "$ref": "#/components/schemas/ToolDef" } }, "tool_choice": { @@ -3792,60 +3792,9 @@ } ] }, - "ToolParameter": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "parameter_type": { - "type": "string" - }, - "description": { - "type": "string" - }, - "required": { - "type": "boolean" - }, - "default": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "additionalProperties": false, - "required": [ - "name", - "parameter_type", - "description", - "required" - ] - }, - "UserDefinedToolDef": { + "ToolDef": { "type": "object", "properties": { - "type": { - "type": "string", - "const": "user_defined", - "default": "user_defined" - }, "name": { "type": "string" }, @@ -3890,11 +3839,53 @@ }, "additionalProperties": false, "required": [ - "type", + "name" + ] + }, + "ToolParameter": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "parameter_type": { + "type": "string" + }, + "description": { + "type": "string" + }, + "required": { + "type": "boolean" + }, + "default": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "additionalProperties": false, + "required": [ "name", + "parameter_type", "description", - "parameters", - "metadata" + "required" ] }, "CreateAgentRequest": { @@ -4589,49 +4580,6 @@ "session_id" ] }, - "BuiltInToolDef": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "built_in", - "default": "built_in" - }, - "built_in_type": { - "$ref": "#/components/schemas/BuiltinTool" - }, - "metadata": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "built_in_type" - ] - }, "MCPToolGroupDef": { "type": "object", "properties": { @@ -4651,16 +4599,6 @@ ], "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." }, - "ToolDef": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserDefinedToolDef" - }, - { - "$ref": "#/components/schemas/BuiltInToolDef" - } - ] - }, "ToolGroupDef": { "oneOf": [ { @@ -7436,7 +7374,7 @@ "tool_group_id": { "type": "string" }, - "tool_group": { + "tool_group_def": { "$ref": "#/components/schemas/ToolGroupDef" }, "provider_id": { @@ -7446,7 +7384,7 @@ "additionalProperties": false, "required": [ "tool_group_id", - "tool_group" + "tool_group_def" ] }, "RunEvalRequest": { @@ -8098,10 +8036,6 @@ "name": "BenchmarkEvalTaskConfig", "description": "" }, - { - "name": "BuiltInToolDef", - "description": "" - }, { "name": "BuiltinTool", "description": "" @@ -8708,10 +8642,6 @@ "name": "UnstructuredLogEvent", "description": "" }, - { - "name": "UserDefinedToolDef", - "description": "" - }, { "name": "UserDefinedToolGroupDef", "description": "" @@ -8792,7 +8722,6 @@ "BatchCompletionRequest", "BatchCompletionResponse", "BenchmarkEvalTaskConfig", - "BuiltInToolDef", "BuiltinTool", "CancelTrainingJobRequest", "ChatCompletionRequest", @@ -8931,7 +8860,6 @@ "UnregisterModelRequest", "UnregisterToolGroupRequest", "UnstructuredLogEvent", - "UserDefinedToolDef", "UserDefinedToolGroupDef", "UserMessage", "VectorMemoryBank", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 72093b436a..0937d87224 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -19,7 +19,7 @@ components: properties: client_tools: items: - $ref: '#/components/schemas/UserDefinedToolDef' + $ref: '#/components/schemas/ToolDef' type: array enable_session_persistence: type: boolean @@ -396,29 +396,6 @@ components: - type - eval_candidate type: object - BuiltInToolDef: - additionalProperties: false - properties: - built_in_type: - $ref: '#/components/schemas/BuiltinTool' - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: built_in - default: built_in - type: string - required: - - type - - built_in_type - type: object BuiltinTool: enum: - brave_search @@ -1929,13 +1906,13 @@ components: properties: provider_id: type: string - tool_group: + tool_group_def: $ref: '#/components/schemas/ToolGroupDef' tool_group_id: type: string required: - tool_group_id - - tool_group + - tool_group_def type: object ResponseFormat: oneOf: @@ -2716,9 +2693,32 @@ components: - required type: string ToolDef: - oneOf: - - $ref: '#/components/schemas/UserDefinedToolDef' - - $ref: '#/components/schemas/BuiltInToolDef' + additionalProperties: false + properties: + description: + type: string + metadata: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + name: + type: string + parameters: + items: + $ref: '#/components/schemas/ToolParameter' + type: array + tool_prompt_format: + $ref: '#/components/schemas/ToolPromptFormat' + default: json + required: + - name + type: object ToolDefinition: additionalProperties: false properties: @@ -3087,41 +3087,6 @@ components: - message - severity type: object - UserDefinedToolDef: - additionalProperties: false - properties: - description: - type: string - metadata: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - name: - type: string - parameters: - items: - $ref: '#/components/schemas/ToolParameter' - type: array - tool_prompt_format: - $ref: '#/components/schemas/ToolPromptFormat' - default: json - type: - const: user_defined - default: user_defined - type: string - required: - - type - - name - - description - - parameters - - metadata - type: object UserDefinedToolGroupDef: additionalProperties: false properties: @@ -4823,8 +4788,6 @@ tags: - description: name: BenchmarkEvalTaskConfig -- description: - name: BuiltInToolDef - description: name: BuiltinTool - description: name: UnstructuredLogEvent -- description: - name: UserDefinedToolDef - description: name: UserDefinedToolGroupDef @@ -5316,7 +5276,6 @@ x-tagGroups: - BatchCompletionRequest - BatchCompletionResponse - BenchmarkEvalTaskConfig - - BuiltInToolDef - BuiltinTool - CancelTrainingJobRequest - ChatCompletionRequest @@ -5455,7 +5414,6 @@ x-tagGroups: - UnregisterModelRequest - UnregisterToolGroupRequest - UnstructuredLogEvent - - UserDefinedToolDef - UserDefinedToolGroupDef - UserMessage - VectorMemoryBank diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index acf8fa7486..db0e3ab3be 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -36,7 +36,7 @@ ) from llama_stack.apis.memory import MemoryBank from llama_stack.apis.safety import SafetyViolation -from llama_stack.apis.tools import UserDefinedToolDef +from llama_stack.apis.tools import ToolDef from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -157,7 +157,7 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) tools: Optional[List[AgentTool]] = Field(default_factory=list) - client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list) + client_tools: Optional[List[ToolDef]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 6585f3fd2a..bc19a8a027 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -48,30 +48,16 @@ class Tool(Resource): @json_schema_type -class UserDefinedToolDef(BaseModel): - type: Literal["user_defined"] = "user_defined" +class ToolDef(BaseModel): name: str - description: str - parameters: List[ToolParameter] - metadata: Dict[str, Any] + description: Optional[str] = None + parameters: Optional[List[ToolParameter]] = None + metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) -@json_schema_type -class BuiltInToolDef(BaseModel): - type: Literal["built_in"] = "built_in" - built_in_type: BuiltinTool - metadata: Optional[Dict[str, Any]] = None - - -ToolDef = register_schema( - Annotated[Union[UserDefinedToolDef, BuiltInToolDef], Field(discriminator="type")], - name="ToolDef", -) - - @json_schema_type class MCPToolGroupDef(BaseModel): """ @@ -100,7 +86,7 @@ class UserDefinedToolGroupDef(BaseModel): @json_schema_type class ToolGroupInput(BaseModel): tool_group_id: str - tool_group: ToolGroupDef + tool_group_def: ToolGroupDef provider_id: Optional[str] = None @@ -127,7 +113,7 @@ class ToolGroups(Protocol): async def register_tool_group( self, tool_group_id: str, - tool_group: ToolGroupDef, + tool_group_def: ToolGroupDef, provider_id: Optional[str] = None, ) -> None: """Register a tool group""" diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index ccea470ae1..b51de8fef0 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -27,15 +27,12 @@ ) from llama_stack.apis.shields import Shield, Shields from llama_stack.apis.tools import ( - BuiltInToolDef, MCPToolGroupDef, Tool, ToolGroup, ToolGroupDef, ToolGroups, ToolHost, - ToolPromptFormat, - UserDefinedToolDef, UserDefinedToolGroupDef, ) from llama_stack.distribution.datatypes import ( @@ -514,7 +511,7 @@ async def get_tool(self, tool_name: str) -> Tool: async def register_tool_group( self, tool_group_id: str, - tool_group: ToolGroupDef, + tool_group_def: ToolGroupDef, provider_id: Optional[str] = None, ) -> None: tools = [] @@ -528,47 +525,31 @@ async def register_tool_group( provider_id = list(self.impls_by_provider_id.keys())[0] # parse tool group to the type if dict - tool_group = TypeAdapter(ToolGroupDef).validate_python(tool_group) - if isinstance(tool_group, MCPToolGroupDef): + tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def) + if isinstance(tool_group_def, MCPToolGroupDef): tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( - tool_group + tool_group_def ) tool_host = ToolHost.model_context_protocol - elif isinstance(tool_group, UserDefinedToolGroupDef): - tool_defs = tool_group.tools + elif isinstance(tool_group_def, UserDefinedToolGroupDef): + tool_defs = tool_group_def.tools else: - raise ValueError(f"Unknown tool group: {tool_group}") + raise ValueError(f"Unknown tool group: {tool_group_def}") for tool_def in tool_defs: - if isinstance(tool_def, UserDefinedToolDef): - tools.append( - Tool( - identifier=tool_def.name, - tool_group=tool_group_id, - description=tool_def.description, - parameters=tool_def.parameters, - provider_id=provider_id, - tool_prompt_format=tool_def.tool_prompt_format, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - elif isinstance(tool_def, BuiltInToolDef): - tools.append( - Tool( - identifier=tool_def.built_in_type.value, - tool_group=tool_group_id, - built_in_type=tool_def.built_in_type, - description="", - parameters=[], - provider_id=provider_id, - tool_prompt_format=ToolPromptFormat.json, - provider_resource_id=tool_def.built_in_type.value, - metadata=tool_def.metadata, - tool_host=tool_host, - ) + tools.append( + Tool( + identifier=tool_def.name, + tool_group=tool_group_id, + description=tool_def.description or "", + parameters=tool_def.parameters or [], + provider_id=provider_id, + tool_prompt_format=tool_def.tool_prompt_format, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + tool_host=tool_host, ) + ) for tool in tools: existing_tool = await self.get_tool(tool.identifier) # Compare existing and new object if one exists diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e4ebb30110..cea4146e98 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -387,7 +387,7 @@ async def _run( ) ) extra_args = tool_args.get("memory", {}) - args = { + tool_args = { # Query memory with the last message's content "query": input_messages[-1], **extra_args, @@ -396,8 +396,8 @@ async def _run( session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - args["memory_bank_id"] = session_info.memory_bank_id - serialized_args = tracing.serialize_value(args) + tool_args["memory_bank_id"] = session_info.memory_bank_id + serialized_args = tracing.serialize_value(tool_args) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -416,7 +416,7 @@ async def _run( ) result = await self.tool_runtime_api.invoke_tool( tool_name="memory", - args=args, + args=tool_args, ) yield AgentTurnResponseStreamChunk( @@ -482,11 +482,7 @@ async def _run( async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[ - tool - for tool in tool_defs.values() - if tool.tool_name != "memory" - ], + tools=[tool for tool in tool_defs.values()], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, @@ -728,10 +724,17 @@ async def _get_tool_defs( continue tool_def = await self.tool_groups_api.get_tool(tool_name) - - if tool_def.built_in_type: - ret[tool_def.built_in_type] = ToolDefinition( - tool_name=tool_def.built_in_type + if tool_def is None: + raise ValueError(f"Tool {tool_name} not found") + + if tool_def.identifier.startswith("builtin::"): + built_in_type = tool_def.identifier[len("builtin::") :] + if built_in_type == "web_search": + built_in_type = "brave_search" + if built_in_type not in BuiltinTool.__members__: + raise ValueError(f"Unknown built-in tool: {built_in_type}") + ret[built_in_type] = ToolDefinition( + tool_name=BuiltinTool(built_in_type) ) continue @@ -759,52 +762,52 @@ async def handle_documents( tool_defs: Dict[str, ToolDefinition], ) -> None: memory_tool = tool_defs.get("memory", None) - code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) - if documents: - content_items = [] - url_items = [] - pattern = re.compile("^(https?://|file://|data:)") - for d in documents: - if isinstance(d.content, URL): - url_items.append(d.content) - elif pattern.match(d.content): - url_items.append(URL(uri=d.content)) - else: - content_items.append(d) - - # Save the contents to a tempdir and use its path as a URL if code interpreter is present - if code_interpreter_tool: - for c in content_items: - temp_file_path = os.path.join( - self.tempdir, f"{make_random_string()}.txt" - ) - with open(temp_file_path, "w") as temp_file: - temp_file.write(c.content) - url_items.append(URL(uri=f"file://{temp_file_path}")) - - if memory_tool and code_interpreter_tool: - # if both memory and code_interpreter are available, we download the URLs - # and attach the data to the last message. - msg = await attachment_message(self.tempdir, url_items) - input_messages.append(msg) - # Since memory is present, add all the data to the memory bank - await self.add_to_session_memory_bank(session_id, documents) - elif code_interpreter_tool: - # if only code_interpreter is available, we download the URLs to a tempdir - # and attach the path to them as a message to inference with the - # assumption that the model invokes the code_interpreter tool with the path - msg = await attachment_message(self.tempdir, url_items) - input_messages.append(msg) - elif memory_tool: - # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_memory_bank(session_id, documents) + code_interpreter_tool = tool_defs.get("code_interpreter", None) + content_items = [] + url_items = [] + pattern = re.compile("^(https?://|file://|data:)") + for d in documents: + if isinstance(d.content, URL): + url_items.append(d.content) + elif pattern.match(d.content): + url_items.append(URL(uri=d.content)) else: - # if no memory or code_interpreter tool is available, - # we try to load the data from the URLs and content items as a message to inference - # and add it to the last message's context - input_messages[-1].context = content_items + await load_data_from_urls( - url_items + content_items.append(d) + + # Save the contents to a tempdir and use its path as a URL if code interpreter is present + if code_interpreter_tool: + for c in content_items: + temp_file_path = os.path.join( + self.tempdir, f"{make_random_string()}.txt" ) + with open(temp_file_path, "w") as temp_file: + temp_file.write(c.content) + url_items.append(URL(uri=f"file://{temp_file_path}")) + + if memory_tool and code_interpreter_tool: + # if both memory and code_interpreter are available, we download the URLs + # and attach the data to the last message. + msg = await attachment_message(self.tempdir, url_items) + input_messages.append(msg) + # Since memory is present, add all the data to the memory bank + await self.add_to_session_memory_bank(session_id, documents) + elif code_interpreter_tool: + # if only code_interpreter is available, we download the URLs to a tempdir + # and attach the path to them as a message to inference with the + # assumption that the model invokes the code_interpreter tool with the path + msg = await attachment_message(self.tempdir, url_items) + input_messages.append(msg) + elif memory_tool: + # if only memory is available, we load the data from the URLs and content items to the memory bank + await self.add_to_session_memory_bank(session_id, documents) + else: + # if no memory or code_interpreter tool is available, + # we try to load the data from the URLs and content items as a message to inference + # and add it to the last message's context + input_messages[-1].context = "\n".join( + [doc.content for doc in content_items] + + await load_data_from_urls(url_items) + ) async def _ensure_memory_bank(self, session_id: str) -> str: session_info = await self.storage.get_session_info(session_id) @@ -909,7 +912,10 @@ async def execute_tool_call_maybe( tool_call = message.tool_calls[0] name = tool_call.tool_name if isinstance(name, BuiltinTool): - name = name.value + if name == BuiltinTool.brave_search: + name = "builtin::web_search" + else: + name = "builtin::" + name.value result = await tool_runtime_api.invoke_tool( tool_name=name, args=dict( diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 2e062d6d7f..0fe0d0243c 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -30,8 +30,7 @@ async def initialize(self): pass async def register_tool(self, tool: Tool): - if tool.identifier != "code_interpreter": - raise ValueError(f"Tool identifier {tool.identifier} is not supported") + pass async def unregister_tool(self, tool_id: str) -> None: return diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index ca39ca03f1..a760bb08ab 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -17,7 +17,7 @@ from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage -from llama_stack_client.types.tool_def_param import UserDefinedToolDefParameter +from llama_stack_client.types.tool_def_param import Parameter class TestClientTool(ClientTool): @@ -53,15 +53,15 @@ def get_name(self) -> str: def get_description(self) -> str: return "Get the boiling point of imaginary liquids (eg. polyjuice)" - def get_params_definition(self) -> Dict[str, UserDefinedToolDefParameter]: + def get_params_definition(self) -> Dict[str, Parameter]: return { - "liquid_name": UserDefinedToolDefParameter( + "liquid_name": Parameter( name="liquid_name", parameter_type="string", description="The name of the liquid", required=True, ), - "celcius": UserDefinedToolDefParameter( + "celcius": Parameter( name="celcius", parameter_type="boolean", description="Whether to return the boiling point in Celcius", @@ -149,11 +149,11 @@ def test_agent_simple(llama_stack_client, agent_config): assert "I can't" in logs_str -def test_builtin_tool_brave_search(llama_stack_client, agent_config): +def test_builtin_tool_web_search(llama_stack_client, agent_config): agent_config = { **agent_config, "tools": [ - "brave_search", + "builtin::web_search", ], } agent = Agent(llama_stack_client, agent_config) @@ -182,7 +182,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, "tools": [ - "code_interpreter", + "builtin::code_interpreter", ], } agent = Agent(llama_stack_client, agent_config) @@ -209,9 +209,9 @@ def test_code_execution(llama_stack_client): model="meta-llama/Llama-3.1-70B-Instruct", instructions="You are a helpful assistant", tools=[ - "code_interpreter", + "builtin::code_interpreter", ], - tool_choice="auto", + tool_choice="required", input_shields=[], output_shields=[], enable_session_persistence=False, @@ -242,7 +242,7 @@ def test_code_execution(llama_stack_client): ) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - print(logs_str) + assert "Tool:code_interpreter" in logs_str def test_custom_tool(llama_stack_client, agent_config): @@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tools": ["brave_search"], + "tools": ["builtin::web_search"], "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", } From ba242c04cccabb5d06043b9fa5027b68c10932ee Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 10:06:28 -0800 Subject: [PATCH 37/53] remove memory from available tools to agent --- .../inline/agents/meta_reference/agent_instance.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cea4146e98..ceb764ffef 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -482,7 +482,11 @@ async def _run( async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[tool for tool in tool_defs.values()], + tools=[ + tool + for tool in tool_defs.values() + if tool.tool_name != "memory" + ], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, From f9a98c278a325df0e12c8a982c46faea49c5f6c5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 15:37:52 -0800 Subject: [PATCH 38/53] simplify toolgroups registration --- llama_stack/apis/agents/agents.py | 12 +- llama_stack/apis/tools/tools.py | 58 ++---- llama_stack/distribution/routers/routers.py | 14 +- .../distribution/routers/routing_tables.py | 57 ++---- .../agents/meta_reference/agent_instance.py | 180 +++++++++++------- .../inline/agents/meta_reference/agents.py | 4 +- .../code_interpreter/code_interpreter.py | 29 ++- .../tool_runtime/memory/context_retriever.py | 16 +- .../inline/tool_runtime/memory/memory.py | 40 +++- .../tool_runtime/brave_search/brave_search.py | 29 ++- .../model_context_protocol.py | 20 +- .../tavily_search/tavily_search.py | 29 ++- .../providers/tests/agents/test_agents.py | 72 +++++-- llama_stack/providers/tests/tools/fixtures.py | 39 +--- .../providers/tests/tools/test_tools.py | 9 +- 15 files changed, 351 insertions(+), 257 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index db0e3ab3be..f5fbcb9c40 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -137,15 +137,15 @@ class Session(BaseModel): memory_bank: Optional[MemoryBank] = None -class AgentToolWithArgs(BaseModel): +class AgentToolGroupWithArgs(BaseModel): name: str args: Dict[str, Any] -AgentTool = register_schema( +AgentToolGroup = register_schema( Union[ str, - AgentToolWithArgs, + AgentToolGroupWithArgs, ], name="AgentTool", ) @@ -156,7 +156,7 @@ class AgentConfigCommon(BaseModel): input_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list) - tools: Optional[List[AgentTool]] = Field(default_factory=list) + toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( @@ -278,7 +278,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): ] documents: Optional[List[Document]] = None - tools: Optional[List[AgentTool]] = None + toolgroups: Optional[List[AgentToolGroup]] = None stream: Optional[bool] = False @@ -317,7 +317,7 @@ async def create_agent_turn( ], stream: Optional[bool] = False, documents: Optional[List[Document]] = None, - tools: Optional[List[AgentTool]] = None, + tools: Optional[List[AgentToolGroup]] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/turn/get") diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index bc19a8a027..24845e1016 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -5,10 +5,10 @@ # the root directory of this source tree. from enum import Enum -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat -from llama_models.schema_utils import json_schema_type, register_schema, webmethod +from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -22,7 +22,7 @@ class ToolParameter(BaseModel): name: str parameter_type: str description: str - required: bool + required: bool = Field(default=True) default: Optional[Any] = None @@ -36,7 +36,7 @@ class ToolHost(Enum): @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool.value] = ResourceType.tool.value - tool_group: str + toolgroup_id: str tool_host: ToolHost description: str parameters: List[ToolParameter] @@ -58,41 +58,19 @@ class ToolDef(BaseModel): ) -@json_schema_type -class MCPToolGroupDef(BaseModel): - """ - A tool group that is defined by in a model context protocol server. - Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information. - """ - - type: Literal["model_context_protocol"] = "model_context_protocol" - endpoint: URL - - -@json_schema_type -class UserDefinedToolGroupDef(BaseModel): - type: Literal["user_defined"] = "user_defined" - tools: List[ToolDef] - - -ToolGroupDef = register_schema( - Annotated[ - Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type") - ], - name="ToolGroupDef", -) - - @json_schema_type class ToolGroupInput(BaseModel): - tool_group_id: str - tool_group_def: ToolGroupDef - provider_id: Optional[str] = None + toolgroup_id: str + provider_id: str + args: Optional[Dict[str, Any]] = None + mcp_endpoint: Optional[URL] = None @json_schema_type class ToolGroup(Resource): type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value + mcp_endpoint: Optional[URL] = None + args: Optional[Dict[str, Any]] = None @json_schema_type @@ -104,6 +82,7 @@ class ToolInvocationResult(BaseModel): class ToolStore(Protocol): def get_tool(self, tool_name: str) -> Tool: ... + def get_tool_group(self, tool_group_id: str) -> ToolGroup: ... @runtime_checkable @@ -112,9 +91,10 @@ class ToolGroups(Protocol): @webmethod(route="/toolgroups/register", method="POST") async def register_tool_group( self, - tool_group_id: str, - tool_group_def: ToolGroupDef, - provider_id: Optional[str] = None, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: Optional[URL] = None, + args: Optional[Dict[str, Any]] = None, ) -> None: """Register a tool group""" ... @@ -122,7 +102,7 @@ async def register_tool_group( @webmethod(route="/toolgroups/get", method="GET") async def get_tool_group( self, - tool_group_id: str, + toolgroup_id: str, ) -> ToolGroup: ... @webmethod(route="/toolgroups/list", method="GET") @@ -149,8 +129,10 @@ async def unregister_tool_group(self, tool_group_id: str) -> None: class ToolRuntime(Protocol): tool_store: ToolStore - @webmethod(route="/tool-runtime/discover", method="POST") - async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ... + @webmethod(route="/tool-runtime/list-tools", method="POST") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: ... @webmethod(route="/tool-runtime/invoke", method="POST") async def invoke_tool( diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 84ef467eb5..230feea710 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, URL from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( AppEvalTaskConfig, @@ -38,7 +38,7 @@ ScoringFnParams, ) from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime +from llama_stack.apis.tools import ToolDef, ToolRuntime from llama_stack.providers.datatypes import RoutingTable @@ -417,7 +417,9 @@ async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: args=args, ) - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - return await self.routing_table.get_provider_impl( - tool_group.name - ).discover_tools(tool_group) + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return await self.routing_table.get_provider_impl(tool_group_id).list_tools( + tool_group_id, mcp_endpoint + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index b51de8fef0..4ed932807f 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -26,15 +26,7 @@ ScoringFunctions, ) from llama_stack.apis.shields import Shield, Shields -from llama_stack.apis.tools import ( - MCPToolGroupDef, - Tool, - ToolGroup, - ToolGroupDef, - ToolGroups, - ToolHost, - UserDefinedToolGroupDef, -) +from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost from llama_stack.distribution.datatypes import ( RoutableObject, RoutableObjectWithProvider, @@ -496,51 +488,38 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: tools = await self.get_all_with_type("tool") if tool_group_id: - tools = [tool for tool in tools if tool.tool_group == tool_group_id] + tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id] return tools async def list_tool_groups(self) -> List[ToolGroup]: return await self.get_all_with_type("tool_group") - async def get_tool_group(self, tool_group_id: str) -> ToolGroup: - return await self.get_object_by_identifier("tool_group", tool_group_id) + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: + return await self.get_object_by_identifier("tool_group", toolgroup_id) async def get_tool(self, tool_name: str) -> Tool: return await self.get_object_by_identifier("tool", tool_name) async def register_tool_group( self, - tool_group_id: str, - tool_group_def: ToolGroupDef, - provider_id: Optional[str] = None, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: Optional[URL] = None, + args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = [] - tool_host = ToolHost.distribution - if provider_id is None: - if len(self.impls_by_provider_id.keys()) > 1: - raise ValueError( - f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}" - ) - provider_id = list(self.impls_by_provider_id.keys())[0] - - # parse tool group to the type if dict - tool_group_def = TypeAdapter(ToolGroupDef).validate_python(tool_group_def) - if isinstance(tool_group_def, MCPToolGroupDef): - tool_defs = await self.impls_by_provider_id[provider_id].discover_tools( - tool_group_def - ) - tool_host = ToolHost.model_context_protocol - elif isinstance(tool_group_def, UserDefinedToolGroupDef): - tool_defs = tool_group_def.tools - else: - raise ValueError(f"Unknown tool group: {tool_group_def}") + tool_defs = await self.impls_by_provider_id[provider_id].list_tools( + toolgroup_id, mcp_endpoint + ) + tool_host = ( + ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ) for tool_def in tool_defs: tools.append( Tool( identifier=tool_def.name, - tool_group=tool_group_id, + toolgroup_id=toolgroup_id, description=tool_def.description or "", parameters=tool_def.parameters or [], provider_id=provider_id, @@ -565,9 +544,11 @@ async def register_tool_group( await self.dist_registry.register( ToolGroup( - identifier=tool_group_id, + identifier=toolgroup_id, provider_id=provider_id, - provider_resource_id=tool_group_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, ) ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index ceb764ffef..cfe839dad8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,7 +13,7 @@ import string import uuid from datetime import datetime -from typing import AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from urllib.parse import urlparse import httpx @@ -21,8 +21,8 @@ from llama_stack.apis.agents import ( AgentConfig, - AgentTool, - AgentToolWithArgs, + AgentToolGroup, + AgentToolGroupWithArgs, AgentTurnCreateRequest, AgentTurnResponseEvent, AgentTurnResponseEventType, @@ -76,6 +76,10 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") +MEMORY_TOOL_GROUP_ID = "builtin::memory" +MEMORY_QUERY_TOOL = "query_memory" +CODE_INTERPRETER_TOOL = "code_interpreter" +WEB_SEARCH_TOOL = "web_search" class ChatAgent(ShieldRunnerMixin): @@ -192,7 +196,7 @@ async def create_and_execute_turn( sampling_params=self.agent_config.sampling_params, stream=request.stream, documents=request.documents, - tools_for_turn=request.tools, + toolgroups_for_turn=request.toolgroups, ): if isinstance(chunk, CompletionMessage): log.info( @@ -243,7 +247,7 @@ async def run( sampling_params: SamplingParams, stream: bool = False, documents: Optional[List[Document]] = None, - tools_for_turn: Optional[List[AgentTool]] = None, + toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -266,7 +270,7 @@ async def run( sampling_params, stream, documents, - tools_for_turn, + toolgroups_for_turn, ): if isinstance(res, bool): return @@ -362,21 +366,24 @@ async def _run( sampling_params: SamplingParams, stream: bool = False, documents: Optional[List[Document]] = None, - tools_for_turn: Optional[List[AgentTool]] = None, + toolgroups_for_turn: Optional[List[AgentToolGroup]] = None, ) -> AsyncGenerator: - tool_args = {} - if tools_for_turn: - for tool in tools_for_turn: - if isinstance(tool, AgentToolWithArgs): - tool_args[tool.name] = tool.args - - tool_defs = await self._get_tool_defs(tools_for_turn) + toolgroup_args = {} + for toolgroup in self.agent_config.toolgroups: + if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroup_args[toolgroup.name] = toolgroup.args + if toolgroups_for_turn: + for toolgroup in toolgroups_for_turn: + if isinstance(toolgroup, AgentToolGroupWithArgs): + toolgroup_args[toolgroup.name] = toolgroup.args + + tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: await self.handle_documents( session_id, documents, input_messages, tool_defs ) - if "memory" in tool_defs and len(input_messages) > 0: - with tracing.span("memory_tool") as span: + if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0: + with tracing.span(MEMORY_QUERY_TOOL) as span: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -386,18 +393,16 @@ async def _run( ) ) ) - extra_args = tool_args.get("memory", {}) - tool_args = { - # Query memory with the last message's content - "query": input_messages[-1], - **extra_args, + query_args = { + "messages": [msg.content for msg in input_messages], + **toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}), } session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - tool_args["memory_bank_id"] = session_info.memory_bank_id - serialized_args = tracing.serialize_value(tool_args) + query_args["memory_bank_id"] = session_info.memory_bank_id + serialized_args = tracing.serialize_value(query_args) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -415,8 +420,8 @@ async def _run( ) ) result = await self.tool_runtime_api.invoke_tool( - tool_name="memory", - args=tool_args, + tool_name=MEMORY_QUERY_TOOL, + args=query_args, ) yield AgentTurnResponseStreamChunk( @@ -485,7 +490,8 @@ async def _run( tools=[ tool for tool in tool_defs.values() - if tool.tool_name != "memory" + if tool_to_group.get(tool.tool_name, None) + != MEMORY_TOOL_GROUP_ID ], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, @@ -632,6 +638,8 @@ async def _run( self.tool_runtime_api, session_id, [message], + toolgroup_args, + tool_to_group, ) assert ( len(result_messages) == 1 @@ -690,26 +698,37 @@ def interpret_content_as_attachment( n_iter += 1 async def _get_tool_defs( - self, tools_for_turn: Optional[List[AgentTool]] + self, toolgroups_for_turn: Optional[List[AgentToolGroup]] ) -> Dict[str, ToolDefinition]: # Determine which tools to include - agent_config_tools = set( - tool.name if isinstance(tool, AgentToolWithArgs) else tool - for tool in self.agent_config.tools + agent_config_toolgroups = set( + ( + toolgroup.name + if isinstance(toolgroup, AgentToolGroupWithArgs) + else toolgroup + ) + for toolgroup in self.agent_config.toolgroups ) - tools_for_turn_set = ( - agent_config_tools - if tools_for_turn is None + toolgroups_for_turn_set = ( + agent_config_toolgroups + if toolgroups_for_turn is None else { - tool.name if isinstance(tool, AgentToolWithArgs) else tool - for tool in tools_for_turn + ( + toolgroup.name + if isinstance(toolgroup, AgentToolGroupWithArgs) + else toolgroup + ) + for toolgroup in toolgroups_for_turn } ) - ret = {} + tool_def_map = {} + tool_to_group = {} for tool_def in self.agent_config.client_tools: - ret[tool_def.name] = ToolDefinition( + if tool_def_map.get(tool_def.name, None): + raise ValueError(f"Tool {tool_def.name} already exists") + tool_def_map[tool_def.name] = ToolDefinition( tool_name=tool_def.name, description=tool_def.description, parameters={ @@ -722,41 +741,42 @@ async def _get_tool_defs( for param in tool_def.parameters }, ) - - for tool_name in agent_config_tools: - if tool_name not in tools_for_turn_set: - continue - - tool_def = await self.tool_groups_api.get_tool(tool_name) - if tool_def is None: - raise ValueError(f"Tool {tool_name} not found") - - if tool_def.identifier.startswith("builtin::"): - built_in_type = tool_def.identifier[len("builtin::") :] - if built_in_type == "web_search": - built_in_type = "brave_search" - if built_in_type not in BuiltinTool.__members__: - raise ValueError(f"Unknown built-in tool: {built_in_type}") - ret[built_in_type] = ToolDefinition( - tool_name=BuiltinTool(built_in_type) - ) + tool_to_group[tool_def.name] = "__client_tools__" + for toolgroup_name in agent_config_toolgroups: + if toolgroup_name not in toolgroups_for_turn_set: continue + tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name) + for tool_def in tools: + if tool_def.built_in_type: + if tool_def_map.get(tool_def.built_in_type, None): + raise ValueError( + f"Tool {tool_def.built_in_type} already exists" + ) - ret[tool_def.identifier] = ToolDefinition( - tool_name=tool_def.identifier, - description=tool_def.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, + tool_def_map[tool_def.built_in_type] = ToolDefinition( + tool_name=tool_def.built_in_type ) - for param in tool_def.parameters - }, - ) + tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id + continue - return ret + if tool_def_map.get(tool_def.identifier, None): + raise ValueError(f"Tool {tool_def.identifier} already exists") + tool_def_map[tool_def.identifier] = ToolDefinition( + tool_name=tool_def.identifier, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, + ) + tool_to_group[tool_def.identifier] = tool_def.toolgroup_id + + return tool_def_map, tool_to_group async def handle_documents( self, @@ -765,8 +785,8 @@ async def handle_documents( input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: - memory_tool = tool_defs.get("memory", None) - code_interpreter_tool = tool_defs.get("code_interpreter", None) + memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) + code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") @@ -903,7 +923,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa async def execute_tool_call_maybe( - tool_runtime_api: ToolRuntime, session_id: str, messages: List[CompletionMessage] + tool_runtime_api: ToolRuntime, + session_id: str, + messages: List[CompletionMessage], + toolgroup_args: Dict[str, Dict[str, Any]], + tool_to_group: Dict[str, str], ) -> List[ToolResponseMessage]: # While Tools.run interface takes a list of messages, # All tools currently only run on a single message @@ -915,18 +939,26 @@ async def execute_tool_call_maybe( tool_call = message.tool_calls[0] name = tool_call.tool_name + group_name = tool_to_group.get(name, None) + if group_name is None: + raise ValueError(f"Tool {name} not found in any tool group") + # get the arguments generated by the model and augment with toolgroup arg overrides for the agent + tool_call_args = tool_call.arguments + tool_call_args.update(toolgroup_args.get(group_name, {})) if isinstance(name, BuiltinTool): if name == BuiltinTool.brave_search: - name = "builtin::web_search" + name = WEB_SEARCH_TOOL else: - name = "builtin::" + name.value + name = name.value + result = await tool_runtime_api.invoke_tool( tool_name=name, args=dict( session_id=session_id, - **tool_call.arguments, + **tool_call_args, ), ) + return [ ToolResponseMessage( call_id=tool_call.call_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0181ef6095..2ea74300dd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -19,7 +19,7 @@ Agents, AgentSessionCreateResponse, AgentStepResponse, - AgentTool, + AgentToolGroup, AgentTurnCreateRequest, Document, Session, @@ -147,7 +147,7 @@ async def create_agent_turn( ToolResponseMessage, ] ], - tools: Optional[List[AgentTool]] = None, + tools: Optional[List[AgentToolGroup]] = None, documents: Optional[List[Document]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 0fe0d0243c..fc568996da 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -7,9 +7,16 @@ import logging import tempfile -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) from llama_stack.providers.datatypes import ToolsProtocolPrivate from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor @@ -35,8 +42,22 @@ async def register_tool(self, tool: Tool): async def unregister_tool(self, tool_id: str) -> None: return - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - raise NotImplementedError("Code interpreter tool group not supported") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="code_interpreter", + description="Execute code", + parameters=[ + ToolParameter( + name="code", + description="The code to execute", + parameter_type="string", + ), + ], + ) + ] async def invoke_tool( self, tool_name: str, args: Dict[str, Any] diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 7ee751a173..1fb1d09920 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -5,6 +5,8 @@ # the root directory of this source tree. +from typing import List + from jinja2 import Template from llama_stack.apis.inference import Message, UserMessage @@ -22,7 +24,7 @@ async def generate_rag_query( config: MemoryQueryGeneratorConfig, - message: Message, + messages: List[Message], **kwargs, ): """ @@ -30,9 +32,9 @@ async def generate_rag_query( retrieving relevant information from the memory bank. """ if config.type == MemoryQueryGenerator.default.value: - query = await default_rag_query_generator(config, message, **kwargs) + query = await default_rag_query_generator(config, messages, **kwargs) elif config.type == MemoryQueryGenerator.llm.value: - query = await llm_rag_query_generator(config, message, **kwargs) + query = await llm_rag_query_generator(config, messages, **kwargs) else: raise NotImplementedError(f"Unsupported memory query generator {config.type}") return query @@ -40,21 +42,21 @@ async def generate_rag_query( async def default_rag_query_generator( config: DefaultMemoryQueryGeneratorConfig, - message: Message, + messages: List[Message], **kwargs, ): - return interleaved_content_as_str(message.content) + return config.sep.join(interleaved_content_as_str(m.content) for m in messages) async def llm_rag_query_generator( config: LLMMemoryQueryGeneratorConfig, - message: Message, + messages: List[Message], **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = {"messages": [message.model_dump()]} + m_dict = {"messages": [message.model_dump() for message in messages]} template = Template(config.template) content = template.render(m_dict) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index cad123696b..c8c2cc772f 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -10,13 +10,14 @@ import string from typing import Any, Dict, List, Optional -from llama_stack.apis.inference import Inference, InterleavedContent, Message +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.memory import Memory, QueryDocumentsResponse from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.tools import ( ToolDef, - ToolGroupDef, ToolInvocationResult, + ToolParameter, ToolRuntime, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -50,17 +51,31 @@ def __init__( async def initialize(self): pass - async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: - return [] + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="memory", + description="Retrieve context from memory", + parameters=[ + ToolParameter( + name="input_messages", + description="The input messages to search for", + parameter_type="array", + ), + ], + ) + ] async def _retrieve_context( - self, message: Message, bank_ids: List[str] + self, input_messages: List[str], bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: if not bank_ids: return None query = await generate_rag_query( self.config.query_generator_config, - message, + input_messages, inference_api=self.inference_api, ) tasks = [ @@ -106,17 +121,22 @@ async def invoke_tool( self, tool_name: str, args: Dict[str, Any] ) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) + tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id) + final_args = tool_group.args or {} + final_args.update(args) config = MemoryToolConfig() - if tool.metadata.get("config") is not None: + if tool.metadata and tool.metadata.get("config") is not None: config = MemoryToolConfig(**tool.metadata["config"]) - if "memory_bank_id" in args: - bank_ids = [args["memory_bank_id"]] + if "memory_bank_ids" in final_args: + bank_ids = final_args["memory_bank_ids"] else: bank_ids = [ bank_config.bank_id for bank_config in config.memory_bank_configs ] + if "messages" not in final_args: + raise ValueError("messages are required") context = await self._retrieve_context( - args["query"], + final_args["messages"], bank_ids, ) if context is None: diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index cd0468d93b..162e82d629 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -4,11 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import requests -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -41,8 +48,22 @@ def _get_api_key(self) -> str: ) return provider_data.api_key - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - raise NotImplementedError("Brave search tool group not supported") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="web_search", + description="Search the web for information", + parameters=[ + ToolParameter( + name="query", + description="The query to search for", + parameter_type="string", + ) + ], + ) + ] async def invoke_tool( self, tool_name: str, args: Dict[str, Any] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 19ada8457e..dd2bb5e5e5 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,20 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from urllib.parse import urlparse from mcp import ClientSession from mcp.client.sse import sse_client +from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( - MCPToolGroupDef, ToolDef, - ToolGroupDef, ToolInvocationResult, ToolParameter, ToolRuntime, - UserDefinedToolDef, ) from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -31,12 +29,14 @@ def __init__(self, config: ModelContextProtocolConfig): async def initialize(self): pass - async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: - if not isinstance(tool_group, MCPToolGroupDef): - raise ValueError(f"Unsupported tool group type: {type(tool_group)}") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + if mcp_endpoint is None: + raise ValueError("mcp_endpoint is required") tools = [] - async with sse_client(tool_group.endpoint.uri) as streams: + async with sse_client(mcp_endpoint.uri) as streams: async with ClientSession(*streams) as session: await session.initialize() tools_result = await session.list_tools() @@ -53,12 +53,12 @@ async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ) ) tools.append( - UserDefinedToolDef( + ToolDef( name=tool.name, description=tool.description, parameters=parameters, metadata={ - "endpoint": tool_group.endpoint.uri, + "endpoint": mcp_endpoint.uri, }, ) ) diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index f4e9809293..6dc515be3b 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -5,11 +5,18 @@ # the root directory of this source tree. import json -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import requests -from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolsProtocolPrivate @@ -42,8 +49,22 @@ def _get_api_key(self) -> str: ) return provider_data.api_key - async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]: - raise NotImplementedError("Tavily search tool group not supported") + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="web_search", + description="Search the web for information", + parameters=[ + ToolParameter( + name="query", + description="The query to search for", + parameter_type="string", + ) + ], + ) + ] async def invoke_tool( self, tool_name: str, args: Dict[str, Any] diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 18dc904204..fb22e976e5 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -45,8 +45,7 @@ def common_params(inference_model): sampling_params=SamplingParams(temperature=0.7, top_p=0.95), input_shields=[], output_shields=[], - available_tools=[], - preprocessing_tools=[], + toolgroups=[], max_infer_iters=5, ) @@ -83,27 +82,27 @@ def query_attachment_messages(): ] -async def create_agent_turn_with_search_tool( +async def create_agent_turn_with_toolgroup( agents_stack: Dict[str, object], search_query_messages: List[object], common_params: Dict[str, str], - tool_name: str, + toolgroup_name: str, ) -> None: """ - Create an agent turn with a search tool. + Create an agent turn with a toolgroup. Args: agents_stack (Dict[str, object]): The agents stack. search_query_messages (List[object]): The search query messages. common_params (Dict[str, str]): The common parameters. - search_tool_definition (SearchToolDefinition): The search tool definition. + toolgroup_name (str): The name of the toolgroup. """ - # Create an agent with the search tool + # Create an agent with the toolgroup agent_config = AgentConfig( **{ **common_params, - "tools": [tool_name], + "toolgroups": [toolgroup_name], } ) @@ -249,7 +248,7 @@ async def test_rag_agent( agent_config = AgentConfig( **{ **common_params, - "tools": ["memory"], + "toolgroups": ["builtin::memory"], "tool_choice": ToolChoice.auto, } ) @@ -289,13 +288,58 @@ async def test_create_agent_turn_with_tavily_search( if "TAVILY_SEARCH_API_KEY" not in os.environ: pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") - await create_agent_turn_with_search_tool( - agents_stack, - search_query_messages, - common_params, - "brave_search", + # Create an agent with the toolgroup + agent_config = AgentConfig( + **{ + **common_params, + "toolgroups": ["builtin::web_search"], + } + ) + + agent_id, session_id = await create_agent_session( + agents_stack.impls[Api.agents], agent_config + ) + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=search_query_messages, + stream=True, ) + turn_response = [ + chunk + async for chunk in await agents_stack.impls[Api.agents].create_agent_turn( + **turn_request + ) + ] + + assert len(turn_response) > 0 + assert all( + isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response + ) + + check_event_types(turn_response) + + # Check for tool execution events + tool_execution_events = [ + chunk + for chunk in turn_response + if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) + and chunk.event.payload.step_details.step_type + == StepType.tool_execution.value + ] + assert len(tool_execution_events) > 0, "No tool execution events found" + + # Check the tool execution details + tool_execution = tool_execution_events[0].event.payload.step_details + assert isinstance(tool_execution, ToolExecutionStep) + assert len(tool_execution.tool_calls) > 0 + actual_tool_name = tool_execution.tool_calls[0].tool_name + assert actual_tool_name == "web_search" + assert len(tool_execution.tool_responses) > 0 + + check_turn_complete_event(turn_response, session_id, search_query_messages) + def check_event_types(turn_response): event_types = [chunk.event.payload.event_type for chunk in turn_response] diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index 58defd57d0..a9f923c87f 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -8,16 +8,9 @@ import pytest import pytest_asyncio -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.models import ModelInput, ModelType -from llama_stack.apis.tools import ( - BuiltInToolDef, - ToolGroupInput, - ToolParameter, - UserDefinedToolDef, - UserDefinedToolGroupDef, -) +from llama_stack.apis.tools import ToolGroupInput from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -47,30 +40,7 @@ def tool_runtime_memory_and_search() -> ProviderFixture: @pytest.fixture(scope="session") def tool_group_input_memory() -> ToolGroupInput: return ToolGroupInput( - tool_group_id="memory_group", - tool_group=UserDefinedToolGroupDef( - tools=[ - UserDefinedToolDef( - name="memory", - description="Query the memory bank", - parameters=[ - ToolParameter( - name="input_messages", - description="The input messages to search for in memory", - parameter_type="list", - required=True, - ), - ], - metadata={ - "config": { - "memory_bank_configs": [ - {"bank_id": "test_bank", "type": "vector"} - ] - } - }, - ) - ], - ), + toolgroup_id="builtin::memory", provider_id="memory-runtime", ) @@ -78,10 +48,7 @@ def tool_group_input_memory() -> ToolGroupInput: @pytest.fixture(scope="session") def tool_group_input_tavily_search() -> ToolGroupInput: return ToolGroupInput( - tool_group_id="tavily_search_group", - tool_group=UserDefinedToolGroupDef( - tools=[BuiltInToolDef(built_in_type=BuiltinTool.brave_search, metadata={})], - ), + toolgroup_id="builtin::web_search", provider_id="tavily-search", ) diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index f33b4a61d8..917db55e14 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -43,8 +43,8 @@ def sample_documents(): class TestTools: @pytest.mark.asyncio - async def test_brave_search_tool(self, tools_stack, sample_search_query): - """Test the Brave search tool functionality.""" + async def test_web_search_tool(self, tools_stack, sample_search_query): + """Test the web search tool functionality.""" if "TAVILY_SEARCH_API_KEY" not in os.environ: pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") @@ -52,7 +52,7 @@ async def test_brave_search_tool(self, tools_stack, sample_search_query): # Execute the tool response = await tools_impl.invoke_tool( - tool_name="brave_search", args={"query": sample_search_query} + tool_name="web_search", args={"query": sample_search_query} ) # Verify the response @@ -89,11 +89,12 @@ async def test_memory_tool(self, tools_stack, sample_documents): response = await tools_impl.invoke_tool( tool_name="memory", args={ - "input_messages": [ + "messages": [ UserMessage( content="What are the main topics covered in the documentation?", ) ], + "memory_bank_ids": ["test_bank"], }, ) From 94cca7a72a42ea0b558afdd90340861f08f40877 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 16:07:03 -0800 Subject: [PATCH 39/53] add wolfram alpha, bing search --- llama_stack/apis/tools/tools.py | 1 + .../distribution/routers/routing_tables.py | 1 + .../code_interpreter/code_interpreter.py | 3 + .../providers/registry/tool_runtime.py | 20 +++ .../tool_runtime/bing_search/__init__.py | 21 +++ .../tool_runtime/bing_search/bing_search.py | 116 ++++++++++++++ .../remote/tool_runtime/bing_search/config.py | 16 ++ .../tool_runtime/brave_search/brave_search.py | 2 + .../tavily_search/tavily_search.py | 2 + .../tool_runtime/wolfram_alpha/__init__.py | 22 +++ .../tool_runtime/wolfram_alpha/config.py | 15 ++ .../wolfram_alpha/wolfram_alpha.py | 148 ++++++++++++++++++ llama_stack/providers/tests/tools/fixtures.py | 22 ++- .../providers/tests/tools/test_tools.py | 23 +++ 14 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 llama_stack/providers/remote/tool_runtime/bing_search/__init__.py create mode 100644 llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py create mode 100644 llama_stack/providers/remote/tool_runtime/bing_search/config.py create mode 100644 llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py create mode 100644 llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py create mode 100644 llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 24845e1016..0c2bb58633 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -53,6 +53,7 @@ class ToolDef(BaseModel): description: Optional[str] = None parameters: Optional[List[ToolParameter]] = None metadata: Optional[Dict[str, Any]] = None + built_in_type: Optional[BuiltinTool] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4ed932807f..2f0288865c 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -527,6 +527,7 @@ async def register_tool_group( provider_resource_id=tool_def.name, metadata=tool_def.metadata, tool_host=tool_host, + built_in_type=tool_def.built_in_type, ) ) for tool in tools: diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index fc568996da..3b3f180320 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -9,6 +9,8 @@ import tempfile from typing import Any, Dict, List, Optional +from llama_models.llama3.api.datatypes import BuiltinTool + from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( Tool, @@ -56,6 +58,7 @@ async def list_tools( parameter_type="string", ), ], + built_in_type=BuiltinTool.code_interpreter, ) ] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index d6e8925992..40299edad7 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -42,6 +42,16 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", ), ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="bing-search", + module="llama_stack.providers.remote.tool_runtime.bing_search", + config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", + ), + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( @@ -52,6 +62,16 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", ), ), + remote_provider_spec( + api=Api.tool_runtime, + adapter=AdapterSpec( + adapter_type="wolfram-alpha", + module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", + config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", + ), + ), remote_provider_spec( api=Api.tool_runtime, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py b/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py new file mode 100644 index 0000000000..8481737b5a --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .bing_search import BingSearchToolRuntimeImpl +from .config import BingSearchToolConfig + +__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"] +from pydantic import BaseModel + + +class BingSearchToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: BingSearchToolConfig, _deps): + impl = BingSearchToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py new file mode 100644 index 0000000000..b0c30b0a08 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any, Dict, List, Optional + +import requests +from llama_models.llama3.api.datatypes import BuiltinTool + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import BingSearchToolConfig + + +class BingSearchToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: BingSearchToolConfig): + self.config = config + self.url = "https://api.bing.microsoft.com/v7.0/search" + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + pass + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass Bing Search API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="web_search", + description="Search the web using Bing Search API", + parameters=[ + ToolParameter( + name="query", + description="The query to search for", + parameter_type="string", + ) + ], + built_in_type=BuiltinTool.brave_search, + ) + ] + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + headers = { + "Ocp-Apim-Subscription-Key": api_key, + } + params = { + "count": self.config.top_k, + "textDecorations": True, + "textFormat": "HTML", + "q": args["query"], + } + + response = requests.get( + url=self.url, + params=params, + headers=headers, + ) + response.raise_for_status() + + return ToolInvocationResult( + content=json.dumps(self._clean_response(response.json())) + ) + + def _clean_response(self, search_response): + clean_response = [] + query = search_response["queryContext"]["originalQuery"] + if "webPages" in search_response: + pages = search_response["webPages"]["value"] + for p in pages: + selected_keys = {"name", "url", "snippet"} + clean_response.append( + {k: v for k, v in p.items() if k in selected_keys} + ) + if "news" in search_response: + clean_news = [] + news = search_response["news"]["value"] + for n in news: + selected_keys = {"name", "url", "description"} + clean_news.append({k: v for k, v in n.items() if k in selected_keys}) + + clean_response.append(clean_news) + + return {"query": query, "top_k": clean_response} diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/config.py b/llama_stack/providers/remote/tool_runtime/bing_search/config.py new file mode 100644 index 0000000000..67283d8d59 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/bing_search/config.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + + +class BingSearchToolConfig(BaseModel): + """Configuration for Bing Search Tool Runtime""" + + api_key: Optional[str] = None + top_k: int = 3 diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 162e82d629..dab6ce4398 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional import requests +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -62,6 +63,7 @@ async def list_tools( parameter_type="string", ) ], + built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 6dc515be3b..d22f188b32 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional import requests +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -63,6 +64,7 @@ async def list_tools( parameter_type="string", ) ], + built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py new file mode 100644 index 0000000000..aaa6e4e693 --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import WolframAlphaToolConfig +from .wolfram_alpha import WolframAlphaToolRuntimeImpl + +__all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"] + + +class WolframAlphaToolProviderDataValidator(BaseModel): + api_key: str + + +async def get_adapter_impl(config: WolframAlphaToolConfig, _deps): + impl = WolframAlphaToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py new file mode 100644 index 0000000000..13996b639a --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/config.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + + +class WolframAlphaToolConfig(BaseModel): + """Configuration for WolframAlpha Tool Runtime""" + + api_key: Optional[str] = None diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py new file mode 100644 index 0000000000..0f3fdfb39e --- /dev/null +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Any, Dict, List, Optional + +import requests +from llama_models.llama3.api.datatypes import BuiltinTool + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolInvocationResult, + ToolParameter, + ToolRuntime, +) +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import WolframAlphaToolConfig + + +class WolframAlphaToolRuntimeImpl( + ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData +): + def __init__(self, config: WolframAlphaToolConfig): + self.config = config + self.url = "https://api.wolframalpha.com/v2/query" + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + pass + + async def unregister_tool(self, tool_id: str) -> None: + return + + def _get_api_key(self) -> str: + if self.config.api_key: + return self.config.api_key + + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.api_key: + raise ValueError( + 'Pass WolframAlpha API Key in the header X-LlamaStack-ProviderData as { "api_key": }' + ) + return provider_data.api_key + + async def list_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [ + ToolDef( + name="wolfram_alpha", + description="Query WolframAlpha for computational knowledge", + parameters=[ + ToolParameter( + name="query", + description="The query to compute", + parameter_type="string", + ) + ], + built_in_type=BuiltinTool.wolfram_alpha, + ) + ] + + async def invoke_tool( + self, tool_name: str, args: Dict[str, Any] + ) -> ToolInvocationResult: + api_key = self._get_api_key() + params = { + "input": args["query"], + "appid": api_key, + "format": "plaintext", + "output": "json", + } + response = requests.get( + self.url, + params=params, + ) + + return ToolInvocationResult( + content=json.dumps(self._clean_wolfram_alpha_response(response.json())) + ) + + def _clean_wolfram_alpha_response(self, wa_response): + remove = { + "queryresult": [ + "datatypes", + "error", + "timedout", + "timedoutpods", + "numpods", + "timing", + "parsetiming", + "parsetimedout", + "recalculate", + "id", + "host", + "server", + "related", + "version", + { + "pods": [ + "scanner", + "id", + "error", + "expressiontypes", + "states", + "infos", + "position", + "numsubpods", + ] + }, + "assumptions", + ], + } + for main_key in remove: + for key_to_remove in remove[main_key]: + try: + if key_to_remove == "assumptions": + if "assumptions" in wa_response[main_key]: + del wa_response[main_key][key_to_remove] + if isinstance(key_to_remove, dict): + for sub_key in key_to_remove: + if sub_key == "pods": + for i in range(len(wa_response[main_key][sub_key])): + if ( + wa_response[main_key][sub_key][i]["title"] + == "Result" + ): + del wa_response[main_key][sub_key][i + 1 :] + break + sub_items = wa_response[main_key][sub_key] + for i in range(len(sub_items)): + for sub_key_to_remove in key_to_remove[sub_key]: + if sub_key_to_remove in sub_items[i]: + del sub_items[i][sub_key_to_remove] + elif key_to_remove in wa_response[main_key]: + del wa_response[main_key][key_to_remove] + except KeyError: + pass + return wa_response diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index a9f923c87f..a559dbf8c3 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -33,6 +33,13 @@ def tool_runtime_memory_and_search() -> ProviderFixture: "api_key": os.environ["TAVILY_SEARCH_API_KEY"], }, ), + Provider( + provider_id="wolfram-alpha", + provider_type="remote::wolfram-alpha", + config={ + "api_key": os.environ["WOLFRAM_ALPHA_API_KEY"], + }, + ), ], ) @@ -53,12 +60,24 @@ def tool_group_input_tavily_search() -> ToolGroupInput: ) +@pytest.fixture(scope="session") +def tool_group_input_wolfram_alpha() -> ToolGroupInput: + return ToolGroupInput( + toolgroup_id="builtin::wolfram_alpha", + provider_id="wolfram-alpha", + ) + + TOOL_RUNTIME_FIXTURES = ["memory_and_search"] @pytest_asyncio.fixture(scope="session") async def tools_stack( - request, inference_model, tool_group_input_memory, tool_group_input_tavily_search + request, + inference_model, + tool_group_input_memory, + tool_group_input_tavily_search, + tool_group_input_wolfram_alpha, ): fixture_dict = request.param @@ -104,6 +123,7 @@ async def tools_stack( models=models, tool_groups=[ tool_group_input_tavily_search, + tool_group_input_wolfram_alpha, tool_group_input_memory, ], ) diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index 917db55e14..16081b939f 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -20,6 +20,11 @@ def sample_search_query(): return "What are the latest developments in quantum computing?" +@pytest.fixture +def sample_wolfram_alpha_query(): + return "What is the square root of 16?" + + @pytest.fixture def sample_documents(): urls = [ @@ -61,6 +66,24 @@ async def test_web_search_tool(self, tools_stack, sample_search_query): assert len(response.content) > 0 assert isinstance(response.content, str) + @pytest.mark.asyncio + async def test_wolfram_alpha_tool(self, tools_stack, sample_wolfram_alpha_query): + """Test the wolfram alpha tool functionality.""" + if "WOLFRAM_ALPHA_API_KEY" not in os.environ: + pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") + + tools_impl = tools_stack.impls[Api.tool_runtime] + + response = await tools_impl.invoke_tool( + tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query} + ) + + # Verify the response + assert isinstance(response, ToolInvocationResult) + assert response.content is not None + assert len(response.content) > 0 + assert isinstance(response.content, str) + @pytest.mark.asyncio async def test_memory_tool(self, tools_stack, sample_documents): """Test the memory tool functionality.""" From 87068278ac45588f22ac14e35f8aaaf211267e1b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 16:08:04 -0800 Subject: [PATCH 40/53] update open api spec --- docs/resources/llama-stack-spec.html | 203 +++++++++------------------ docs/resources/llama-stack-spec.yaml | 103 +++++--------- 2 files changed, 103 insertions(+), 203 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index fb75259889..2d423b3e62 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -462,46 +462,6 @@ } } }, - "/alpha/tool-runtime/discover": { - "post": { - "responses": { - "200": { - "description": "OK", - "content": { - "application/jsonl": { - "schema": { - "$ref": "#/components/schemas/ToolDef" - } - } - } - } - }, - "tags": [ - "ToolRuntime" - ], - "parameters": [ - { - "name": "X-LlamaStack-ProviderData", - "in": "header", - "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", - "required": false, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/DiscoverToolsRequest" - } - } - }, - "required": true - } - } - }, "/alpha/inference/embeddings": { "post": { "responses": { @@ -1215,7 +1175,7 @@ ], "parameters": [ { - "name": "tool_group_id", + "name": "toolgroup_id", "in": "query", "required": true, "schema": { @@ -1898,7 +1858,7 @@ } }, "tags": [ - "ToolGroups" + "ToolRuntime" ], "summary": "List tools with optional tool group", "parameters": [ @@ -3705,7 +3665,7 @@ "type": "string" } }, - "tools": { + "toolgroups": { "type": "array", "items": { "$ref": "#/components/schemas/AgentTool" @@ -3832,6 +3792,9 @@ ] } }, + "built_in_type": { + "$ref": "#/components/schemas/BuiltinTool" + }, "tool_prompt_format": { "$ref": "#/components/schemas/ToolPromptFormat", "default": "json" @@ -3855,7 +3818,8 @@ "type": "string" }, "required": { - "type": "boolean" + "type": "boolean", + "default": true }, "default": { "oneOf": [ @@ -4580,68 +4544,6 @@ "session_id" ] }, - "MCPToolGroupDef": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "model_context_protocol", - "default": "model_context_protocol" - }, - "endpoint": { - "$ref": "#/components/schemas/URL" - } - }, - "additionalProperties": false, - "required": [ - "type", - "endpoint" - ], - "title": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information." - }, - "ToolGroupDef": { - "oneOf": [ - { - "$ref": "#/components/schemas/MCPToolGroupDef" - }, - { - "$ref": "#/components/schemas/UserDefinedToolGroupDef" - } - ] - }, - "UserDefinedToolGroupDef": { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "user_defined", - "default": "user_defined" - }, - "tools": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolDef" - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "tools" - ] - }, - "DiscoverToolsRequest": { - "type": "object", - "properties": { - "tool_group": { - "$ref": "#/components/schemas/ToolGroupDef" - } - }, - "additionalProperties": false, - "required": [ - "tool_group" - ] - }, "EmbeddingsRequest": { "type": "object", "properties": { @@ -5872,7 +5774,7 @@ "const": "tool", "default": "tool" }, - "tool_group": { + "toolgroup_id": { "type": "string" }, "tool_host": { @@ -5926,7 +5828,7 @@ "provider_resource_id", "provider_id", "type", - "tool_group", + "toolgroup_id", "tool_host", "description", "parameters" @@ -5956,6 +5858,34 @@ "type": "string", "const": "tool_group", "default": "tool_group" + }, + "mcp_endpoint": { + "$ref": "#/components/schemas/URL" + }, + "args": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -7371,20 +7301,45 @@ "RegisterToolGroupRequest": { "type": "object", "properties": { - "tool_group_id": { + "toolgroup_id": { "type": "string" }, - "tool_group_def": { - "$ref": "#/components/schemas/ToolGroupDef" - }, "provider_id": { "type": "string" + }, + "mcp_endpoint": { + "$ref": "#/components/schemas/URL" + }, + "args": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, "required": [ - "tool_group_id", - "tool_group_def" + "toolgroup_id", + "provider_id" ] }, "RunEvalRequest": { @@ -8122,10 +8077,6 @@ "name": "DeleteAgentsSessionRequest", "description": "" }, - { - "name": "DiscoverToolsRequest", - "description": "" - }, { "name": "EfficiencyConfig", "description": "" @@ -8250,10 +8201,6 @@ "name": "LoraFinetuningConfig", "description": "" }, - { - "name": "MCPToolGroupDef", - "description": "A tool group that is defined by in a model context protocol server. Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.\n\n" - }, { "name": "Memory" }, @@ -8568,10 +8515,6 @@ "name": "ToolGroup", "description": "" }, - { - "name": "ToolGroupDef", - "description": "" - }, { "name": "ToolGroups" }, @@ -8642,10 +8585,6 @@ "name": "UnstructuredLogEvent", "description": "" }, - { - "name": "UserDefinedToolGroupDef", - "description": "" - }, { "name": "UserMessage", "description": "" @@ -8742,7 +8681,6 @@ "Dataset", "DeleteAgentsRequest", "DeleteAgentsSessionRequest", - "DiscoverToolsRequest", "EfficiencyConfig", "EmbeddingsRequest", "EmbeddingsResponse", @@ -8771,7 +8709,6 @@ "LogEventRequest", "LogSeverity", "LoraFinetuningConfig", - "MCPToolGroupDef", "MemoryBankDocument", "MemoryRetrievalStep", "Message", @@ -8843,7 +8780,6 @@ "ToolDefinition", "ToolExecutionStep", "ToolGroup", - "ToolGroupDef", "ToolHost", "ToolInvocationResult", "ToolParamDefinition", @@ -8860,7 +8796,6 @@ "UnregisterModelRequest", "UnregisterToolGroupRequest", "UnstructuredLogEvent", - "UserDefinedToolGroupDef", "UserMessage", "VectorMemoryBank", "VectorMemoryBankParams", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 0937d87224..bf3b515f27 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -46,7 +46,7 @@ components: tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json - tools: + toolgroups: items: $ref: '#/components/schemas/AgentTool' type: array @@ -729,14 +729,6 @@ components: - agent_id - session_id type: object - DiscoverToolsRequest: - additionalProperties: false - properties: - tool_group: - $ref: '#/components/schemas/ToolGroupDef' - required: - - tool_group - type: object EfficiencyConfig: additionalProperties: false properties: @@ -1186,21 +1178,6 @@ components: - rank - alpha type: object - MCPToolGroupDef: - additionalProperties: false - properties: - endpoint: - $ref: '#/components/schemas/URL' - type: - const: model_context_protocol - default: model_context_protocol - type: string - required: - - type - - endpoint - title: A tool group that is defined by in a model context protocol server. Refer - to https://modelcontextprotocol.io/docs/concepts/tools for more information. - type: object MemoryBankDocument: additionalProperties: false properties: @@ -1904,15 +1881,25 @@ components: RegisterToolGroupRequest: additionalProperties: false properties: + args: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + mcp_endpoint: + $ref: '#/components/schemas/URL' provider_id: type: string - tool_group_def: - $ref: '#/components/schemas/ToolGroupDef' - tool_group_id: + toolgroup_id: type: string required: - - tool_group_id - - tool_group_def + - toolgroup_id + - provider_id type: object ResponseFormat: oneOf: @@ -2607,13 +2594,13 @@ components: type: string provider_resource_id: type: string - tool_group: - type: string tool_host: $ref: '#/components/schemas/ToolHost' tool_prompt_format: $ref: '#/components/schemas/ToolPromptFormat' default: json + toolgroup_id: + type: string type: const: tool default: tool @@ -2623,7 +2610,7 @@ components: - provider_resource_id - provider_id - type - - tool_group + - toolgroup_id - tool_host - description - parameters @@ -2695,6 +2682,8 @@ components: ToolDef: additionalProperties: false properties: + built_in_type: + $ref: '#/components/schemas/BuiltinTool' description: type: string metadata: @@ -2770,8 +2759,20 @@ components: ToolGroup: additionalProperties: false properties: + args: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object identifier: type: string + mcp_endpoint: + $ref: '#/components/schemas/URL' provider_id: type: string provider_resource_id: @@ -2786,10 +2787,6 @@ components: - provider_id - type type: object - ToolGroupDef: - oneOf: - - $ref: '#/components/schemas/MCPToolGroupDef' - - $ref: '#/components/schemas/UserDefinedToolGroupDef' ToolHost: enum: - distribution @@ -2847,6 +2844,7 @@ components: parameter_type: type: string required: + default: true type: boolean required: - name @@ -3087,21 +3085,6 @@ components: - message - severity type: object - UserDefinedToolGroupDef: - additionalProperties: false - properties: - tools: - items: - $ref: '#/components/schemas/ToolDef' - type: array - type: - const: user_defined - default: user_defined - type: string - required: - - type - - tools - type: object UserMessage: additionalProperties: false properties: @@ -4862,9 +4845,6 @@ tags: - description: name: DeleteAgentsSessionRequest -- description: - name: DiscoverToolsRequest - description: name: EfficiencyConfig @@ -4947,12 +4927,6 @@ tags: - description: name: LoraFinetuningConfig -- description: 'A tool group that is defined by in a model context protocol server. - Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information. - - - ' - name: MCPToolGroupDef - name: Memory - description: @@ -5158,8 +5132,6 @@ tags: name: ToolExecutionStep - description: name: ToolGroup -- description: - name: ToolGroupDef - name: ToolGroups - description: name: ToolHost @@ -5214,9 +5186,6 @@ tags: - description: name: UnstructuredLogEvent -- description: - name: UserDefinedToolGroupDef - description: name: UserMessage - description: Date: Tue, 7 Jan 2025 16:17:38 -0800 Subject: [PATCH 41/53] fix list tools method name --- docs/resources/llama-stack-spec.html | 64 ++++++++++++++++++- docs/resources/llama-stack-spec.yaml | 10 +++ llama_stack/apis/tools/tools.py | 4 +- llama_stack/distribution/routers/routers.py | 2 +- .../distribution/routers/routing_tables.py | 2 +- .../code_interpreter/code_interpreter.py | 2 +- .../inline/tool_runtime/memory/memory.py | 2 +- .../tool_runtime/bing_search/bing_search.py | 2 +- .../tool_runtime/brave_search/brave_search.py | 2 +- .../model_context_protocol.py | 2 +- .../tavily_search/tavily_search.py | 2 +- .../wolfram_alpha/wolfram_alpha.py | 2 +- 12 files changed, 84 insertions(+), 12 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 2d423b3e62..8ce86d3676 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -1752,6 +1752,54 @@ ] } }, + "/alpha/tool-runtime/list-tools": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/jsonl": { + "schema": { + "$ref": "#/components/schemas/ToolDef" + } + } + } + } + }, + "tags": [ + "ToolRuntime" + ], + "parameters": [ + { + "name": "tool_group_id", + "in": "query", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListRuntimeToolsRequest" + } + } + }, + "required": true + } + } + }, "/alpha/scoring-functions/list": { "get": { "responses": { @@ -1858,7 +1906,7 @@ } }, "tags": [ - "ToolRuntime" + "ToolGroups" ], "summary": "List tools with optional tool group", "parameters": [ @@ -6207,6 +6255,15 @@ "provider_types" ] }, + "ListRuntimeToolsRequest": { + "type": "object", + "properties": { + "mcp_endpoint": { + "$ref": "#/components/schemas/URL" + } + }, + "additionalProperties": false + }, "LogSeverity": { "type": "string", "enum": [ @@ -8189,6 +8246,10 @@ "name": "LLMAsJudgeScoringFnParams", "description": "" }, + { + "name": "ListRuntimeToolsRequest", + "description": "" + }, { "name": "LogEventRequest", "description": "" @@ -8706,6 +8767,7 @@ "KeywordMemoryBank", "KeywordMemoryBankParams", "LLMAsJudgeScoringFnParams", + "ListRuntimeToolsRequest", "LogEventRequest", "LogSeverity", "LoraFinetuningConfig", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index bf3b515f27..3e14bfe768 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -1122,6 +1122,12 @@ components: - type - judge_model type: object + ListRuntimeToolsRequest: + additionalProperties: false + properties: + mcp_endpoint: + $ref: '#/components/schemas/URL' + type: object LogEventRequest: additionalProperties: false properties: @@ -4919,6 +4925,9 @@ tags: - description: name: LLMAsJudgeScoringFnParams +- description: + name: ListRuntimeToolsRequest - description: name: LogEventRequest @@ -5290,6 +5299,7 @@ x-tagGroups: - KeywordMemoryBank - KeywordMemoryBankParams - LLMAsJudgeScoringFnParams + - ListRuntimeToolsRequest - LogEventRequest - LogSeverity - LoraFinetuningConfig diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 0c2bb58633..dbfd852206 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -130,8 +130,8 @@ async def unregister_tool_group(self, tool_group_id: str) -> None: class ToolRuntime(Protocol): tool_store: ToolStore - @webmethod(route="/tool-runtime/list-tools", method="POST") - async def list_tools( + @webmethod(route="/tool-runtime/list-tools", method="GET") + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 230feea710..05d43ad4f4 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -417,7 +417,7 @@ async def invoke_tool(self, tool_name: str, args: Dict[str, Any]) -> Any: args=args, ) - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return await self.routing_table.get_provider_impl(tool_group_id).list_tools( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 2f0288865c..36ddda7a65 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -508,7 +508,7 @@ async def register_tool_group( args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_tools( + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( toolgroup_id, mcp_endpoint ) tool_host = ( diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 3b3f180320..98026fa3d1 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -44,7 +44,7 @@ async def register_tool(self, tool: Tool): async def unregister_tool(self, tool_id: str) -> None: return - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return [ diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index c8c2cc772f..f27cb9dd4a 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -51,7 +51,7 @@ def __init__( async def initialize(self): pass - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return [ diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index b0c30b0a08..a69f08ce81 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -51,7 +51,7 @@ def _get_api_key(self) -> str: ) return provider_data.api_key - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return [ diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index dab6ce4398..05a3f25663 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -49,7 +49,7 @@ def _get_api_key(self) -> str: ) return provider_data.api_key - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return [ diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index dd2bb5e5e5..a304167e92 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -29,7 +29,7 @@ def __init__(self, config: ModelContextProtocolConfig): async def initialize(self): pass - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: if mcp_endpoint is None: diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index d22f188b32..8f666a6fb4 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -50,7 +50,7 @@ def _get_api_key(self) -> str: ) return provider_data.api_key - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return [ diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 0f3fdfb39e..13c298eb23 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -51,7 +51,7 @@ def _get_api_key(self) -> str: ) return provider_data.api_key - async def list_tools( + async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: return [ From c3865faf3775162094fde978cf61a841a1f768ac Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 16:37:39 -0800 Subject: [PATCH 42/53] minor fixes --- docs/resources/llama-stack-spec.html | 2 +- docs/resources/llama-stack-spec.yaml | 2 +- llama_stack/apis/agents/agents.py | 2 +- .../inline/agents/meta_reference/agents.py | 4 ++-- llama_stack/providers/tests/agents/test_agents.py | 3 ++- tests/client-sdk/agents/test_agents.py | 15 ++++++++++++--- 6 files changed, 19 insertions(+), 9 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 8ce86d3676..e98de6491e 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -4012,7 +4012,7 @@ ] } }, - "tools": { + "toolgroups": { "type": "array", "items": { "$ref": "#/components/schemas/AgentTool" diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 3e14bfe768..924fc32a10 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -624,7 +624,7 @@ components: type: string stream: type: boolean - tools: + toolgroups: items: $ref: '#/components/schemas/AgentTool' type: array diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index f5fbcb9c40..fb9df21e67 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -317,7 +317,7 @@ async def create_agent_turn( ], stream: Optional[bool] = False, documents: Optional[List[Document]] = None, - tools: Optional[List[AgentToolGroup]] = None, + toolgroups: Optional[List[AgentToolGroup]] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/turn/get") diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 2ea74300dd..faff716ce6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -147,7 +147,7 @@ async def create_agent_turn( ToolResponseMessage, ] ], - tools: Optional[List[AgentToolGroup]] = None, + toolgroups: Optional[List[AgentToolGroup]] = None, documents: Optional[List[Document]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: @@ -156,7 +156,7 @@ async def create_agent_turn( session_id=session_id, messages=messages, stream=True, - tools=tools, + toolgroups=toolgroups, documents=documents, ) if stream: diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index fb22e976e5..3d18429563 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -8,6 +8,7 @@ from typing import Dict, List import pytest +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, @@ -335,7 +336,7 @@ async def test_create_agent_turn_with_tavily_search( assert isinstance(tool_execution, ToolExecutionStep) assert len(tool_execution.tool_calls) > 0 actual_tool_name = tool_execution.tool_calls[0].tool_name - assert actual_tool_name == "web_search" + assert actual_tool_name == BuiltinTool.brave_search assert len(tool_execution.tool_responses) > 0 check_turn_complete_event(turn_response, session_id, search_query_messages) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a760bb08ab..09cedced35 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -9,7 +9,7 @@ from uuid import uuid4 import pytest -from llama_stack_client.lib.agents.agent import Agent, maybe_register_memory_tool +from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage @@ -293,9 +293,18 @@ def test_rag_agent(llama_stack_client, agent_config): for i, url in enumerate(urls) ] - tool_name, memory_bank_id = maybe_register_memory_tool(llama_stack_client) - agent_config["tools"].append(tool_name) + agent_config["tools"].append("builtin::memory") agent = Agent(llama_stack_client, agent_config) + memory_bank_id = "test-memory-bank" + llama_stack_client.memory_banks.register( + memory_bank_id=memory_bank_id, + params={ + "memory_bank_type": "vector", + "embedding_model": "all-MiniLM-L6-v2", + "chunk_size_in_tokens": 512, + "overlap_size_in_tokens": 64, + }, + ) llama_stack_client.memory.insert( bank_id=memory_bank_id, documents=documents, From efe3189728be97954d3bbf6dae6c9ee889b737ca Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 16:59:01 -0800 Subject: [PATCH 43/53] client sdk test fixes --- .../agents/meta_reference/agent_instance.py | 13 ++++---- .../inline/tool_runtime/memory/memory.py | 2 +- tests/client-sdk/agents/test_agents.py | 31 +++++++++---------- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cfe839dad8..528246cdf8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -402,7 +402,6 @@ async def _run( # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: query_args["memory_bank_id"] = session_info.memory_bank_id - serialized_args = tracing.serialize_value(query_args) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -412,8 +411,8 @@ async def _run( parse_status=ToolCallParseStatus.success, content=ToolCall( call_id="", - tool_name="memory", - arguments=serialized_args, + tool_name=MEMORY_QUERY_TOOL, + arguments={}, ), ), ) @@ -435,14 +434,14 @@ async def _run( tool_calls=[ ToolCall( call_id="", - tool_name="memory", - arguments=serialized_args, + tool_name=MEMORY_QUERY_TOOL, + arguments={}, ) ], tool_responses=[ ToolResponse( call_id="", - tool_name="memory", + tool_name=MEMORY_QUERY_TOOL, content=result.content, ) ], @@ -456,7 +455,7 @@ async def _run( span.set_attribute("output", result.content) span.set_attribute("error_code", result.error_code) span.set_attribute("error_message", result.error_message) - span.set_attribute("tool_name", "memory") + span.set_attribute("tool_name", MEMORY_QUERY_TOOL) if result.error_code == 0: last_message = input_messages[-1] last_message.context = result.content diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index f27cb9dd4a..a6ce744a68 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -56,7 +56,7 @@ async def list_runtime_tools( ) -> List[ToolDef]: return [ ToolDef( - name="memory", + name="query_memory", description="Retrieve context from memory", parameters=[ ToolParameter( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 09cedced35..a4ad2278fa 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -101,7 +101,7 @@ def agent_config(llama_stack_client): "temperature": 1.0, "top_p": 0.9, }, - tools=[], + toolgroups=[], tool_choice="auto", tool_prompt_format="json", input_shields=available_shields, @@ -152,8 +152,8 @@ def test_agent_simple(llama_stack_client, agent_config): def test_builtin_tool_web_search(llama_stack_client, agent_config): agent_config = { **agent_config, - "tools": [ - "builtin::web_search", + "toolgroups": [ + "builtin::websearch", ], } agent = Agent(llama_stack_client, agent_config) @@ -181,7 +181,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config): agent_config = { **agent_config, - "tools": [ + "toolgroups": [ "builtin::code_interpreter", ], } @@ -208,7 +208,7 @@ def test_code_execution(llama_stack_client): agent_config = AgentConfig( model="meta-llama/Llama-3.1-70B-Instruct", instructions="You are a helpful assistant", - tools=[ + toolgroups=[ "builtin::code_interpreter", ], tool_choice="required", @@ -250,7 +250,7 @@ def test_custom_tool(llama_stack_client, agent_config): agent_config = { **agent_config, "model": "meta-llama/Llama-3.2-3B-Instruct", - "tools": ["builtin::web_search"], + "toolgroups": ["builtin::websearch"], "client_tools": [client_tool.get_tool_definition()], "tool_prompt_format": "python_list", } @@ -293,9 +293,14 @@ def test_rag_agent(llama_stack_client, agent_config): for i, url in enumerate(urls) ] - agent_config["tools"].append("builtin::memory") - agent = Agent(llama_stack_client, agent_config) memory_bank_id = "test-memory-bank" + agent_config["toolgroups"].append( + dict( + name="builtin::memory", + args={"memory_bank_id": memory_bank_id}, + ) + ) + agent = Agent(llama_stack_client, agent_config) llama_stack_client.memory_banks.register( memory_bank_id=memory_bank_id, params={ @@ -326,16 +331,8 @@ def test_rag_agent(llama_stack_client, agent_config): } ], session_id=session_id, - tools=[ - { - "name": "memory", - "args": { - "memory_bank_id": memory_bank_id, - }, - } - ], ) logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - assert "Tool:memory" in logs_str + assert "Tool:query_memory" in logs_str From 82395ba65415b8abe8b11a24af8b5e13cba2b1db Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 21:33:25 -0800 Subject: [PATCH 44/53] fix the rag query generator types --- .../agents/meta_reference/agent_instance.py | 4 +++- .../tool_runtime/memory/context_retriever.py | 19 +++++++++++++------ .../inline/tool_runtime/memory/memory.py | 2 +- tests/client-sdk/agents/test_agents.py | 4 ++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 528246cdf8..f9ffb2ae09 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -401,7 +401,9 @@ async def _run( session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it if session_info.memory_bank_id: - query_args["memory_bank_id"] = session_info.memory_bank_id + if "memory_bank_ids" not in query_args: + query_args["memory_bank_ids"] = [] + query_args["memory_bank_ids"].append(session_info.memory_bank_id) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( diff --git a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py index 1fb1d09920..803981f074 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/memory/context_retriever.py @@ -8,8 +8,10 @@ from typing import List from jinja2 import Template +from pydantic import BaseModel -from llama_stack.apis.inference import Message, UserMessage +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import UserMessage from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -24,7 +26,7 @@ async def generate_rag_query( config: MemoryQueryGeneratorConfig, - messages: List[Message], + messages: List[InterleavedContent], **kwargs, ): """ @@ -42,21 +44,26 @@ async def generate_rag_query( async def default_rag_query_generator( config: DefaultMemoryQueryGeneratorConfig, - messages: List[Message], + messages: List[InterleavedContent], **kwargs, ): - return config.sep.join(interleaved_content_as_str(m.content) for m in messages) + return config.sep.join(interleaved_content_as_str(m) for m in messages) async def llm_rag_query_generator( config: LLMMemoryQueryGeneratorConfig, - messages: List[Message], + messages: List[InterleavedContent], **kwargs, ): assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api" inference_api = kwargs["inference_api"] - m_dict = {"messages": [message.model_dump() for message in messages]} + m_dict = { + "messages": [ + message.model_dump() if isinstance(message, BaseModel) else message + for message in messages + ] + } template = Template(config.template) content = template.render(m_dict) diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index a6ce744a68..f46b375105 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -69,7 +69,7 @@ async def list_runtime_tools( ] async def _retrieve_context( - self, input_messages: List[str], bank_ids: List[str] + self, input_messages: List[InterleavedContent], bank_ids: List[str] ) -> Optional[List[InterleavedContent]]: if not bank_ids: return None diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index a4ad2278fa..01ffe2025f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -206,7 +206,7 @@ def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_code_execution(llama_stack_client): agent_config = AgentConfig( - model="meta-llama/Llama-3.1-70B-Instruct", + model="meta-llama/Llama-3.1-8B-Instruct", instructions="You are a helpful assistant", toolgroups=[ "builtin::code_interpreter", @@ -297,7 +297,7 @@ def test_rag_agent(llama_stack_client, agent_config): agent_config["toolgroups"].append( dict( name="builtin::memory", - args={"memory_bank_id": memory_bank_id}, + args={"memory_bank_ids": [memory_bank_id]}, ) ) agent = Agent(llama_stack_client, agent_config) From db2ec110a1a85c7cadc644d44e71873617f0bea9 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 7 Jan 2025 22:13:33 -0800 Subject: [PATCH 45/53] fix failing code interpreter tests --- .../agents/meta_reference/agent_instance.py | 3 +- tests/client-sdk/agents/test_agents.py | 48 +++++++------------ 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index f9ffb2ae09..0c1d50b15f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -78,7 +78,6 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") MEMORY_TOOL_GROUP_ID = "builtin::memory" MEMORY_QUERY_TOOL = "query_memory" -CODE_INTERPRETER_TOOL = "code_interpreter" WEB_SEARCH_TOOL = "web_search" @@ -787,7 +786,7 @@ async def handle_documents( tool_defs: Dict[str, ToolDefinition], ) -> None: memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) - code_interpreter_tool = tool_defs.get(CODE_INTERPRETER_TOOL, None) + code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 01ffe2025f..a2ed687a49 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -275,14 +275,7 @@ def test_custom_tool(llama_stack_client, agent_config): def test_rag_agent(llama_stack_client, agent_config): - urls = [ - "memory_optimizations.rst", - "chat.rst", - "llama3.rst", - "datasets.rst", - "qat_finetune.rst", - "lora_finetune.rst", - ] + urls = ["chat.rst", "llama3.rst", "datasets.rst", "lora_finetune.rst"] documents = [ Document( document_id=f"num-{i}", @@ -292,15 +285,7 @@ def test_rag_agent(llama_stack_client, agent_config): ) for i, url in enumerate(urls) ] - memory_bank_id = "test-memory-bank" - agent_config["toolgroups"].append( - dict( - name="builtin::memory", - args={"memory_bank_ids": [memory_bank_id]}, - ) - ) - agent = Agent(llama_stack_client, agent_config) llama_stack_client.memory_banks.register( memory_bank_id=memory_bank_id, params={ @@ -314,25 +299,28 @@ def test_rag_agent(llama_stack_client, agent_config): bank_id=memory_bank_id, documents=documents, ) - session_id = agent.create_session(f"test-session-{uuid4()}") - + agent_config = { + **agent_config, + "toolgroups": [ + dict( + name="builtin::memory", + args={ + "memory_bank_ids": [memory_bank_id], + }, + ) + ], + } + rag_agent = Agent(llama_stack_client, agent_config) + session_id = rag_agent.create_session("test-session") user_prompts = [ - "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", - "Was anything related to 'Llama3' discussed, if so what?", - "Tell me how to use LoRA", + "What are the top 5 topics that were explained? Only list succinct bullet points.", ] - for prompt in user_prompts: - response = agent.create_turn( - messages=[ - { - "role": "user", - "content": prompt, - } - ], + print(f"User> {prompt}") + response = rag_agent.create_turn( + messages=[{"role": "user", "content": prompt}], session_id=session_id, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_memory" in logs_str From 854fef74784c5794dae2cdfcd97b8af09edb2707 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 13:57:26 -0800 Subject: [PATCH 46/53] add unit tests for chat agent --- .../agents/meta_reference/agent_instance.py | 26 +- .../meta_reference/tests/test_chat_agent.py | 326 +++++++++++++----- .../agents/meta_reference/tools/safety.py | 42 --- .../providers/tests/agents/test_agents.py | 69 ---- 4 files changed, 259 insertions(+), 204 deletions(-) delete mode 100644 llama_stack/providers/inline/agents/meta_reference/tools/safety.py diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 0c1d50b15f..52293182cc 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -13,7 +13,7 @@ import string import uuid from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx @@ -76,7 +76,6 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") -MEMORY_TOOL_GROUP_ID = "builtin::memory" MEMORY_QUERY_TOOL = "query_memory" WEB_SEARCH_TOOL = "web_search" @@ -382,6 +381,9 @@ async def _run( session_id, documents, input_messages, tool_defs ) if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0: + memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None) + if memory_tool_group is None: + raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}") with tracing.span(MEMORY_QUERY_TOOL) as span: step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -394,7 +396,7 @@ async def _run( ) query_args = { "messages": [msg.content for msg in input_messages], - **toolgroup_args.get(MEMORY_TOOL_GROUP_ID, {}), + **toolgroup_args.get(memory_tool_group, {}), } session_info = await self.storage.get_session_info(session_id) @@ -484,14 +486,20 @@ async def _run( stop_reason = None with tracing.span("inference") as span: + + def is_memory_group(tool): + memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None) + has_memory_tool = MEMORY_QUERY_TOOL in tool_defs + return ( + has_memory_tool + and tool_to_group.get(tool.tool_name, None) != memory_tool_group + ) + async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, tools=[ - tool - for tool in tool_defs.values() - if tool_to_group.get(tool.tool_name, None) - != MEMORY_TOOL_GROUP_ID + tool for tool in tool_defs.values() if not is_memory_group(tool) ], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, @@ -698,8 +706,8 @@ def interpret_content_as_attachment( n_iter += 1 async def _get_tool_defs( - self, toolgroups_for_turn: Optional[List[AgentToolGroup]] - ) -> Dict[str, ToolDefinition]: + self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None + ) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]: # Determine which tools to include agent_config_toolgroups = set( ( diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 0350543204..6e789bf19f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -4,21 +4,25 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import tempfile from typing import AsyncIterator, List, Optional, Union import pytest +from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.agents import ( AgentConfig, AgentTurnCreateRequest, AgentTurnResponseTurnCompletePayload, + StepType, ) - +from llama_stack.apis.common.content_types import URL from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseStreamChunk, CompletionMessage, + LogProbConfig, Message, ResponseFormat, SamplingParams, @@ -27,13 +31,24 @@ UserMessage, ) from llama_stack.apis.memory import MemoryBank +from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank from llama_stack.apis.safety import RunShieldResponse - -from ..agents import ( - AGENT_INSTANCES_BY_ID, +from llama_stack.apis.tools import ( + Tool, + ToolDef, + ToolGroup, + ToolHost, + ToolInvocationResult, + ToolPromptFormat, +) +from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( + MEMORY_QUERY_TOOL, +) +from llama_stack.providers.inline.agents.meta_reference.agents import ( MetaReferenceAgentsImpl, - MetaReferenceInferenceConfig, + MetaReferenceAgentsImplConfig, ) +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class MockInferenceAPI: @@ -48,10 +63,10 @@ async def chat_completion( tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> AsyncIterator[ - Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: - if stream: + async def stream_response(): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type="start", @@ -65,19 +80,7 @@ async def chat_completion( delta="AI is a fascinating field...", ) ) - # yield ChatCompletionResponseStreamChunk( - # event=ChatCompletionResponseEvent( - # event_type="progress", - # delta=ToolCallDelta( - # content=ToolCall( - # call_id="123", - # tool_name=BuiltinTool.brave_search.value, - # arguments={"query": "AI history"}, - # ), - # parse_status="success", - # ), - # ) - # ) + yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type="complete", @@ -85,12 +88,17 @@ async def chat_completion( stop_reason="end_of_turn", ) ) + + if stream: + return stream_response() else: - yield ChatCompletionResponse( + return ChatCompletionResponse( completion_message=CompletionMessage( - role="assistant", content="Mock response", stop_reason="end_of_turn" + role="assistant", + content="Mock response", + stop_reason="end_of_turn", ), - logprobs=[0.1, 0.2, 0.3] if logprobs else None, + logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None, ) @@ -165,6 +173,99 @@ async def delete_documents(self, bank_id, document_ids): self.documents[bank_id].pop(doc_id, None) +class MockToolGroupsAPI: + async def register_tool_group( + self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None + ) -> None: + pass + + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: + return ToolGroup( + identifier=toolgroup_id, + provider_resource_id=toolgroup_id, + ) + + async def list_tool_groups(self) -> List[ToolGroup]: + return [] + + async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: + if tool_group_id == MEMORY_TOOLGROUP: + return [ + Tool( + identifier=MEMORY_QUERY_TOOL, + provider_resource_id=MEMORY_QUERY_TOOL, + toolgroup_id=MEMORY_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="mock_provider", + parameters=[], + ) + ] + if tool_group_id == CODE_INTERPRETER_TOOLGROUP: + return [ + Tool( + identifier="code_interpreter", + provider_resource_id="code_interpreter", + toolgroup_id=CODE_INTERPRETER_TOOLGROUP, + built_in_type=BuiltinTool.code_interpreter, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="mock_provider", + parameters=[], + ) + ] + return [] + + async def get_tool(self, tool_name: str) -> Tool: + return Tool( + identifier=tool_name, + provider_resource_id=tool_name, + toolgroup_id="mock_group", + tool_host=ToolHost.client, + description="Mock tool", + provider_id="mock_provider", + parameters=[], + ) + + async def unregister_tool_group(self, tool_group_id: str) -> None: + pass + + +class MockToolRuntimeAPI: + async def list_runtime_tools( + self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None + ) -> List[ToolDef]: + return [] + + async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult: + return ToolInvocationResult(content={"result": "Mock tool result"}) + + +class MockMemoryBanksAPI: + async def list_memory_banks(self) -> List[MemoryBank]: + return [] + + async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: + return None + + async def register_memory_bank( + self, + memory_bank_id: str, + params: BankParams, + provider_id: Optional[str] = None, + provider_memory_bank_id: Optional[str] = None, + ) -> MemoryBank: + return VectorMemoryBank( + identifier=memory_bank_id, + provider_resource_id=provider_memory_bank_id or memory_bank_id, + embedding_model="mock_model", + chunk_size_in_tokens=512, + ) + + async def unregister_memory_bank(self, memory_bank_id: str) -> None: + pass + + @pytest.fixture def mock_inference_api(): return MockInferenceAPI() @@ -181,64 +282,107 @@ def mock_memory_api(): @pytest.fixture -async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): +def mock_tool_groups_api(): + return MockToolGroupsAPI() + + +@pytest.fixture +def mock_tool_runtime_api(): + return MockToolRuntimeAPI() + + +@pytest.fixture +def mock_memory_banks_api(): + return MockMemoryBanksAPI() + + +@pytest.fixture +async def get_agents_impl( + mock_inference_api, + mock_safety_api, + mock_memory_api, + mock_memory_banks_api, + mock_tool_runtime_api, + mock_tool_groups_api, +): + sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") impl = MetaReferenceAgentsImpl( - config=MetaReferenceInferenceConfig(), + config=MetaReferenceAgentsImplConfig( + persistence_store=SqliteKVStoreConfig( + db_name=sqlite_file.name, + ), + ), inference_api=mock_inference_api, safety_api=mock_safety_api, memory_api=mock_memory_api, + memory_banks_api=mock_memory_banks_api, + tool_runtime_api=mock_tool_runtime_api, + tool_groups_api=mock_tool_groups_api, ) await impl.initialize() + return impl + +@pytest.fixture +async def get_chat_agent(get_agents_impl): + impl = await get_agents_impl agent_config = AgentConfig( model="test_model", instructions="You are a helpful assistant.", - sampling_params=SamplingParams(), - tools=[ - # SearchToolDefinition( - # name="brave_search", - # api_key="test_key", - # ), - ], + toolgroups=[], tool_choice=ToolChoice.auto, enable_session_persistence=False, - input_shields=[], - output_shields=[], + input_shields=["test_shield"], ) response = await impl.create_agent(agent_config) - agent = AGENT_INSTANCES_BY_ID[response.agent_id] - return agent + return await impl.get_agent(response.agent_id) -@pytest.mark.asyncio -async def test_chat_agent_create_session(chat_agent): - session = chat_agent.create_session("Test Session") - assert session.session_name == "Test Session" - assert session.turns == [] - assert session.session_id in chat_agent.sessions +MEMORY_TOOLGROUP = "builtin::memory" +CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter" + + +@pytest.fixture +async def get_chat_agent_with_tools(get_agents_impl, request): + impl = await get_agents_impl + toolgroups = request.param + agent_config = AgentConfig( + model="test_model", + instructions="You are a helpful assistant.", + toolgroups=toolgroups, + tool_choice=ToolChoice.auto, + enable_session_persistence=False, + input_shields=["test_shield"], + ) + response = await impl.create_agent(agent_config) + return await impl.get_agent(response.agent_id) @pytest.mark.asyncio -async def test_chat_agent_create_and_execute_turn(chat_agent): - session = chat_agent.create_session("Test Session") +async def test_chat_agent_create_and_execute_turn(get_chat_agent): + chat_agent = await get_chat_agent + session_id = await chat_agent.create_session("Test Session") request = AgentTurnCreateRequest( - agent_id="random", - session_id=session.session_id, + agent_id=chat_agent.agent_id, + session_id=session_id, messages=[UserMessage(content="Hello")], + stream=True, ) responses = [] async for response in chat_agent.create_and_execute_turn(request): responses.append(response) - print(responses) assert len(responses) > 0 - assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete + assert ( + len(responses) == 7 + ) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete assert responses[0].event.payload.turn_id is not None @pytest.mark.asyncio -async def test_run_multiple_shields_wrapper(chat_agent): +async def test_run_multiple_shields_wrapper(get_chat_agent): + chat_agent = await get_chat_agent messages = [UserMessage(content="Test message")] shields = ["test_shield"] @@ -254,69 +398,83 @@ async def test_run_multiple_shields_wrapper(chat_agent): assert len(responses) == 2 # StepStart, StepComplete assert responses[0].event.payload.step_type.value == "shield_call" - assert not responses[1].event.payload.step_details.response.is_violation + assert not responses[1].event.payload.step_details.violation @pytest.mark.asyncio -@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily") -async def test_chat_agent_complex_turn(chat_agent): - # Setup - session = chat_agent.create_session("Test Session") +async def test_chat_agent_complex_turn(get_chat_agent): + chat_agent = await get_chat_agent + session_id = await chat_agent.create_session("Test Session") request = AgentTurnCreateRequest( - agent_id="random", - session_id=session.session_id, + agent_id=chat_agent.agent_id, + session_id=session_id, messages=[UserMessage(content="Tell me about AI and then use a tool.")], stream=True, ) - # Execute the turn responses = [] async for response in chat_agent.create_and_execute_turn(request): responses.append(response) - # Assertions assert len(responses) > 0 - # Check for the presence of different step types step_types = [ response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type") ] - assert "shield_call" in step_types, "Shield call step is missing" - assert "inference" in step_types, "Inference step is missing" - assert "tool_execution" in step_types, "Tool execution step is missing" + assert StepType.shield_call in step_types, "Shield call step is missing" + assert StepType.inference in step_types, "Inference step is missing" - # Check for the presence of start and complete events event_types = [ response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type") ] - assert "start" in event_types, "Start event is missing" - assert "complete" in event_types, "Complete event is missing" + assert "turn_start" in event_types, "Start event is missing" + assert "turn_complete" in event_types, "Complete event is missing" - # Check for the presence of tool call - tool_calls = [ - response.event.payload.tool_call - for response in responses - if hasattr(response.event.payload, "tool_call") - ] - assert any( - tool_call - for tool_call in tool_calls - if tool_call and tool_call.content.get("name") == "memory" - ), "Memory tool call is missing" - - # Check for the final turn complete event assert any( isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses ), "Turn complete event is missing" + turn_complete_payload = next( + response.event.payload + for response in responses + if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) + ) + turn = turn_complete_payload.turn + assert turn.input_messages == request.messages, "Input messages do not match" - # Verify the turn was added to the session - assert len(session.turns) == 1, "Turn was not added to the session" - assert ( - session.turns[0].input_messages == request.messages - ), "Input messages do not match" + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "toolgroups, expected_memory, expected_code_interpreter", + [ + ([], False, False), # no tools + ([MEMORY_TOOLGROUP], True, False), # memory only + ([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only + ([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools + ], +) +async def test_chat_agent_tools( + get_agents_impl, toolgroups, expected_memory, expected_code_interpreter +): + impl = await get_agents_impl + agent_config = AgentConfig( + model="test_model", + instructions="You are a helpful assistant.", + toolgroups=toolgroups, + tool_choice=ToolChoice.auto, + enable_session_persistence=False, + input_shields=["test_shield"], + ) + response = await impl.create_agent(agent_config) + chat_agent = await impl.get_agent(response.agent_id) + + tool_defs, _ = await chat_agent._get_tool_defs() + if expected_memory: + assert MEMORY_QUERY_TOOL in tool_defs + if expected_code_interpreter: + assert BuiltinTool.code_interpreter in tool_defs diff --git a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py deleted file mode 100644 index a34649756e..0000000000 --- a/llama_stack/providers/inline/agents/meta_reference/tools/safety.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List - -from llama_stack.apis.inference import Message -from llama_stack.apis.safety import Safety - -from ..safety import ShieldRunnerMixin -from .builtin import BaseTool - - -class SafeTool(BaseTool, ShieldRunnerMixin): - """A tool that makes other tools safety enabled""" - - def __init__( - self, - tool: BaseTool, - safety_api: Safety, - input_shields: List[str] = None, - output_shields: List[str] = None, - ): - self._tool = tool - ShieldRunnerMixin.__init__( - self, safety_api, input_shields=input_shields, output_shields=output_shields - ) - - def get_name(self) -> str: - return self._tool.get_name() - - async def run(self, messages: List[Message]) -> List[Message]: - if self.input_shields: - await self.run_multiple_shields(messages, self.input_shields) - # run the underlying tool - res = await self._tool.run(messages) - if self.output_shields: - await self.run_multiple_shields(messages, self.output_shields) - - return res diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 3d18429563..27fb905722 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -from typing import Dict, List import pytest from llama_models.llama3.api.datatypes import BuiltinTool @@ -83,74 +82,6 @@ def query_attachment_messages(): ] -async def create_agent_turn_with_toolgroup( - agents_stack: Dict[str, object], - search_query_messages: List[object], - common_params: Dict[str, str], - toolgroup_name: str, -) -> None: - """ - Create an agent turn with a toolgroup. - - Args: - agents_stack (Dict[str, object]): The agents stack. - search_query_messages (List[object]): The search query messages. - common_params (Dict[str, str]): The common parameters. - toolgroup_name (str): The name of the toolgroup. - """ - - # Create an agent with the toolgroup - agent_config = AgentConfig( - **{ - **common_params, - "toolgroups": [toolgroup_name], - } - ) - - agent_id, session_id = await create_agent_session( - agents_stack.impls[Api.agents], agent_config - ) - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=search_query_messages, - stream=True, - ) - - turn_response = [ - chunk - async for chunk in await agents_stack.impls[Api.agents].create_agent_turn( - **turn_request - ) - ] - - assert len(turn_response) > 0 - assert all( - isinstance(chunk, AgentTurnResponseStreamChunk) for chunk in turn_response - ) - - check_event_types(turn_response) - - # Check for tool execution events - tool_execution_events = [ - chunk - for chunk in turn_response - if isinstance(chunk.event.payload, AgentTurnResponseStepCompletePayload) - and chunk.event.payload.step_details.step_type == StepType.tool_execution.value - ] - assert len(tool_execution_events) > 0, "No tool execution events found" - - # Check the tool execution details - tool_execution = tool_execution_events[0].event.payload.step_details - assert isinstance(tool_execution, ToolExecutionStep) - assert len(tool_execution.tool_calls) > 0 - actual_tool_name = tool_execution.tool_calls[0].tool_name - assert actual_tool_name.value == tool_name - assert len(tool_execution.tool_responses) > 0 - - check_turn_complete_event(turn_response, session_id, search_query_messages) - - class TestAgents: @pytest.mark.asyncio async def test_agent_turns_with_safety( From 67b35613bb5e0f7e5222571e05d1c8c133494254 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 14:05:28 -0800 Subject: [PATCH 47/53] test turn overrides in unit tests --- .../agents/meta_reference/tests/test_chat_agent.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 6e789bf19f..6b8a846ee5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -12,6 +12,7 @@ from llama_stack.apis.agents import ( AgentConfig, + AgentToolGroupWithArgs, AgentTurnCreateRequest, AgentTurnResponseTurnCompletePayload, StepType, @@ -478,3 +479,15 @@ async def test_chat_agent_tools( assert MEMORY_QUERY_TOOL in tool_defs if expected_code_interpreter: assert BuiltinTool.code_interpreter in tool_defs + if expected_memory and expected_code_interpreter: + # override the tools for turn + new_tool_defs, _ = await chat_agent._get_tool_defs( + toolgroups_for_turn=[ + AgentToolGroupWithArgs( + name=MEMORY_TOOLGROUP, + args={"memory_banks": ["test_memory_bank"]}, + ) + ] + ) + assert MEMORY_QUERY_TOOL in new_tool_defs + assert BuiltinTool.code_interpreter not in new_tool_defs From edcfd66be3f53450d601b82a77658b43e8ef6153 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 15:28:03 -0800 Subject: [PATCH 48/53] all templates to include toolgroups and tool runtime --- distributions/dependencies.json | 12 +++++++ .../self_hosted_distro/bedrock.md | 1 + .../self_hosted_distro/cerebras.md | 1 + .../self_hosted_distro/fireworks.md | 1 + .../self_hosted_distro/meta-reference-gpu.md | 1 + .../meta-reference-quantized-gpu.md | 1 + .../self_hosted_distro/ollama.md | 1 + .../self_hosted_distro/remote-vllm.md | 1 + .../distributions/self_hosted_distro/tgi.md | 1 + .../self_hosted_distro/together.md | 1 + .../tool_runtime/brave_search/config.py | 9 ++++- .../tool_runtime/tavily_search/config.py | 9 ++++- llama_stack/templates/bedrock/bedrock.py | 24 +++++++++++-- llama_stack/templates/bedrock/build.yaml | 6 +++- llama_stack/templates/bedrock/run.yaml | 27 ++++++++++++-- llama_stack/templates/cerebras/build.yaml | 6 +++- llama_stack/templates/cerebras/cerebras.py | 29 +++++++++++++-- llama_stack/templates/cerebras/run.yaml | 33 +++++++++++++---- llama_stack/templates/fireworks/build.yaml | 6 +++- llama_stack/templates/fireworks/fireworks.py | 29 +++++++++++++-- llama_stack/templates/fireworks/run.yaml | 33 +++++++++++++---- llama_stack/templates/hf-endpoint/build.yaml | 6 +++- .../templates/hf-endpoint/hf_endpoint.py | 29 ++++++++++++++- .../hf-endpoint/run-with-safety.yaml | 35 ++++++++++++++----- llama_stack/templates/hf-endpoint/run.yaml | 29 ++++++++++++--- .../templates/hf-serverless/build.yaml | 6 +++- .../templates/hf-serverless/hf_serverless.py | 28 ++++++++++++++- .../hf-serverless/run-with-safety.yaml | 35 ++++++++++++++----- llama_stack/templates/hf-serverless/run.yaml | 23 +++++++++--- .../templates/meta-reference-gpu/build.yaml | 6 +++- .../meta-reference-gpu/meta_reference.py | 29 +++++++++++++-- .../meta-reference-gpu/run-with-safety.yaml | 35 ++++++++++++++----- .../templates/meta-reference-gpu/run.yaml | 23 +++++++++--- .../meta-reference-quantized-gpu/build.yaml | 6 +++- .../meta_reference.py | 24 +++++++++++-- .../meta-reference-quantized-gpu/run.yaml | 29 ++++++++++++--- llama_stack/templates/ollama/build.yaml | 6 +++- llama_stack/templates/ollama/ollama.py | 29 +++++++++++++-- .../templates/ollama/run-with-safety.yaml | 35 ++++++++++++++----- llama_stack/templates/ollama/run.yaml | 23 +++++++++--- llama_stack/templates/remote-vllm/build.yaml | 6 +++- .../remote-vllm/run-with-safety.yaml | 35 ++++++++++++++----- llama_stack/templates/remote-vllm/run.yaml | 23 +++++++++--- llama_stack/templates/remote-vllm/vllm.py | 29 +++++++++++++-- llama_stack/templates/template.py | 15 ++++++-- llama_stack/templates/tgi/build.yaml | 6 +++- .../templates/tgi/run-with-safety.yaml | 34 +++++++++++++----- llama_stack/templates/tgi/run.yaml | 23 +++++++++--- llama_stack/templates/tgi/tgi.py | 29 +++++++++++++-- llama_stack/templates/together/build.yaml | 6 +++- llama_stack/templates/together/run.yaml | 33 +++++++++++++---- llama_stack/templates/together/together.py | 29 +++++++++++++-- llama_stack/templates/vllm-gpu/build.yaml | 6 +++- llama_stack/templates/vllm-gpu/run.yaml | 29 ++++++++++++--- llama_stack/templates/vllm-gpu/vllm.py | 28 ++++++++++++++- 55 files changed, 854 insertions(+), 145 deletions(-) diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 7a974b9177..bd363ea40c 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -23,6 +23,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -54,6 +55,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -86,6 +88,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -116,6 +119,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -148,6 +152,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -181,6 +186,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -213,6 +219,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -247,6 +254,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentence-transformers", @@ -286,6 +294,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentence-transformers", @@ -319,6 +328,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -352,6 +362,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", @@ -385,6 +396,7 @@ "psycopg2-binary", "pypdf", "redis", + "requests", "scikit-learn", "scipy", "sentencepiece", diff --git a/docs/source/distributions/self_hosted_distro/bedrock.md b/docs/source/distributions/self_hosted_distro/bedrock.md index 7dab236557..db4c7a8c97 100644 --- a/docs/source/distributions/self_hosted_distro/bedrock.md +++ b/docs/source/distributions/self_hosted_distro/bedrock.md @@ -19,6 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro | safety | `remote::bedrock` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | diff --git a/docs/source/distributions/self_hosted_distro/cerebras.md b/docs/source/distributions/self_hosted_distro/cerebras.md index a8886d39bb..f623ed0de5 100644 --- a/docs/source/distributions/self_hosted_distro/cerebras.md +++ b/docs/source/distributions/self_hosted_distro/cerebras.md @@ -9,6 +9,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr | memory | `inline::meta-reference` | | safety | `inline::llama-guard` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | ### Environment Variables diff --git a/docs/source/distributions/self_hosted_distro/fireworks.md b/docs/source/distributions/self_hosted_distro/fireworks.md index a78b0ee3fc..c5428306a6 100644 --- a/docs/source/distributions/self_hosted_distro/fireworks.md +++ b/docs/source/distributions/self_hosted_distro/fireworks.md @@ -22,6 +22,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | ### Environment Variables diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md index d460393184..0ca58e7df3 100644 --- a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md @@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs. diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md index 837be744a1..87f4f4a617 100644 --- a/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md +++ b/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md @@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc. diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index c915a7ac31..7fe2ae408b 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -22,6 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables diff --git a/docs/source/distributions/self_hosted_distro/remote-vllm.md b/docs/source/distributions/self_hosted_distro/remote-vllm.md index 27f917055b..e751567ce4 100644 --- a/docs/source/distributions/self_hosted_distro/remote-vllm.md +++ b/docs/source/distributions/self_hosted_distro/remote-vllm.md @@ -18,6 +18,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following | memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | safety | `inline::llama-guard` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference. diff --git a/docs/source/distributions/self_hosted_distro/tgi.md b/docs/source/distributions/self_hosted_distro/tgi.md index 84b91da38a..8470188097 100644 --- a/docs/source/distributions/self_hosted_distro/tgi.md +++ b/docs/source/distributions/self_hosted_distro/tgi.md @@ -23,6 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference. diff --git a/docs/source/distributions/self_hosted_distro/together.md b/docs/source/distributions/self_hosted_distro/together.md index 856fd264fa..72b0822260 100644 --- a/docs/source/distributions/self_hosted_distro/together.md +++ b/docs/source/distributions/self_hosted_distro/together.md @@ -22,6 +22,7 @@ The `llamastack/distribution-together` distribution consists of the following pr | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` | ### Environment Variables diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/config.py b/llama_stack/providers/remote/tool_runtime/brave_search/config.py index 565d428f79..ab60536090 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/config.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, Field @@ -18,3 +18,10 @@ class BraveSearchToolConfig(BaseModel): default=3, description="The maximum number of results to return", ) + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + return { + "api_key": "${env.BRAVE_SEARCH_API_KEY:}", + "max_results": 3, + } diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/config.py b/llama_stack/providers/remote/tool_runtime/tavily_search/config.py index f7a8f3f09b..945430bb1e 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/config.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, Field @@ -18,3 +18,10 @@ class TavilySearchToolConfig(BaseModel): default=3, description="The maximum number of results to return", ) + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + return { + "api_key": "${env.TAVILY_SEARCH_API_KEY:}", + "max_results": 3, + } diff --git a/llama_stack/templates/bedrock/bedrock.py b/llama_stack/templates/bedrock/bedrock.py index 0b5b7d90d5..a579e5b7f0 100644 --- a/llama_stack/templates/bedrock/bedrock.py +++ b/llama_stack/templates/bedrock/bedrock.py @@ -9,8 +9,7 @@ from llama_models.sku_list import all_registered_models from llama_stack.apis.models import ModelInput -from llama_stack.distribution.datatypes import Provider - +from llama_stack.distribution.datatypes import Provider, ToolGroupInput from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -26,6 +25,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "bedrock" memory_provider = Provider( @@ -46,6 +51,20 @@ def get_distribution_template() -> DistributionTemplate: ) for m in MODEL_ALIASES ] + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -61,6 +80,7 @@ def get_distribution_template() -> DistributionTemplate: "memory": [memory_provider], }, default_models=default_models, + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/bedrock/build.yaml b/llama_stack/templates/bedrock/build.yaml index cd36c320e0..a68a8f6fca 100644 --- a/llama_stack/templates/bedrock/build.yaml +++ b/llama_stack/templates/bedrock/build.yaml @@ -2,7 +2,6 @@ version: '2' name: bedrock distribution_spec: description: Use AWS Bedrock for running LLM inference and safety - docker_image: null providers: inference: - remote::bedrock @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index 9aa5ca9144..1d07217738 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: bedrock -docker_image: null conda_env: bedrock apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: bedrock @@ -65,8 +65,24 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db models: @@ -90,3 +106,10 @@ memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml index a1fe93099f..307e0303a7 100644 --- a/llama_stack/templates/cerebras/build.yaml +++ b/llama_stack/templates/cerebras/build.yaml @@ -2,7 +2,6 @@ version: '2' name: cerebras distribution_spec: description: Use Cerebras for running LLM inference - docker_image: null providers: inference: - remote::cerebras @@ -14,4 +13,9 @@ distribution_spec: - inline::meta-reference telemetry: - inline::meta-reference + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py index 9acb244bdd..cbacdbaec9 100644 --- a/llama_stack/templates/cerebras/cerebras.py +++ b/llama_stack/templates/cerebras/cerebras.py @@ -9,8 +9,12 @@ from llama_models.sku_list import all_registered_models from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -26,6 +30,12 @@ def get_distribution_template() -> DistributionTemplate: "memory": ["inline::meta-reference"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } inference_provider = Provider( @@ -58,6 +68,20 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name="cerebras", @@ -74,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate: }, default_models=default_models + [embedding_model], default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index 05b21bf0ab..e06b17a50c 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: cerebras -docker_image: null conda_env: cerebras apis: - agents @@ -8,6 +7,7 @@ apis: - memory - safety - telemetry +- tool_runtime providers: inference: - provider_id: cerebras @@ -45,8 +45,24 @@ providers: service_name: ${env.OTEL_SERVICE_NAME:llama-stack} sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db models: @@ -64,14 +80,17 @@ models: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: meta-llama/Llama-Guard-3-8B - provider_id: null - provider_shield_id: null +- shield_id: meta-llama/Llama-Guard-3-8B memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index 30ea347aef..e76cc86f11 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -2,7 +2,6 @@ version: '2' name: fireworks distribution_spec: description: Use Fireworks.AI for running LLM inference - docker_image: null providers: inference: - remote::fireworks @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index cbcac0f929..090f98b59e 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -9,8 +9,12 @@ from llama_models.sku_list import all_registered_models from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -30,6 +34,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "fireworks" @@ -69,6 +79,20 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -86,6 +110,7 @@ def get_distribution_template() -> DistributionTemplate: }, default_models=default_models + [embedding_model], default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index 99f155a4a1..444679da79 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: fireworks -docker_image: null conda_env: fireworks apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: fireworks @@ -70,8 +70,24 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/registry.db models: @@ -129,14 +145,17 @@ models: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: meta-llama/Llama-Guard-3-8B - provider_id: null - provider_shield_id: null +- shield_id: meta-llama/Llama-Guard-3-8B memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/hf-endpoint/build.yaml b/llama_stack/templates/hf-endpoint/build.yaml index 523cf5d835..c186898552 100644 --- a/llama_stack/templates/hf-endpoint/build.yaml +++ b/llama_stack/templates/hf-endpoint/build.yaml @@ -2,7 +2,6 @@ version: '2' name: hf-endpoint distribution_spec: description: Use (an external) Hugging Face Inference Endpoint for running LLM inference - docker_image: null providers: inference: - remote::hf::endpoint @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/hf-endpoint/hf_endpoint.py b/llama_stack/templates/hf-endpoint/hf_endpoint.py index 404440be6a..8bac2588d3 100644 --- a/llama_stack/templates/hf-endpoint/hf_endpoint.py +++ b/llama_stack/templates/hf-endpoint/hf_endpoint.py @@ -5,7 +5,12 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -24,6 +29,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "hf-endpoint" inference_provider = Provider( @@ -58,6 +69,20 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -74,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate: "memory": [memory_provider], }, default_models=[inference_model, embedding_model], + default_tool_groups=default_tool_groups, ), "run-with-safety.yaml": RunConfigSettings( provider_overrides={ @@ -96,6 +122,7 @@ def get_distribution_template() -> DistributionTemplate: embedding_model, ], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index 8e566de9a0..a9d895d234 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -1,6 +1,5 @@ version: '2' image_name: hf-endpoint -docker_image: null conda_env: hf-endpoint apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: hf-endpoint @@ -75,33 +75,50 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: hf-endpoint - provider_model_id: null model_type: llm - metadata: {} model_id: ${env.SAFETY_MODEL} provider_id: hf-endpoint-safety - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: ${env.SAFETY_MODEL} - provider_id: null - provider_shield_id: null +- shield_id: ${env.SAFETY_MODEL} memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index c1b3a64d00..e9b58c9624 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: hf-endpoint -docker_image: null conda_env: hf-endpoint apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: hf-endpoint @@ -70,24 +70,45 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: hf-endpoint - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index af7eb60fe2..a6b551e4a6 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -2,7 +2,6 @@ version: '2' name: hf-serverless distribution_spec: description: Use (an external) Hugging Face Inference Endpoint for running LLM inference - docker_image: null providers: inference: - remote::hf::serverless @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/hf-serverless/hf_serverless.py b/llama_stack/templates/hf-serverless/hf_serverless.py index 63b423412f..33eb594fe8 100644 --- a/llama_stack/templates/hf-serverless/hf_serverless.py +++ b/llama_stack/templates/hf-serverless/hf_serverless.py @@ -5,7 +5,12 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -24,6 +29,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "hf-serverless" @@ -59,6 +70,20 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -97,6 +122,7 @@ def get_distribution_template() -> DistributionTemplate: embedding_model, ], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index 2b24ab0747..415cec648d 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -1,6 +1,5 @@ version: '2' image_name: hf-serverless -docker_image: null conda_env: hf-serverless apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: hf-serverless @@ -75,33 +75,50 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: hf-serverless - provider_model_id: null model_type: llm - metadata: {} model_id: ${env.SAFETY_MODEL} provider_id: hf-serverless-safety - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: ${env.SAFETY_MODEL} - provider_id: null - provider_shield_id: null +- shield_id: ${env.SAFETY_MODEL} memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index 394d689daa..ef9dedeed6 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: hf-serverless -docker_image: null conda_env: hf-serverless apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: hf-serverless @@ -70,24 +70,39 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: hf-serverless - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: [] diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index 300b75b14b..ba8413fa65 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -2,7 +2,6 @@ version: '2' name: meta-reference-gpu distribution_spec: description: Use Meta Reference for running LLM inference - docker_image: null providers: inference: - inline::meta-reference @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/templates/meta-reference-gpu/meta_reference.py index 461d89a4a5..8ad56d7f55 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-gpu/meta_reference.py @@ -7,8 +7,12 @@ from pathlib import Path from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) @@ -29,6 +33,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "meta-reference-gpu" inference_provider = Provider( @@ -66,6 +76,20 @@ def get_distribution_template() -> DistributionTemplate: model_id="${env.SAFETY_MODEL}", provider_id="meta-reference-safety", ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -104,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate: embedding_model, ], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index deb6c4a912..4946fdab7f 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -1,6 +1,5 @@ version: '2' image_name: meta-reference-gpu -docker_image: null conda_env: meta-reference-gpu apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: meta-reference-inference @@ -77,33 +77,50 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: meta-reference-inference - provider_model_id: null model_type: llm - metadata: {} model_id: ${env.SAFETY_MODEL} provider_id: meta-reference-safety - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: ${env.SAFETY_MODEL} - provider_id: null - provider_shield_id: null +- shield_id: ${env.SAFETY_MODEL} memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index c190666644..52345f3c19 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: meta-reference-gpu -docker_image: null conda_env: meta-reference-gpu apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: meta-reference-inference @@ -71,24 +71,39 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: meta-reference-inference - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: [] diff --git a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml b/llama_stack/templates/meta-reference-quantized-gpu/build.yaml index 9d866de18f..41ab44e38b 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-quantized-gpu/build.yaml @@ -2,7 +2,6 @@ version: '2' name: meta-reference-quantized-gpu distribution_spec: description: Use Meta Reference with fp8, int4 quantization for running LLM inference - docker_image: null providers: inference: - inline::meta-reference-quantized @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py b/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py index c460860c56..6af7175f70 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py @@ -7,8 +7,7 @@ from pathlib import Path from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider +from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceQuantizedInferenceConfig, ) @@ -29,7 +28,27 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] name = "meta-reference-quantized-gpu" inference_provider = Provider( provider_id="meta-reference-inference", @@ -76,6 +95,7 @@ def get_distribution_template() -> DistributionTemplate: "memory": [memory_provider], }, default_models=[inference_model, embedding_model], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/meta-reference-quantized-gpu/run.yaml b/llama_stack/templates/meta-reference-quantized-gpu/run.yaml index 550170a00d..02a5bacaa9 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-quantized-gpu/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: meta-reference-quantized-gpu -docker_image: null conda_env: meta-reference-quantized-gpu apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: meta-reference-inference @@ -73,24 +73,45 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: meta-reference-inference - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index a021e4993b..cbd9101cfc 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -2,7 +2,6 @@ version: '2' name: ollama distribution_spec: description: Use (an external) Ollama server for running LLM inference - docker_image: null providers: inference: - remote::ollama @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index 1e3180a775..9a76e93713 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -7,8 +7,12 @@ from pathlib import Path from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -27,6 +31,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "ollama" inference_provider = Provider( @@ -61,6 +71,20 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -92,6 +116,7 @@ def get_distribution_template() -> DistributionTemplate: embedding_model, ], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 100886c958..96cb1d6684 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -1,6 +1,5 @@ version: '2' image_name: ollama -docker_image: null conda_env: ollama apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: ollama @@ -69,33 +69,50 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: ollama - provider_model_id: null model_type: llm - metadata: {} model_id: ${env.SAFETY_MODEL} provider_id: ollama - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: ${env.SAFETY_MODEL} - provider_id: null - provider_shield_id: null +- shield_id: ${env.SAFETY_MODEL} memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index bcbed3e6ef..1764652993 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: ollama -docker_image: null conda_env: ollama apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: ollama @@ -69,24 +69,39 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: ollama - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: [] diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index 9f4597cb0f..246e53db0d 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -2,7 +2,6 @@ version: '2' name: remote-vllm distribution_spec: description: Use (an external) vLLM server for running LLM inference - docker_image: null providers: inference: - remote::vllm @@ -16,4 +15,9 @@ distribution_spec: - inline::meta-reference telemetry: - inline::meta-reference + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index 7097bc6496..1babd04ac1 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -1,6 +1,5 @@ version: '2' image_name: remote-vllm -docker_image: null conda_env: remote-vllm apis: - agents @@ -8,6 +7,7 @@ apis: - memory - safety - telemetry +- tool_runtime providers: inference: - provider_id: vllm-inference @@ -52,33 +52,50 @@ providers: service_name: ${env.OTEL_SERVICE_NAME:llama-stack} sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference - provider_model_id: null model_type: llm - metadata: {} model_id: ${env.SAFETY_MODEL} provider_id: vllm-safety - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: ${env.SAFETY_MODEL} - provider_id: null - provider_shield_id: null +- shield_id: ${env.SAFETY_MODEL} memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index c957b05d08..a3a571423a 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: remote-vllm -docker_image: null conda_env: remote-vllm apis: - agents @@ -8,6 +7,7 @@ apis: - memory - safety - telemetry +- tool_runtime providers: inference: - provider_id: vllm-inference @@ -46,24 +46,39 @@ providers: service_name: ${env.OTEL_SERVICE_NAME:llama-stack} sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: [] diff --git a/llama_stack/templates/remote-vllm/vllm.py b/llama_stack/templates/remote-vllm/vllm.py index e4c948fbfa..f12752f2b3 100644 --- a/llama_stack/templates/remote-vllm/vllm.py +++ b/llama_stack/templates/remote-vllm/vllm.py @@ -7,8 +7,12 @@ from pathlib import Path from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -24,6 +28,12 @@ def get_distribution_template() -> DistributionTemplate: "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "remote-vllm" inference_provider = Provider( @@ -60,6 +70,20 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -97,6 +121,7 @@ def get_distribution_template() -> DistributionTemplate: embedding_model, ], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index 0ec8c1f09d..5bb88c821d 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -20,6 +20,7 @@ Provider, ShieldInput, StackRunConfig, + ToolGroupInput, ) from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -30,6 +31,7 @@ class RunConfigSettings(BaseModel): provider_overrides: Dict[str, List[Provider]] = Field(default_factory=dict) default_models: Optional[List[ModelInput]] = None default_shields: Optional[List[ShieldInput]] = None + default_tool_groups: Optional[List[ToolGroupInput]] = None def run_config( self, @@ -91,6 +93,7 @@ def run_config( ), models=self.default_models or [], shields=self.default_shields or [], + tool_groups=self.default_tool_groups or [], ) @@ -159,14 +162,22 @@ def enum_representer(dumper, data): build_config = self.build_config() with open(yaml_output_dir / "build.yaml", "w") as f: - yaml.safe_dump(build_config.model_dump(), f, sort_keys=False) + yaml.safe_dump( + build_config.model_dump(exclude_none=True), + f, + sort_keys=False, + ) for yaml_pth, settings in self.run_configs.items(): run_config = settings.run_config( self.name, self.providers, self.docker_image ) with open(yaml_output_dir / yaml_pth, "w") as f: - yaml.safe_dump(run_config.model_dump(), f, sort_keys=False) + yaml.safe_dump( + run_config.model_dump(exclude_none=True), + f, + sort_keys=False, + ) if self.template_path: docs = self.generate_markdown_docs() diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index d90b505df6..399d4a6163 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -2,7 +2,6 @@ version: '2' name: tgi distribution_spec: description: Use (an external) TGI server for running LLM inference - docker_image: null providers: inference: - remote::tgi @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index ef8344a7ad..4134101f62 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -1,6 +1,5 @@ version: '2' image_name: tgi -docker_image: null conda_env: tgi apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: tgi-inference @@ -70,27 +70,45 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: tgi-inference - provider_model_id: null model_type: llm - metadata: {} model_id: ${env.SAFETY_MODEL} provider_id: tgi-safety - provider_model_id: null model_type: llm shields: -- params: null - shield_id: ${env.SAFETY_MODEL} - provider_id: null - provider_shield_id: null +- shield_id: ${env.SAFETY_MODEL} memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 22c08d1d3a..b0b78e33b0 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: tgi -docker_image: null conda_env: tgi apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: tgi-inference @@ -69,24 +69,39 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: tgi-inference - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: [] diff --git a/llama_stack/templates/tgi/tgi.py b/llama_stack/templates/tgi/tgi.py index c84f5b5feb..892d539d2b 100644 --- a/llama_stack/templates/tgi/tgi.py +++ b/llama_stack/templates/tgi/tgi.py @@ -7,8 +7,12 @@ from pathlib import Path from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -27,6 +31,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "tgi" inference_provider = Provider( @@ -63,6 +73,20 @@ def get_distribution_template() -> DistributionTemplate: model_id="${env.SAFETY_MODEL}", provider_id="tgi-safety", ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -99,6 +123,7 @@ def get_distribution_template() -> DistributionTemplate: safety_model, ], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 6930b76926..96f9f758eb 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -2,7 +2,6 @@ version: '2' name: together distribution_spec: description: Use Together.AI for running LLM inference - docker_image: null providers: inference: - remote::together @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index 44e33662b2..ed65ded57d 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: together -docker_image: null conda_env: together apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: together @@ -70,8 +70,24 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/registry.db models: @@ -124,14 +140,17 @@ models: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: -- params: null - shield_id: meta-llama/Llama-Guard-3-8B - provider_id: null - provider_shield_id: null +- shield_id: meta-llama/Llama-Guard-3-8B memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index 994cf55498..d73e23e77c 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -9,8 +9,12 @@ from llama_models.sku_list import all_registered_models from llama_stack.apis.models.models import ModelType - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) @@ -30,6 +34,12 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } name = "together" inference_provider = Provider( @@ -59,6 +69,20 @@ def get_distribution_template() -> DistributionTemplate: ) for m in MODEL_ALIASES ] + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] embedding_model = ModelInput( model_id="all-MiniLM-L6-v2", provider_id="sentence-transformers", @@ -83,6 +107,7 @@ def get_distribution_template() -> DistributionTemplate: "memory": [memory_provider], }, default_models=default_models + [embedding_model], + default_tool_groups=default_tool_groups, default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], ), }, diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml index 4289296ec2..959f91d3e0 100644 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ b/llama_stack/templates/vllm-gpu/build.yaml @@ -2,7 +2,6 @@ version: '2' name: vllm-gpu distribution_spec: description: Use a built-in vLLM engine for running LLM inference - docker_image: null providers: inference: - inline::vllm @@ -25,4 +24,9 @@ distribution_spec: - inline::basic - inline::llm-as-judge - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::memory-runtime image_type: conda diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index 171f25d632..48ec57cfbd 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -1,6 +1,5 @@ version: '2' image_name: vllm-gpu -docker_image: null conda_env: vllm-gpu apis: - agents @@ -11,6 +10,7 @@ apis: - safety - scoring - telemetry +- tool_runtime providers: inference: - provider_id: vllm @@ -73,24 +73,45 @@ providers: provider_type: inline::braintrust config: openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: memory-runtime + provider_type: inline::memory-runtime + config: {} metadata_store: - namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/registry.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: vllm - provider_model_id: null model_type: llm - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 provider_id: sentence-transformers - provider_model_id: null model_type: embedding shields: [] memory_banks: [] datasets: [] scoring_fns: [] eval_tasks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::memory + provider_id: memory-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter diff --git a/llama_stack/templates/vllm-gpu/vllm.py b/llama_stack/templates/vllm-gpu/vllm.py index fe6fb7186a..5cf4789907 100644 --- a/llama_stack/templates/vllm-gpu/vllm.py +++ b/llama_stack/templates/vllm-gpu/vllm.py @@ -11,7 +11,11 @@ ) from llama_stack.providers.inline.inference.vllm import VLLMConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, + ToolGroupInput, +) def get_distribution_template() -> DistributionTemplate: @@ -24,7 +28,14 @@ def get_distribution_template() -> DistributionTemplate: "eval": ["inline::meta-reference"], "datasetio": ["remote::huggingface", "inline::localfs"], "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::memory-runtime", + ], } + name = "vllm-gpu" inference_provider = Provider( provider_id="vllm", @@ -54,6 +65,20 @@ def get_distribution_template() -> DistributionTemplate: "embedding_dimension": 384, }, ) + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::memory", + provider_id="memory-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] return DistributionTemplate( name=name, @@ -70,6 +95,7 @@ def get_distribution_template() -> DistributionTemplate: "memory": [memory_provider], }, default_models=[inference_model, embedding_model], + default_tool_groups=default_tool_groups, ), }, run_config_env_vars={ From e08b7f4432cef851aeda6ee08326727d2aab810f Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 17:47:29 -0800 Subject: [PATCH 49/53] move _interpret_content_as_attachment to outside --- .../agents/meta_reference/agent_instance.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 52293182cc..b728566743 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -679,21 +679,7 @@ def is_memory_group(tool): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially - def interpret_content_as_attachment( - content: str, - ) -> Optional[Attachment]: - match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) - if match: - snippet = match.group(1) - data = json.loads(snippet) - return Attachment( - url=URL(uri="file://" + data["filepath"]), - mime_type=data["mimetype"], - ) - - return None - - if out_attachment := interpret_content_as_attachment( + if out_attachment := _interpret_content_as_attachment( result_message.content ): # NOTE: when we push this message back to the model, the model may ignore the @@ -974,3 +960,18 @@ async def execute_tool_call_maybe( content=result.content, ) ] + + +def _interpret_content_as_attachment( + content: str, +) -> Optional[Attachment]: + match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content) + if match: + snippet = match.group(1) + data = json.loads(snippet) + return Attachment( + url=URL(uri="file://" + data["filepath"]), + mime_type=data["mimetype"], + ) + + return None From a7a55748cacbda8f0e30ee26bf5d71cf2da66920 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 18:24:35 -0800 Subject: [PATCH 50/53] address feedback --- llama_stack/apis/tools/tools.py | 5 ++-- .../distribution/routers/routing_tables.py | 1 - .../agents/meta_reference/agent_instance.py | 25 +++++++++++++------ .../meta_reference/tests/test_chat_agent.py | 5 ++-- .../code_interpreter/code_interpreter.py | 3 --- .../tool_runtime/bing_search/bing_search.py | 2 -- .../tavily_search/tavily_search.py | 2 -- .../wolfram_alpha/wolfram_alpha.py | 2 -- 8 files changed, 21 insertions(+), 24 deletions(-) diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index dbfd852206..e430ec46d4 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from llama_models.llama3.api.datatypes import BuiltinTool, ToolPromptFormat +from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -40,7 +40,6 @@ class Tool(Resource): tool_host: ToolHost description: str parameters: List[ToolParameter] - built_in_type: Optional[BuiltinTool] = None metadata: Optional[Dict[str, Any]] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json @@ -53,7 +52,6 @@ class ToolDef(BaseModel): description: Optional[str] = None parameters: Optional[List[ToolParameter]] = None metadata: Optional[Dict[str, Any]] = None - built_in_type: Optional[BuiltinTool] = None tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) @@ -130,6 +128,7 @@ async def unregister_tool_group(self, tool_group_id: str) -> None: class ToolRuntime(Protocol): tool_store: ToolStore + # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 36ddda7a65..d4cb708a27 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -527,7 +527,6 @@ async def register_tool_group( provider_resource_id=tool_def.name, metadata=tool_def.metadata, tool_host=tool_host, - built_in_type=tool_def.built_in_type, ) ) for tool in tools: diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b728566743..2cd86bcaa3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -78,6 +78,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") MEMORY_QUERY_TOOL = "query_memory" WEB_SEARCH_TOOL = "web_search" +MEMORY_GROUP = "builtin::memory" class ChatAgent(ShieldRunnerMixin): @@ -741,16 +742,24 @@ async def _get_tool_defs( continue tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name) for tool_def in tools: - if tool_def.built_in_type: - if tool_def_map.get(tool_def.built_in_type, None): - raise ValueError( - f"Tool {tool_def.built_in_type} already exists" - ) + if ( + toolgroup_name.startswith("builtin") + and toolgroup_name != MEMORY_GROUP + ): + tool_name = tool_def.identifier + built_in_type = BuiltinTool.brave_search + if tool_name == "web_search": + built_in_type = BuiltinTool.brave_search + else: + built_in_type = BuiltinTool(tool_name) + + if tool_def_map.get(built_in_type, None): + raise ValueError(f"Tool {built_in_type} already exists") - tool_def_map[tool_def.built_in_type] = ToolDefinition( - tool_name=tool_def.built_in_type + tool_def_map[built_in_type] = ToolDefinition( + tool_name=built_in_type ) - tool_to_group[tool_def.built_in_type] = tool_def.toolgroup_id + tool_to_group[built_in_type] = tool_def.toolgroup_id continue if tool_def_map.get(tool_def.identifier, None): diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index 6b8a846ee5..a7e6efc8cf 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -198,7 +198,7 @@ async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: toolgroup_id=MEMORY_TOOLGROUP, tool_host=ToolHost.client, description="Mock tool", - provider_id="mock_provider", + provider_id="builtin::memory", parameters=[], ) ] @@ -208,10 +208,9 @@ async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: identifier="code_interpreter", provider_resource_id="code_interpreter", toolgroup_id=CODE_INTERPRETER_TOOLGROUP, - built_in_type=BuiltinTool.code_interpreter, tool_host=ToolHost.client, description="Mock tool", - provider_id="mock_provider", + provider_id="builtin::code_interpreter", parameters=[], ) ] diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 98026fa3d1..361c91a92a 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -9,8 +9,6 @@ import tempfile from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import BuiltinTool - from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( Tool, @@ -58,7 +56,6 @@ async def list_runtime_tools( parameter_type="string", ), ], - built_in_type=BuiltinTool.code_interpreter, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index a69f08ce81..5cf36acbc1 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -65,7 +64,6 @@ async def list_runtime_tools( parameter_type="string", ) ], - built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 8f666a6fb4..8f86edfb1d 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -64,7 +63,6 @@ async def list_runtime_tools( parameter_type="string", ) ], - built_in_type=BuiltinTool.brave_search, ) ] diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 13c298eb23..af99d7b2aa 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional import requests -from llama_models.llama3.api.datatypes import BuiltinTool from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( @@ -65,7 +64,6 @@ async def list_runtime_tools( parameter_type="string", ) ], - built_in_type=BuiltinTool.wolfram_alpha, ) ] From d0c8dced65d6fd0cac0eb3f23f561524db6d9061 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 18:31:47 -0800 Subject: [PATCH 51/53] resolve conflicts --- docs/resources/llama-stack-spec.html | 59 ++++++-- docs/resources/llama-stack-spec.yaml | 197 ++++++++++++++++++++++++++- 2 files changed, 240 insertions(+), 16 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index e98de6491e..377adf4666 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -2737,6 +2737,40 @@ } } }, + "/alpha/toolgroups/unregister": { + "post": { + "responses": { + "200": { + "description": "OK" + } + }, + "tags": [ + "ToolGroups" + ], + "summary": "Unregister a tool group", + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnregisterToolGroupRequest" + } + } + }, + "required": true + } + } + }, "/alpha/version": { "get": { "responses": { @@ -3840,9 +3874,6 @@ ] } }, - "built_in_type": { - "$ref": "#/components/schemas/BuiltinTool" - }, "tool_prompt_format": { "$ref": "#/components/schemas/ToolPromptFormat", "default": "json" @@ -5837,9 +5868,6 @@ "$ref": "#/components/schemas/ToolParameter" } }, - "built_in_type": { - "$ref": "#/components/schemas/BuiltinTool" - }, "metadata": { "type": "object", "additionalProperties": { @@ -7933,6 +7961,18 @@ "model_id" ] }, + "UnregisterToolGroupRequest": { + "type": "object", + "properties": { + "tool_group_id": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "tool_group_id" + ] + }, "VersionInfo": { "type": "object", "properties": { @@ -8665,10 +8705,6 @@ { "name": "ViolationLevel", "description": "" - }, - { - "name": "WolframAlphaToolDefinition", - "description": "" } ], "x-tagGroups": [ @@ -8862,8 +8898,7 @@ "VectorMemoryBank", "VectorMemoryBankParams", "VersionInfo", - "ViolationLevel", - "WolframAlphaToolDefinition" + "ViolationLevel" ] } ] diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 924fc32a10..f642553419 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -2576,8 +2576,6 @@ components: Tool: additionalProperties: false properties: - built_in_type: - $ref: '#/components/schemas/BuiltinTool' description: type: string identifier: @@ -2688,8 +2686,6 @@ components: ToolDef: additionalProperties: false properties: - built_in_type: - $ref: '#/components/schemas/BuiltinTool' description: type: string metadata: @@ -4683,6 +4679,199 @@ paths: description: OK tags: - Telemetry + /alpha/tool-runtime/invoke: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InvokeToolRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ToolInvocationResult' + description: OK + summary: Run a tool with the given arguments + tags: + - ToolRuntime + /alpha/tool-runtime/list-tools: + post: + parameters: + - in: query + name: tool_group_id + required: false + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ListRuntimeToolsRequest' + required: true + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/ToolDef' + description: OK + tags: + - ToolRuntime + /alpha/toolgroups/get: + get: + parameters: + - in: query + name: toolgroup_id + required: true + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/ToolGroup' + description: OK + tags: + - ToolGroups + /alpha/toolgroups/list: + get: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/ToolGroup' + description: OK + summary: List tool groups with optional provider + tags: + - ToolGroups + /alpha/toolgroups/register: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RegisterToolGroupRequest' + required: true + responses: + '200': + description: OK + summary: Register a tool group + tags: + - ToolGroups + /alpha/toolgroups/unregister: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UnregisterToolGroupRequest' + required: true + responses: + '200': + description: OK + summary: Unregister a tool group + tags: + - ToolGroups + /alpha/tools/get: + get: + parameters: + - in: query + name: tool_name + required: true + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/Tool' + description: OK + tags: + - ToolGroups + /alpha/tools/list: + get: + parameters: + - in: query + name: tool_group_id + required: false + schema: + type: string + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/jsonl: + schema: + $ref: '#/components/schemas/Tool' + description: OK + summary: List tools with optional tool group + tags: + - ToolGroups /alpha/version: get: parameters: From b46d94d87db5622f8e2dfa743a8dfed45b231f8a Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 18:53:32 -0800 Subject: [PATCH 52/53] do not pass memory tools to inference --- .../inline/agents/meta_reference/agent_instance.py | 13 +++---------- .../providers/inline/tool_runtime/memory/memory.py | 2 +- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2cd86bcaa3..24448a28f8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -487,20 +487,13 @@ async def _run( stop_reason = None with tracing.span("inference") as span: - - def is_memory_group(tool): - memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None) - has_memory_tool = MEMORY_QUERY_TOOL in tool_defs - return ( - has_memory_tool - and tool_to_group.get(tool.tool_name, None) != memory_tool_group - ) - async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, tools=[ - tool for tool in tool_defs.values() if not is_memory_group(tool) + tool + for tool in tool_defs.values() + if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP ], tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, diff --git a/llama_stack/providers/inline/tool_runtime/memory/memory.py b/llama_stack/providers/inline/tool_runtime/memory/memory.py index f46b375105..fe6325abbd 100644 --- a/llama_stack/providers/inline/tool_runtime/memory/memory.py +++ b/llama_stack/providers/inline/tool_runtime/memory/memory.py @@ -60,7 +60,7 @@ async def list_runtime_tools( description="Retrieve context from memory", parameters=[ ToolParameter( - name="input_messages", + name="messages", description="The input messages to search for", parameter_type="array", ), From 3bdc1d92b49df0815e7cd9c83070671b71b0c2c0 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 8 Jan 2025 18:56:12 -0800 Subject: [PATCH 53/53] update registry prefix --- llama_stack/distribution/store/registry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 686054dd26..d26b4447c8 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR - from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -36,7 +35,7 @@ async def delete(self, type: str, identifier: str) -> None: ... REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v3" +KEY_VERSION = "v4" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"