From af2d3b6ac6bd88c484b626b5e11a8421f7ed6501 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 18 Nov 2024 20:40:54 -0800 Subject: [PATCH] use class methods and construct_stack --- src/llama_stack_client/lib/direct/direct.py | 34 +++++++++++++-------- src/llama_stack_client/lib/direct/test.py | 2 +- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/llama_stack_client/lib/direct/direct.py b/src/llama_stack_client/lib/direct/direct.py index 3f4ef9f..c1ef4ee 100644 --- a/src/llama_stack_client/lib/direct/direct.py +++ b/src/llama_stack_client/lib/direct/direct.py @@ -9,8 +9,7 @@ from llama_stack.distribution.stack import ( get_stack_run_config_from_template, ) - -from llama_stack.distribution.store.registry import create_dist_registry +from llama_stack.distribution.stack import construct_stack from pydantic import BaseModel from ..._base_client import ResponseT @@ -20,22 +19,31 @@ class LlamaStackDirectClient(LlamaStackClient): - def __init__(self, config: StackRunConfig | str, **kwargs): + def __init__(self, config: StackRunConfig, **kwargs): + raise TypeError("Use from_yaml() or from_template() instead of direct initialization") + + @classmethod + async def from_config(cls, config: StackRunConfig, **kwargs): + instance = object.__new__(cls) + await instance._initialize(config, **kwargs) + return instance + + @classmethod + async def from_template(cls, template_name: str, **kwargs): + config = get_stack_run_config_from_template(template_name) + instance = object.__new__(cls) + await instance._initialize(config, **kwargs) + return instance + + async def _initialize(self, config: StackRunConfig, **kwargs) -> None: super().__init__(**kwargs) self.endpoints = get_all_api_endpoints() - - # Allow initialization with template name string - if isinstance(config, str): - self.config = get_stack_run_config_from_template(config) - else: - self.config = config - - self.dist_registry = None + self.config = config self.impls = None + await self.initialize() async def initialize(self) -> None: - self.dist_registry, _ = await create_dist_registry(self.config.metadata_store, image_name="") - self.impls = await resolve_impls(self.config, get_provider_registry(), self.dist_registry) + self.impls = await construct_stack(self.config) def _convert_param(self, param_type: Any, value: Any) -> Any: origin = get_origin(param_type) diff --git a/src/llama_stack_client/lib/direct/test.py b/src/llama_stack_client/lib/direct/test.py index 28cd310..4f21ce4 100644 --- a/src/llama_stack_client/lib/direct/test.py +++ b/src/llama_stack_client/lib/direct/test.py @@ -12,7 +12,7 @@ async def main(config_path: str): run_config = parse_and_maybe_upgrade_config(config_dict) - client = LlamaStackDirectClient(config=run_config) + client = await LlamaStackDirectClient.from_config(run_config) await client.initialize() response = await client.models.list()