Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add direct client implementation #15

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions src/llama_stack_client/lib/direct/direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import cast, Any, Type
from ..._client import LlamaStackClient
from ..._types import Body, RequestFiles, RequestOptions, NOT_GIVEN
from ..._base_client import ResponseT
from ..._streaming import Stream
from llama_stack.distribution.datatypes import *
dltn marked this conversation as resolved.
Show resolved Hide resolved
import asyncio
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from typing import List, get_origin, get_args
from llama_stack.distribution.server.server import is_streaming_request
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
import inspect


class LlamaStackDirectClient(LlamaStackClient):
def __init__(self, config: StackRunConfig, **kwargs):
super().__init__(**kwargs)
self.endpoints = get_all_api_endpoints()
self.config = config
self.dist_registry = None
self.impls = None

async def initialize(self) -> None:
if self.config.metadata_store:
dltn marked this conversation as resolved.
Show resolved Hide resolved
dist_kvstore = await kvstore_impl(self.config.metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / self.config.image_name / "kvstore.db"
).as_posix()
)
)

self.dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
self.impls = await resolve_impls(self.config, get_provider_registry(), self.dist_registry)

def _convert_param(self, param_type: Any, value: Any) -> Any:
dltn marked this conversation as resolved.
Show resolved Hide resolved
# TODO: are there any APIs with even more complex nesting that fail this?
if get_origin(param_type) == list:
item_type = get_args(param_type)[0]
if isinstance(item_type, type) and issubclass(item_type, BaseModel):
return [item_type(**item) for item in value]
return value

elif isinstance(param_type, type) and issubclass(param_type, BaseModel):
return param_type(**value)

# Return as-is for primitive types
return value

async def _call_endpoint(self, path: str, method: str, body: dict = None) -> Any:
for api, endpoints in self.endpoints.items():
for endpoint in endpoints:
if endpoint.route == path:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
sig = inspect.signature(func) #

if body:
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}

# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = self._convert_param(
param.annotation,
value
)
body = converted_body

if is_streaming_request(endpoint.name, body):
async for chunk in func(**(body or {})):
yield chunk
else:
yield await func(**(body or {}))

raise ValueError(f"No endpoint found for {path}")


async def get(
self,
path: str,
*,
cast_to: Type[ResponseT],
options: RequestOptions = {},
stream: bool = False,
stream_cls: type[Stream[Any]] | None = None,
) -> ResponseT:
async for response in self._call_endpoint(path, "GET"):
return cast(ResponseT, response)

async def post(
self,
path: str,
*,
cast_to: Type[ResponseT],
body: Body | None = None,
options: RequestOptions = {},
files: RequestFiles | None = None,
stream: bool = False,
stream_cls: type[Stream[Any]] | None = None,
) -> ResponseT:
async for response in self._call_endpoint(path, "POST", body):
return cast(ResponseT, response)
35 changes: 35 additions & 0 deletions src/llama_stack_client/lib/direct/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from llama_stack_client.lib.direct.direct import LlamaStackDirectClient
from llama_stack_client.types import UserMessage
import yaml
import argparse
from llama_stack.distribution.datatypes import *
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config

async def main(config_path: str):
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)

run_config = parse_and_maybe_upgrade_config(config_dict)

client = LlamaStackDirectClient(config=run_config)
await client.initialize()

response = await client.models.list()
print(response)

response = await client.inference.chat_completion(
messages = [
UserMessage(content="What is the capital of France?", role="user")
],
model="Llama3.1-8B-Instruct",
stream=False,
)
print("\nChat completion response:")
print(response)

if __name__ == "__main__":
import asyncio
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
asyncio.run(main(args.config_path))