Skip to content

Commit

Permalink
tool def
Browse files Browse the repository at this point in the history
  • Loading branch information
dineshyv committed Dec 18, 2024
1 parent 482a0e4 commit a7dd229
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class MetaReferenceToolRuntimeImpl(
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config
self.tools: Dict[str, Type[BaseTool]] = {}
self.tool_instances: Dict[str, BaseTool] = {}
self._discover_tools()

def _discover_tools(self):
Expand All @@ -44,26 +43,6 @@ def _discover_tools(self):
):
self.tools[attr.tool_id()] = attr

async def _create_tool_instance(
self, tool_id: str, tool_def: Optional[Tool] = None
) -> BaseTool:
"""Create a new tool instance with proper configuration"""
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found in available tools")

tool_class = self.tools[tool_id]

# Get tool definition if not provided
if tool_def is None:
tool_def = await self.tool_store.get_tool(tool_id)

# Build configuration
config = dict(tool_def.provider_metadata.get("config") or {})
if tool_class.requires_api_key:
config["api_key"] = self._get_api_key()

return tool_class(config=config)

async def initialize(self):
pass

Expand All @@ -81,24 +60,36 @@ async def register_tool(self, tool: Tool):
):
config_type(**tool.provider_metadata.get("config"))

self.tool_instances[tool.identifier] = await self._create_tool_instance(
tool.identifier, tool
)

async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")

if tool_id not in self.tool_instances:
self.tool_instances[tool_id] = await self._create_tool_instance(tool_id)

return await self.tool_instances[tool_id].execute(**args)
tool_instance = await self._create_tool_instance(tool_id)
return await tool_instance.execute(**args)

async def unregister_tool(self, tool_id: str) -> None:
if tool_id in self.tool_instances:
del self.tool_instances[tool_id]
raise NotImplementedError("Meta Reference does not support unregistering tools")

async def _create_tool_instance(
self, tool_id: str, tool_def: Optional[Tool] = None
) -> BaseTool:
"""Create a new tool instance with proper configuration"""
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found in available tools")

tool_class = self.tools[tool_id]

# Get tool definition if not provided
if tool_def is None:
tool_def = await self.tool_store.get_tool(tool_id)

# Build configuration
config = dict(tool_def.provider_metadata.get("config") or {})
if tool_class.requires_api_key:
config["api_key"] = self._get_api_key()

return tool_class(config=config)

def _get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar
from typing import Any, Dict, get_type_hints, List, Optional, Type, TypeVar

from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_stack.apis.tools.tools import Tool, ToolParameter, ToolReturn

T = TypeVar("T")

Expand Down Expand Up @@ -33,3 +37,48 @@ async def execute(self, **kwargs) -> Any:
def get_provider_config_type(cls) -> Optional[Type[T]]:
"""Override to specify a Pydantic model for tool configuration"""
return None

@classmethod
def get_tool_definition(cls) -> Tool:
"""Generate a Tool definition from the class implementation"""
# Get execute method
execute_method = cls.execute
signature = inspect.signature(execute_method)
docstring = execute_method.__doc__ or "No description available"

# Extract parameters
parameters: List[ToolParameter] = []
type_hints = get_type_hints(execute_method)

for name, param in signature.parameters.items():
if name == "self":
continue

param_type = type_hints.get(name, Any).__name__
required = param.default == param.empty
default = None if param.default == param.empty else param.default

parameters.append(
ToolParameter(
name=name,
type_hint=param_type,
description=f"Parameter: {name}", # Could be enhanced with docstring parsing
required=required,
default=default,
)
)

# Extract return info
return_type = type_hints.get("return", Any).__name__

return Tool(
identifier=cls.tool_id(),
provider_resource_id=cls.tool_id(),
name=cls.__name__,
description=docstring,
parameters=parameters,
returns=ToolReturn(
type_hint=return_type, description="Tool execution result"
),
tool_prompt_format=ToolPromptFormat.json,
)

0 comments on commit a7dd229

Please sign in to comment.