Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add direct client implementation #15

Merged
merged 3 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def create_session(self, session_name: str) -> int:
return self.session_id

async def create_turn(
self, messages: List[Union[UserMessage, ToolResponseMessage]], attachments: Optional[List[Attachment]] = None, session_id: Optional[str] = None,
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
attachments: Optional[List[Attachment]] = None,
session_id: Optional[str] = None,
):
response = self.client.agents.turn.create(
agent_id=self.agent_id,
Expand Down
3 changes: 2 additions & 1 deletion src/llama_stack_client/lib/agents/custom_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from abc import abstractmethod
from typing import Dict, List, Union

from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam
from llama_stack_client.types import ToolResponseMessage, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfigToolFunctionCallToolDefinition

from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam


class CustomTool:
"""
Expand Down
3 changes: 1 addition & 2 deletions src/llama_stack_client/lib/cli/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

import yaml

from llama_stack_client.lib.cli.constants import get_config_file_path
from llama_stack_client.lib.cli.constants import get_config_file_path, LLAMA_STACK_CLIENT_CONFIG_DIR
from llama_stack_client.lib.cli.subcommand import Subcommand

from llama_stack_client.lib.cli.constants import LLAMA_STACK_CLIENT_CONFIG_DIR

def get_config():
config_file = get_config_file_path()
Expand Down
6 changes: 2 additions & 4 deletions src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()

def command_requires_config(self, args: argparse.Namespace) -> bool:
return not (hasattr(args.func, '__self__') and isinstance(args.func.__self__, ConfigureParser))
return not (hasattr(args.func, "__self__") and isinstance(args.func.__self__, ConfigureParser))

def run(self, args: argparse.Namespace) -> None:
if self.command_requires_config(args) and not get_config_file_path().exists():
print(
"Config file not found. Please run 'llama-stack-client configure' to create one."
)
print("Config file not found. Please run 'llama-stack-client configure' to create one.")
return

args.func(args)
Expand Down
105 changes: 105 additions & 0 deletions src/llama_stack_client/lib/direct/direct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import inspect
from typing import Any, cast, get_args, get_origin, Type

from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add llama-stack as a dependency for the llama-stack-client package?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope it should be the reverse as we talked about. this code should always be exercised when the person already has llama-stack in their environment (as a library or as pip)

Copy link
Contributor

@yanxi0830 yanxi0830 Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should this class LlamaStackDirectClient be inside the llama-stack repo instead of the llama-stack-client-python repo?

  1. User who want to use llama-stack as a library. Install llama-stack package (dependent on llama-stack-client package). Is able to use LlamaStackDirectClient.

  2. User who just installs llama-stack-client package. They cannot use LlamaStackDirectClient without installing llama-stack.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanxi0830 yeah I think that makes sense to me actually.

from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.server import is_streaming_request

from llama_stack.distribution.store.registry import create_dist_registry
from pydantic import BaseModel

from ..._base_client import ResponseT
from ..._client import LlamaStackClient
from ..._streaming import Stream
from ..._types import Body, NOT_GIVEN, RequestFiles, RequestOptions


class LlamaStackDirectClient(LlamaStackClient):
def __init__(self, config: StackRunConfig, **kwargs):
super().__init__(**kwargs)
self.endpoints = get_all_api_endpoints()
self.config = config
self.dist_registry = None
self.impls = None

async def initialize(self) -> None:
self.dist_registry, _ = await create_dist_registry(self.config)
self.impls = await resolve_impls(self.config, get_provider_registry(), self.dist_registry)

def _convert_param(self, param_type: Any, value: Any) -> Any:
dltn marked this conversation as resolved.
Show resolved Hide resolved
origin = get_origin(param_type)
if origin == list:
item_type = get_args(param_type)[0]
if isinstance(item_type, type) and issubclass(item_type, BaseModel):
return [item_type(**item) for item in value]
return value

elif origin == dict:
_, val_type = get_args(param_type)
if isinstance(val_type, type) and issubclass(val_type, BaseModel):
return {k: val_type(**v) for k, v in value.items()}
return value

elif isinstance(param_type, type) and issubclass(param_type, BaseModel):
return param_type(**value)

# Return as-is for primitive types
return value

async def _call_endpoint(self, path: str, method: str, body: dict = None) -> Any:
for api, endpoints in self.endpoints.items():
for endpoint in endpoints:
if endpoint.route == path:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
sig = inspect.signature(func) #

if body:
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}

# Convert parameters to Pydantic models where needed
converted_body = {}
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = self._convert_param(param.annotation, value)
body = converted_body

if is_streaming_request(endpoint.name, body):
async for chunk in func(**(body or {})):
yield chunk
else:
yield await func(**(body or {}))

raise ValueError(f"No endpoint found for {path}")

async def get(
self,
path: str,
*,
cast_to: Type[ResponseT],
options: RequestOptions = None,
stream: bool = False,
stream_cls: type[Stream[Any]] | None = None,
) -> ResponseT:
options = options or {}
async for response in self._call_endpoint(path, "GET"):
return cast(ResponseT, response)

async def post(
self,
path: str,
*,
cast_to: Type[ResponseT],
body: Body | None = None,
options: RequestOptions = None,
files: RequestFiles | None = None,
stream: bool = False,
stream_cls: type[Stream[Any]] | None = None,
) -> ResponseT:
options = options or {}
async for response in self._call_endpoint(path, "POST", body):
return cast(ResponseT, response)
36 changes: 36 additions & 0 deletions src/llama_stack_client/lib/direct/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse

import yaml
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack_client.lib.direct.direct import LlamaStackDirectClient
from llama_stack_client.types import UserMessage


async def main(config_path: str):
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)

run_config = parse_and_maybe_upgrade_config(config_dict)

client = LlamaStackDirectClient(config=run_config)
await client.initialize()

response = await client.models.list()
print(response)

response = await client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model="Llama3.1-8B-Instruct",
stream=False,
)
print("\nChat completion response:")
print(response)


if __name__ == "__main__":
import asyncio

parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="Path to the config YAML file")
args = parser.parse_args()
asyncio.run(main(args.config_path))