Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add tags field for models with dynamic and user-defined population #629

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llama_stack/apis/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class CommonModelFields(BaseModel):
description="Any additional metadata for this model",
)

tags: Dict[str, str] = Field(
default_factory=dict,
description="Tags associated with this model as a dictionary",
)


@json_schema_type
class ModelType(str, Enum):
Expand Down Expand Up @@ -69,6 +74,7 @@ async def register_model(
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
tags: Optional[Dict[str, str]] = None,
) -> Model: ...

@webmethod(route="/models/unregister", method="POST")
Expand Down
27 changes: 26 additions & 1 deletion llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
Registry = Dict[str, List[RoutableObjectWithProvider]]


def extract_tags_from_identifier(identifier: str) -> Dict[str, str]:
tags = {}
version_match = re.search(r"(\d+\.\d+)", identifier)
model_type_match = re.search(r"(Instruct|Vision|Other|chat)", identifier)
size_match = re.search(r"(\d+)(B|M)", identifier)

if version_match:
tags["llama_version"] = version_match.group(1)
if model_type_match:
tags["model_type"] = model_type_match.group(1)
if size_match:
tags["model_size"] = size_match.group(1) + size_match.group(2)
return tags


class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
Expand Down Expand Up @@ -198,7 +213,14 @@ async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]

class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[Model]:
return await self.get_all_with_type("model")
models = await self.get_all_with_type("model")
for model in models:
if not model.tags: # If there are no tags, assign them
tags = extract_tags_from_identifier(model.identifier)
model.tags = tags
await self.dist_registry.register(model)

return models

async def get_model(self, identifier: str) -> Optional[Model]:
return await self.get_object_by_identifier("model", identifier)
Expand All @@ -210,6 +232,7 @@ async def register_model(
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
tags: Optional[Dict[str, str]] = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
Expand All @@ -229,12 +252,14 @@ async def register_model(
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
tags = extract_tags_from_identifier(model_id)
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
tags=tags,
)
registered_model = await self.register_object(model)
return registered_model
Expand Down