From a7dd22988b1251c00a588c7184b8b538796ac769 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 17 Dec 2024 16:06:57 -0800 Subject: [PATCH] tool def --- .../meta_reference/meta_reference.py | 53 ++++++++----------- .../tool_runtime/meta_reference/tools/base.py | 51 +++++++++++++++++- 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py index 89efafecd5..47e6c22570 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/meta_reference.py @@ -25,7 +25,6 @@ class MetaReferenceToolRuntimeImpl( def __init__(self, config: MetaReferenceToolRuntimeConfig): self.config = config self.tools: Dict[str, Type[BaseTool]] = {} - self.tool_instances: Dict[str, BaseTool] = {} self._discover_tools() def _discover_tools(self): @@ -44,26 +43,6 @@ def _discover_tools(self): ): self.tools[attr.tool_id()] = attr - async def _create_tool_instance( - self, tool_id: str, tool_def: Optional[Tool] = None - ) -> BaseTool: - """Create a new tool instance with proper configuration""" - if tool_id not in self.tools: - raise ValueError(f"Tool {tool_id} not found in available tools") - - tool_class = self.tools[tool_id] - - # Get tool definition if not provided - if tool_def is None: - tool_def = await self.tool_store.get_tool(tool_id) - - # Build configuration - config = dict(tool_def.provider_metadata.get("config") or {}) - if tool_class.requires_api_key: - config["api_key"] = self._get_api_key() - - return tool_class(config=config) - async def initialize(self): pass @@ -81,24 +60,36 @@ async def register_tool(self, tool: Tool): ): config_type(**tool.provider_metadata.get("config")) - self.tool_instances[tool.identifier] = await self._create_tool_instance( - tool.identifier, tool - ) - async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any: if tool_id not in self.tools: raise ValueError(f"Tool {tool_id} not found") - if tool_id not in self.tool_instances: - self.tool_instances[tool_id] = await self._create_tool_instance(tool_id) - - return await self.tool_instances[tool_id].execute(**args) + tool_instance = await self._create_tool_instance(tool_id) + return await tool_instance.execute(**args) async def unregister_tool(self, tool_id: str) -> None: - if tool_id in self.tool_instances: - del self.tool_instances[tool_id] raise NotImplementedError("Meta Reference does not support unregistering tools") + async def _create_tool_instance( + self, tool_id: str, tool_def: Optional[Tool] = None + ) -> BaseTool: + """Create a new tool instance with proper configuration""" + if tool_id not in self.tools: + raise ValueError(f"Tool {tool_id} not found in available tools") + + tool_class = self.tools[tool_id] + + # Get tool definition if not provided + if tool_def is None: + tool_def = await self.tool_store.get_tool(tool_id) + + # Build configuration + config = dict(tool_def.provider_metadata.get("config") or {}) + if tool_class.requires_api_key: + config["api_key"] = self._get_api_key() + + return tool_class(config=config) + def _get_api_key(self) -> str: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.api_key: diff --git a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py index 79e20f85ec..d9acfa9be7 100644 --- a/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py +++ b/llama_stack/providers/inline/tool_runtime/meta_reference/tools/base.py @@ -4,8 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, get_type_hints, List, Optional, Type, TypeVar + +from llama_models.llama3.api.datatypes import ToolPromptFormat +from llama_stack.apis.tools.tools import Tool, ToolParameter, ToolReturn T = TypeVar("T") @@ -33,3 +37,48 @@ async def execute(self, **kwargs) -> Any: def get_provider_config_type(cls) -> Optional[Type[T]]: """Override to specify a Pydantic model for tool configuration""" return None + + @classmethod + def get_tool_definition(cls) -> Tool: + """Generate a Tool definition from the class implementation""" + # Get execute method + execute_method = cls.execute + signature = inspect.signature(execute_method) + docstring = execute_method.__doc__ or "No description available" + + # Extract parameters + parameters: List[ToolParameter] = [] + type_hints = get_type_hints(execute_method) + + for name, param in signature.parameters.items(): + if name == "self": + continue + + param_type = type_hints.get(name, Any).__name__ + required = param.default == param.empty + default = None if param.default == param.empty else param.default + + parameters.append( + ToolParameter( + name=name, + type_hint=param_type, + description=f"Parameter: {name}", # Could be enhanced with docstring parsing + required=required, + default=default, + ) + ) + + # Extract return info + return_type = type_hints.get("return", Any).__name__ + + return Tool( + identifier=cls.tool_id(), + provider_resource_id=cls.tool_id(), + name=cls.__name__, + description=docstring, + parameters=parameters, + returns=ToolReturn( + type_hint=return_type, description="Tool execution result" + ), + tool_prompt_format=ToolPromptFormat.json, + )