diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index 93a3718a05..67ea94c19c 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -18,6 +18,7 @@ class ResourceType(Enum): dataset = "dataset" scoring_function = "scoring_function" eval_task = "eval_task" + tool = "tool" class Resource(BaseModel): diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 572a749980..d9baa33de8 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -7,12 +7,11 @@ from typing import Any, Dict, List, Literal, Optional from llama_models.llama3.api.datatypes import ToolPromptFormat - -from llama_models.schema_utils import json_schema_type +from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable -from llama_stack.apis.resource import Resource +from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -39,7 +38,7 @@ class ToolReturn(BaseModel): class Tool(Resource): """Represents a tool that can be provided by different providers""" - resource_type: Literal["tool"] = "tool" + type: Literal[ResourceType.tool.value] = ResourceType.tool.value name: str description: str parameters: List[ToolParameter] @@ -53,6 +52,7 @@ class Tool(Resource): @runtime_checkable @trace_protocol class Tools(Protocol): + @webmethod(route="/tools/register", method="POST") async def register_tool( self, tool_id: str, @@ -60,27 +60,33 @@ async def register_tool( description: str, parameters: List[ToolParameter], returns: ToolReturn, + provider_id: Optional[str] = None, provider_metadata: Optional[Dict[str, Any]] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, ) -> Tool: """Register a tool with provider-specific metadata""" ... + @webmethod(route="/tools/get", method="GET") async def get_tool( self, - identifier: str, + tool_id: str, ) -> Tool: ... - async def list_tools( - self, - provider_id: Optional[str] = None, - ) -> List[Tool]: + @webmethod(route="/tools/list", method="GET") + async def list_tools(self) -> List[Tool]: """List tools with optional provider""" + @webmethod(route="/tools/unregister", method="POST") + async def unregister_tool(self, tool_id: str) -> None: + """Unregister a tool""" + ... + @runtime_checkable @trace_protocol class ToolRuntime(Protocol): - def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: + @webmethod(route="/tool-runtime/invoke", method="POST") + async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: """Run a tool with the given arguments""" ... diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 6fc4545c78..1478737dad 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -47,6 +47,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.eval_tasks, router_api=Api.eval, ), + AutoRoutedApiInfo( + routing_table_api=Api.tools, + router_api=Api.tool_runtime, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 885e9bbc0f..14e5e7a86b 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -30,7 +30,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry -from llama_stack.apis.tools import Tools +from llama_stack.apis.tools import ToolRuntime, Tools from llama_stack.distribution.client import get_client_impl from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.store import DistributionRegistry @@ -61,6 +61,8 @@ def api_protocol_map() -> Dict[Api, Any]: Api.eval: Eval, Api.eval_tasks: EvalTasks, Api.post_training: PostTraining, + Api.tools: Tools, + Api.tool_runtime: ToolRuntime, } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 57e81ac30d..05e741598d 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -17,6 +17,7 @@ ModelsRoutingTable, ScoringFunctionsRoutingTable, ShieldsRoutingTable, + ToolsRoutingTable, ) @@ -33,6 +34,7 @@ async def get_routing_table_impl( "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, "eval_tasks": EvalTasksRoutingTable, + "tools": ToolsRoutingTable, } if api.value not in api_to_tables: @@ -51,6 +53,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> MemoryRouter, SafetyRouter, ScoringRouter, + ToolRuntimeRouter, ) api_to_routers = { @@ -60,6 +63,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "datasetio": DatasetIORouter, "scoring": ScoringRouter, "eval": EvalRouter, + "tool_runtime": ToolRuntimeRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 16ae353574..9acdf0a5d2 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -15,6 +15,7 @@ from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.eval import * # noqa: F403 +from llama_stack.apis.tools import * # noqa: F403 class MemoryRouter(Memory): @@ -372,3 +373,23 @@ async def job_result( task_id, job_id, ) + + +class ToolRuntimeRouter(ToolRuntime): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: + return await self.routing_table.get_provider_impl(tool_id).invoke_tool( + tool_id=tool_id, + args=args, + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 01edf4e5ac..0cd451ae5a 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -15,7 +15,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.eval_tasks import * # noqa: F403 - +from llama_stack.apis.tools import * # noqa: F403 from llama_models.llama3.api.datatypes import URL @@ -47,6 +47,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable return await p.register_scoring_function(obj) elif api == Api.eval: return await p.register_eval_task(obj) + elif api == Api.tool_runtime: + return await p.register_tool(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -59,6 +61,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_model(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) + elif api == Api.tool_runtime: + return await p.unregister_tool(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -464,3 +468,51 @@ async def register_eval_task( provider_resource_id=provider_eval_task_id, ) await self.register_object(eval_task) + + +class ToolsRoutingTable(CommonRoutingTableImpl, Tools): + async def list_tools(self) -> List[Tool]: + return await self.get_all_with_type("tool") + + async def get_tool(self, tool_id: str) -> Tool: + return await self.get_object_by_identifier("tool", tool_id) + + async def register_tool( + self, + tool_id: str, + name: str, + description: str, + parameters: List[ToolParameter], + returns: ToolReturn, + provider_id: Optional[str] = None, + provider_metadata: Optional[Dict[str, Any]] = None, + tool_prompt_format: Optional[ToolPromptFormat] = None, + ) -> None: + if provider_metadata is None: + provider_metadata = {} + if tool_prompt_format is None: + tool_prompt_format = ToolPromptFormat.json + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + tool = Tool( + identifier=tool_id, + name=name, + description=description, + parameters=parameters, + returns=returns, + provider_id=provider_id, + provider_metadata=provider_metadata, + tool_prompt_format=tool_prompt_format, + ) + await self.register_object(tool) + + async def unregister_tool(self, tool_id: str) -> None: + tool = await self.get_tool(tool_id) + if tool is None: + raise ValueError(f"Tool {tool_id} not found") + await self.unregister_object(tool) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index f49222bcaa..7a82e282ed 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -30,6 +30,7 @@ class Api(Enum): scoring = "scoring" eval = "eval" post_training = "post_training" + tool_runtime = "tool_runtime" telemetry = "telemetry" @@ -39,6 +40,7 @@ class Api(Enum): datasets = "datasets" scoring_functions = "scoring_functions" eval_tasks = "eval_tasks" + tools = "tools" # built-in API inspect = "inspect" @@ -79,6 +81,8 @@ async def register_eval_task(self, eval_task: EvalTask) -> None: ... class ToolsProtocolPrivate(Protocol): async def register_tool(self, tool: Tool) -> None: ... + async def unregister_tool(self, tool_id: str) -> None: ... + @json_schema_type class ProviderSpec(BaseModel): diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py b/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py new file mode 100644 index 0000000000..f7d52c1f06 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import MetaReferenceToolRuntimeConfig +from .meta_reference import MetaReferenceToolRuntimeImpl + + +async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps): + impl = MetaReferenceToolRuntimeImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/tools/meta_reference/config.py b/llama_stack/providers/inline/tool_runtime/meta_reference/config.py similarity index 83% rename from llama_stack/providers/inline/tools/meta_reference/config.py rename to llama_stack/providers/inline/tool_runtime/meta_reference/config.py index 61dfcf52e3..3f6146c518 100644 --- a/llama_stack/providers/inline/tools/meta_reference/config.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/config.py @@ -7,5 +7,5 @@ from pydantic import BaseModel -class MetaReferenceToolConfig(BaseModel): +class MetaReferenceToolRuntimeConfig(BaseModel): pass diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py new file mode 100644 index 0000000000..8e4718d852 --- /dev/null +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from llama_stack.apis.tools import Tool, ToolRuntime +from llama_stack.providers.datatypes import ToolsProtocolPrivate + +from .config import MetaReferenceToolRuntimeConfig + + +class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): + def __init__(self, config: MetaReferenceToolRuntimeConfig): + self.config = config + + async def initialize(self): + pass + + async def register_tool(self, tool: Tool): + print(f"registering tool {tool.identifier}") + pass + + async def unregister_tool(self, tool_id: str) -> None: + pass + + async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: + pass diff --git a/llama_stack/providers/inline/tools/meta_reference/__init__.py b/llama_stack/providers/inline/tools/meta_reference/__init__.py deleted file mode 100644 index da392fdb30..0000000000 --- a/llama_stack/providers/inline/tools/meta_reference/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .meta_reference import * # noqa: F401 F403 diff --git a/llama_stack/providers/inline/tools/meta_reference/meta_reference.py b/llama_stack/providers/inline/tools/meta_reference/meta_reference.py deleted file mode 100644 index c69e832031..0000000000 --- a/llama_stack/providers/inline/tools/meta_reference/meta_reference.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.tools import Tool, Tools - -from .config import MetaReferenceToolConfig - - -class MetaReferenceTool(Tools): - def __init__(self, config: MetaReferenceToolConfig): - self.config = config - - async def register_tool(self, tool: Tool): - pass diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py new file mode 100644 index 0000000000..c0e7a3d1be --- /dev/null +++ b/llama_stack/providers/registry/tool_runtime.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.tool_runtime, + provider_type="inline::meta-reference", + pip_packages=[], + module="llama_stack.providers.inline.tool_runtime.meta_reference", + config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig", + ), + ]