Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate tools to new runtime #646

Open
wants to merge 2 commits into
base: tool-runtime
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,91 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import importlib
import logging
from enum import Enum
from typing import Any, Dict

import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins
import pkgutil
from typing import Any, Dict, Optional, Type

from llama_stack.apis.tools import Tool, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool

from .config import MetaReferenceToolRuntimeConfig

logger = logging.getLogger(__name__)


class ToolType(Enum):
bing_search = "bing_search"
brave_search = "brave_search"
tavily_search = "tavily_search"
print_tool = "print_tool"


class MetaReferenceToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config
self.tools: Dict[str, Type[BaseTool]] = {}
self._discover_tools()

def _discover_tools(self):
# Import all tools from the tools package
tools_package = "llama_stack.providers.inline.tool_runtime.tools"
package = importlib.import_module(tools_package)

for _, name, _ in pkgutil.iter_modules(package.__path__):
module = importlib.import_module(f"{tools_package}.{name}")
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and issubclass(attr, BaseTool)
and attr != BaseTool
):
self.tools[attr.tool_id()] = attr

async def initialize(self):
pass

async def register_tool(self, tool: Tool):
print(f"registering tool {tool.identifier}")
if tool.provider_resource_id not in ToolType.__members__:
raise ValueError(
f"Tool {tool.identifier} not a supported tool by Meta Reference"
)
if tool.identifier not in self.tools:
raise ValueError(f"Tool {tool.identifier} not found in available tools")

# Validate provider_metadata against tool's config type if specified
tool_class = self.tools[tool.identifier]
config_type = tool_class.get_provider_config_type()
if (
config_type
and tool.provider_metadata
and tool.provider_metadata.get("config")
):
config_type(**tool.provider_metadata.get("config"))

async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found")

tool_instance = await self._create_tool_instance(tool_id)
return await tool_instance.execute(**args)

async def unregister_tool(self, tool_id: str) -> None:
raise NotImplementedError("Meta Reference does not support unregistering tools")

async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
tool = await self.tool_store.get_tool(tool_id)
if args.get("__api_key__") is not None:
logger.warning(
"__api_key__ is a reserved argument for this tool: {tool_id}"
)
args["__api_key__"] = self._get_api_key()
return await getattr(builtins, tool.provider_resource_id)(**args)
async def _create_tool_instance(
self, tool_id: str, tool_def: Optional[Tool] = None
) -> BaseTool:
"""Create a new tool instance with proper configuration"""
if tool_id not in self.tools:
raise ValueError(f"Tool {tool_id} not found in available tools")

tool_class = self.tools[tool_id]

# Get tool definition if not provided
if tool_def is None:
tool_def = await self.tool_store.get_tool(tool_id)

# Build configuration
config = dict(tool_def.provider_metadata.get("config") or {})
if tool_class.requires_api_key:
config["api_key"] = self._get_api_key()

return tool_class(config=config)

def _get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, get_type_hints, List, Optional, Type, TypeVar

from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_stack.apis.tools.tools import Tool, ToolParameter, ToolReturn

T = TypeVar("T")


class BaseTool(ABC):
"""Base class for all tools"""

requires_api_key: bool = False

def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}

@classmethod
@abstractmethod
def tool_id(cls) -> str:
"""Unique identifier for the tool"""
pass

@abstractmethod
async def execute(self, **kwargs) -> Any:
"""Execute the tool with given arguments"""
pass

@classmethod
def get_provider_config_type(cls) -> Optional[Type[T]]:
"""Override to specify a Pydantic model for tool configuration"""
return None

@classmethod
def get_tool_definition(cls) -> Tool:
"""Generate a Tool definition from the class implementation"""
# Get execute method
execute_method = cls.execute
signature = inspect.signature(execute_method)
docstring = execute_method.__doc__ or "No description available"

# Extract parameters
parameters: List[ToolParameter] = []
type_hints = get_type_hints(execute_method)

for name, param in signature.parameters.items():
if name == "self":
continue

param_type = type_hints.get(name, Any).__name__
required = param.default == param.empty
default = None if param.default == param.empty else param.default

parameters.append(
ToolParameter(
name=name,
type_hint=param_type,
description=f"Parameter: {name}", # Could be enhanced with docstring parsing
required=required,
default=default,
)
)

# Extract return info
return_type = type_hints.get("return", Any).__name__

return Tool(
identifier=cls.tool_id(),
provider_resource_id=cls.tool_id(),
name=cls.__name__,
description=docstring,
parameters=parameters,
returns=ToolReturn(
type_hint=return_type, description="Tool execution result"
),
tool_prompt_format=ToolPromptFormat.json,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
from typing import List

import requests

from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel


class BingSearchConfig(BaseModel):
api_key: str
max_results: int = 5


class BingSearchTool(BaseTool):
requires_api_key: bool = True

@classmethod
def tool_id(cls) -> str:
return "bing_search"

@classmethod
def get_provider_config_type(cls):
return BingSearchConfig

async def execute(self, query: str) -> List[dict]:
config = BingSearchConfig(**self.config)
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": config.api_key,
}
params = {
"count": config.max_results,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}

response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
return json.dumps(self._clean_response(response.json()))

def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)

return {"query": query, "results": clean_response}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import List

import requests

from llama_stack.providers.inline.tool_runtime.meta_reference.tools.base import BaseTool
from pydantic import BaseModel


class BraveSearchConfig(BaseModel):
api_key: str
max_results: int = 3


class BraveSearchTool(BaseTool):
requires_api_key: bool = True

@classmethod
def tool_id(cls) -> str:
return "brave_search"

@classmethod
def get_provider_config_type(cls):
return BraveSearchConfig

async def execute(self, query: str) -> List[dict]:
config = BraveSearchConfig(**self.config)
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": config.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
return self._clean_brave_response(response.json(), config.max_results)

def _clean_brave_response(self, search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
clean_response.append(cleaned)

return {"query": query, "results": clean_response}

def _clean_result_by_type(self, r_type, results, idx=None):
type_cleaners = {
"web": (
["type", "title", "url", "description", "date", "extra_snippets"],
lambda x: x[idx],
),
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
"infobox": (
["type", "title", "url", "description", "long_desc"],
lambda x: x[idx],
),
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
"locations": (
[
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
],
lambda x: x,
),
"news": (["type", "title", "url", "description"], lambda x: x),
}

if r_type not in type_cleaners:
return []

selected_keys, result_selector = type_cleaners[r_type]
results = result_selector(results)

if isinstance(results, list):
return [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
return {k: v for k, v in results.items() if k in selected_keys}
Loading
Loading