Skip to content

Commit

Permalink
use class methods and construct_stack
Browse files Browse the repository at this point in the history
  • Loading branch information
Dinesh Yeduguru committed Nov 19, 2024
1 parent 655a46f commit af2d3b6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
34 changes: 21 additions & 13 deletions src/llama_stack_client/lib/direct/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from llama_stack.distribution.stack import (
get_stack_run_config_from_template,
)

from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.stack import construct_stack
from pydantic import BaseModel

from ..._base_client import ResponseT
Expand All @@ -20,22 +19,31 @@


class LlamaStackDirectClient(LlamaStackClient):
def __init__(self, config: StackRunConfig | str, **kwargs):
def __init__(self, config: StackRunConfig, **kwargs):
raise TypeError("Use from_yaml() or from_template() instead of direct initialization")

@classmethod
async def from_config(cls, config: StackRunConfig, **kwargs):
instance = object.__new__(cls)
await instance._initialize(config, **kwargs)
return instance

@classmethod
async def from_template(cls, template_name: str, **kwargs):
config = get_stack_run_config_from_template(template_name)
instance = object.__new__(cls)
await instance._initialize(config, **kwargs)
return instance

async def _initialize(self, config: StackRunConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.endpoints = get_all_api_endpoints()

# Allow initialization with template name string
if isinstance(config, str):
self.config = get_stack_run_config_from_template(config)
else:
self.config = config

self.dist_registry = None
self.config = config
self.impls = None
await self.initialize()

async def initialize(self) -> None:
self.dist_registry, _ = await create_dist_registry(self.config.metadata_store, image_name="")
self.impls = await resolve_impls(self.config, get_provider_registry(), self.dist_registry)
self.impls = await construct_stack(self.config)

def _convert_param(self, param_type: Any, value: Any) -> Any:
origin = get_origin(param_type)
Expand Down
2 changes: 1 addition & 1 deletion src/llama_stack_client/lib/direct/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async def main(config_path: str):

run_config = parse_and_maybe_upgrade_config(config_dict)

client = LlamaStackDirectClient(config=run_config)
client = await LlamaStackDirectClient.from_config(run_config)
await client.initialize()

response = await client.models.list()
Expand Down

0 comments on commit af2d3b6

Please sign in to comment.