From 594aa81f75e1317120f77a81acbc153aa22eb48d Mon Sep 17 00:00:00 2001 From: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Date: Thu, 7 Nov 2024 16:04:11 -0500 Subject: [PATCH] Factor out create_dist_registry --- llama_stack/distribution/server/server.py | 19 ++------------ llama_stack/distribution/store/registry.py | 30 ++++++++++++++++++++-- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16c0fd0e02..1438137808 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -31,7 +31,7 @@ get_provider_registry, ) -from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,8 +42,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.store import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -281,21 +279,8 @@ def main( config = StackRunConfig(**yaml.safe_load(fp)) app = FastAPI() - # instantiate kvstore for storing and retrieving distribution metadata - if config.metadata_store: - dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store)) - else: - dist_kvstore = asyncio.run( - kvstore_impl( - SqliteKVStoreConfig( - db_path=( - DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" - ).as_posix() - ) - ) - ) - dist_registry = CachedDiskDistributionRegistry(dist_kvstore) + dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 994fb475c6..897bb90d04 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -9,9 +9,17 @@ import pydantic -from llama_stack.distribution.datatypes import RoutableObjectWithProvider +from llama_stack.distribution.datatypes import ( + RoutableObjectWithProvider, + StackRunConfig, +) +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.kvstore import ( + KVStore, + kvstore_impl, + SqliteKVStoreConfig, +) class DistributionRegistry(Protocol): @@ -133,3 +141,21 @@ async def register(self, obj: RoutableObjectWithProvider) -> bool: self.cache[obj.identifier].append(obj) return success + + +async def create_dist_registry( + config: StackRunConfig, +) -> tuple[CachedDiskDistributionRegistry, KVStore]: + # instantiate kvstore for storing and retrieving distribution metadata + if config.metadata_store: + dist_kvstore = await kvstore_impl(config.metadata_store) + else: + dist_kvstore = await kvstore_impl( + SqliteKVStoreConfig( + db_path=( + DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" + ).as_posix() + ) + ) + + return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore