Skip to content

Commit

Permalink
working tools runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
dineshyv committed Dec 18, 2024
1 parent 744eb08 commit 84d01fe
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 5 deletions.
6 changes: 6 additions & 0 deletions llama_stack/apis/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class Tool(Resource):
)


class ToolStore(Protocol):
def get_tool(self, identifier: str) -> Tool: ...


@runtime_checkable
@trace_protocol
class Tools(Protocol):
Expand Down Expand Up @@ -88,6 +92,8 @@ async def unregister_tool(self, tool_id: str) -> None:
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
tool_store: ToolStore

@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
"""Run a tool with the given arguments"""
Expand Down
4 changes: 4 additions & 0 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ async def add_objects(
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
elif api == Api.tool_runtime:
p.tool_store = self

async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
Expand All @@ -129,6 +131,8 @@ def apiname_object():
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
elif isinstance(self, ToolsRoutingTable):
return ("Tools", "tool")
else:
raise ValueError("Unknown routing table type")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict

from pydantic import BaseModel

from .config import MetaReferenceToolRuntimeConfig
from .meta_reference import MetaReferenceToolRuntimeImpl


async def get_provider_impl(config: MetaReferenceToolRuntimeConfig, _deps):
class MetaReferenceProviderDataValidator(BaseModel):
api_key: str


async def get_provider_impl(
config: MetaReferenceToolRuntimeConfig, _deps: Dict[str, Any]
):
impl = MetaReferenceToolRuntimeImpl(config)
await impl.initialize()
return impl
165 changes: 165 additions & 0 deletions llama_stack/providers/inline/tool_runtime/meta_reference/builtins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import logging

import requests

logger = logging.getLogger(__name__)


async def bing_search(query: str, __api_key__: str, top_k: int = 3, **kwargs) -> str:
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": __api_key__,
}
params = {
"count": top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}

response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
clean = _bing_clean_response(response.json())
return json.dumps(clean)


def _bing_clean_response(search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append({k: v for k, v in p.items() if k in selected_keys})
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})

clean_response.append(clean_news)

return {"query": query, "top_k": clean_response}


async def brave_search(query: str, __api_key__: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": __api_key__,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
return json.dumps(_clean_brave_response(response.json()))


def _clean_brave_response(search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
if r_type == "web":
# For web data - add a single output from the search
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"date",
"extra_snippets",
]
cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}
elif r_type == "faq":
# For faw data - take a list of all the questions & answers
selected_keys = ["type", "question", "answer", "title", "url"]
cleaned = []
for q in results:
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
elif r_type == "infobox":
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"long_desc",
]
cleaned = {k: v for k, v in results[idx].items() if k in selected_keys}
elif r_type == "videos":
selected_keys = [
"type",
"url",
"title",
"description",
"date",
]
cleaned = []
for q in results:
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
elif r_type == "locations":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
]
cleaned = []
for q in results:
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
elif r_type == "news":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
]
cleaned = []
for q in results:
cleaned.append({k: v for k, v in q.items() if k in selected_keys})
else:
cleaned = []

clean_response.append(cleaned)

return {"query": query, "top_k": clean_response}


async def tavily_search(query: str, __api_key__: str) -> str:
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": __api_key__, "query": query},
)
return json.dumps(_clean_tavily_response(response.json()))


def _clean_tavily_response(search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}


async def print_tool(query: str, __api_key__: str) -> str:
logger.info(f"print_tool called with query: {query} and api_key: {__api_key__}")
return json.dumps({"result": "success"})
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,31 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import logging
from enum import Enum
from typing import Any, Dict

import llama_stack.providers.inline.tool_runtime.meta_reference.builtins as builtins

from llama_stack.apis.tools import Tool, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate

from .config import MetaReferenceToolRuntimeConfig

logger = logging.getLogger(__name__)


class ToolType(Enum):
bing_search = "bing_search"
brave_search = "brave_search"
tavily_search = "tavily_search"
print_tool = "print_tool"


class MetaReferenceToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
class MetaReferenceToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: MetaReferenceToolRuntimeConfig):
self.config = config

Expand All @@ -21,10 +37,27 @@ async def initialize(self):

async def register_tool(self, tool: Tool):
print(f"registering tool {tool.identifier}")
pass
if tool.provider_resource_id not in ToolType.__members__:
raise ValueError(
f"Tool {tool.identifier} not a supported tool by Meta Reference"
)

async def unregister_tool(self, tool_id: str) -> None:
pass
raise NotImplementedError("Meta Reference does not support unregistering tools")

async def invoke_tool(self, tool_id: str, args: Dict[str, Any]) -> Any:
pass
tool = await self.tool_store.get_tool(tool_id)
if args.get("__api_key__") is not None:
logger.warning(
"__api_key__ is a reserved argument for this tool: {tool_id}"
)
args["__api_key__"] = self._get_api_key()
return await getattr(builtins, tool.provider_resource_id)(**args)

def _get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
1 change: 1 addition & 0 deletions llama_stack/providers/registry/tool_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.meta_reference",
config_class="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceToolRuntimeConfig",
provider_data_validator="llama_stack.providers.inline.tool_runtime.meta_reference.MetaReferenceProviderDataValidator",
),
]

0 comments on commit 84d01fe

Please sign in to comment.