Skip to content

Commit

Permalink
Factor out create_dist_registry (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
dltn authored Nov 7, 2024
1 parent 694c142 commit 345ae07
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
19 changes: 2 additions & 17 deletions llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
get_provider_registry,
)

from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.store.registry import create_dist_registry

from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
Expand All @@ -42,8 +42,6 @@
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig

from .endpoints import get_all_api_endpoints

Expand Down Expand Up @@ -281,21 +279,8 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp))

app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
else:
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)

dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))

impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls:
Expand Down
30 changes: 28 additions & 2 deletions llama_stack/distribution/store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@

import pydantic

from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR

from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)


class DistributionRegistry(Protocol):
Expand Down Expand Up @@ -133,3 +141,21 @@ async def register(self, obj: RoutableObjectWithProvider) -> bool:
self.cache[obj.identifier].append(obj)

return success


async def create_dist_registry(
config: StackRunConfig,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)

return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore

0 comments on commit 345ae07

Please sign in to comment.