From 3faa79d76c4762ac2dda4edc4ce80d613ed85035 Mon Sep 17 00:00:00 2001 From: Habeb Nawatha <109216430+HabebNawatha@users.noreply.github.com> Date: Sat, 14 Dec 2024 21:40:18 +0200 Subject: [PATCH] feat: Add tags field for models with dynamic and user-defined population - Implemented a Python function to extract tags from the model identifier field for dynamic population. - Enabled users to specify tags manually when registering a model. - Tags are now included when retrieving model data. Signed-off-by: Habeb Nawatha --- llama_stack/apis/models/models.py | 6 +++++ .../distribution/routers/routing_tables.py | 27 ++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 0ee23ecc17..ba1d3be7cd 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -20,6 +20,11 @@ class CommonModelFields(BaseModel): description="Any additional metadata for this model", ) + tags: Dict[str, str] = Field( + default_factory=dict, + description="Tags associated with this model as a dictionary", + ) + @json_schema_type class ModelType(str, Enum): @@ -69,6 +74,7 @@ async def register_model( provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, + tags: Optional[Dict[str, str]] = None, ) -> Model: ... @webmethod(route="/models/unregister", method="POST") diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 01edf4e5ac..3182e5ce42 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -66,6 +66,21 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: Registry = Dict[str, List[RoutableObjectWithProvider]] +def extract_tags_from_identifier(identifier: str) -> Dict[str, str]: + tags = {} + version_match = re.search(r"(\d+\.\d+)", identifier) + model_type_match = re.search(r"(Instruct|Vision|Other|chat)", identifier) + size_match = re.search(r"(\d+)(B|M)", identifier) + + if version_match: + tags["llama_version"] = version_match.group(1) + if model_type_match: + tags["model_type"] = model_type_match.group(1) + if size_match: + tags["model_size"] = size_match.group(1) + size_match.group(2) + return tags + + class CommonRoutingTableImpl(RoutingTable): def __init__( self, @@ -198,7 +213,14 @@ async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider] class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[Model]: - return await self.get_all_with_type("model") + models = await self.get_all_with_type("model") + for model in models: + if not model.tags: # If there are no tags, assign them + tags = extract_tags_from_identifier(model.identifier) + model.tags = tags + await self.dist_registry.register(model) + + return models async def get_model(self, identifier: str) -> Optional[Model]: return await self.get_object_by_identifier("model", identifier) @@ -210,6 +232,7 @@ async def register_model( provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, model_type: Optional[ModelType] = None, + tags: Optional[Dict[str, str]] = None, ) -> Model: if provider_model_id is None: provider_model_id = model_id @@ -229,12 +252,14 @@ async def register_model( raise ValueError( "Embedding model must have an embedding dimension in its metadata" ) + tags = extract_tags_from_identifier(model_id) model = Model( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, model_type=model_type, + tags=tags, ) registered_model = await self.register_object(model) return registered_model