Skip to content

Commit

Permalink
add embedding model by default to distribution templates (#617)
Browse files Browse the repository at this point in the history
# What does this PR do?
Adds the sentence transformer provider and the `all-MiniLM-L6-v2`
embedding model to the default models to register in the run.yaml for
all providers.

## Test Plan
llama stack build --template together --image-type conda
llama stack run
~/.llama/distributions/llamastack-together/together-run.yaml
  • Loading branch information
dineshyv authored Dec 13, 2024
1 parent e893b22 commit 516e1a3
Show file tree
Hide file tree
Showing 41 changed files with 473 additions and 64 deletions.
2 changes: 2 additions & 0 deletions distributions/dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@
"redis",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"torch",
"torchvision",
Expand Down Expand Up @@ -287,6 +288,7 @@
"redis",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"torch",
"torchao==0.5.0",
Expand Down
5 changes: 3 additions & 2 deletions llama_stack/apis/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ class CommonModelFields(BaseModel):
)


class ModelType(Enum):
@json_schema_type
class ModelType(str, Enum):
llm = "llm"
embedding_model = "embedding"
embedding = "embedding"


@json_schema_type
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def chat_completion(
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
Expand Down Expand Up @@ -142,7 +142,7 @@ async def completion(
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
Expand Down
16 changes: 10 additions & 6 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,7 @@ async def register_model(
metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
Expand Down Expand Up @@ -311,8 +308,15 @@ async def register_memory_bank(
)
model = await self.get_object_by_identifier("model", params.embedding_model)
if model is None:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding_model:
if params.embedding_model == "all-MiniLM-L6-v2":
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def unregister_model(self, model_id: str) -> None:

async def register_model(self, model: Model) -> Model:
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict

from pydantic import BaseModel


class SentenceTransformersInferenceConfig(BaseModel): ...
class SentenceTransformersInferenceConfig(BaseModel):

@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {}
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ async def embeddings(

async def register_model(self, model: Model) -> Model:
# ollama does not have embedding models running. Check if the model is in list of available models.
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def embeddings(
model = await self.model_store.get_model(model_id)

kwargs = {}
assert model.model_type == ModelType.embedding_model
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/tests/inference/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ async def inference_stack(request, inference_model):
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
model_type = ModelType.embedding
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")

test_stack = await construct_stack_for_test(
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/tests/inference/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def test_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)

if model.model_type != ModelType.embedding_model:
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")

response = await inference_impl.embeddings(
Expand All @@ -39,7 +39,7 @@ async def test_batch_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)

if model.model_type != ModelType.embedding_model:
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")

texts = ["Hello, world!", "This is a test", "Testing embeddings"]
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/tests/memory/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def memory_stack(inference_model, request):
models=[
ModelInput(
model_id=inference_model,
model_type=ModelType.embedding_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/providers/utils/inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_llama_model(self, provider_model_id: str) -> str:
return None

async def register_model(self, model: Model) -> Model:
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
Expand Down
24 changes: 21 additions & 3 deletions llama_stack/templates/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

from llama_models.sku_list import all_registered_models

from llama_stack.apis.models.models import ModelType

from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases

from llama_stack.templates.template import DistributionTemplate, RunConfigSettings


Expand All @@ -29,6 +33,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::cerebras",
config=CerebrasImplConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)

core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
Expand All @@ -37,9 +46,18 @@ def get_distribution_template() -> DistributionTemplate:
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
provider_id="cerebras",
)
for m in model_aliases
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)

return DistributionTemplate(
name="cerebras",
Expand All @@ -52,9 +70,9 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"inference": [inference_provider, embedding_provider],
},
default_models=default_models,
default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
Expand Down
15 changes: 13 additions & 2 deletions llama_stack/templates/cerebras/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ providers:
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
Expand Down Expand Up @@ -49,12 +52,20 @@ metadata_store:
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_id: cerebras
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_id: cerebras
provider_model_id: llama3.1-70b
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
Expand Down
24 changes: 21 additions & 3 deletions llama_stack/templates/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

from llama_models.sku_list import all_registered_models

from llama_stack.apis.models.models import ModelType

from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES

from llama_stack.templates.template import DistributionTemplate, RunConfigSettings


Expand All @@ -35,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::fireworks",
config=FireworksImplConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
memory_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
Expand All @@ -48,9 +57,18 @@ def get_distribution_template() -> DistributionTemplate:
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
provider_id="fireworks",
)
for m in MODEL_ALIASES
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)

return DistributionTemplate(
name=name,
Expand All @@ -63,10 +81,10 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"inference": [inference_provider, embedding_provider],
"memory": [memory_provider],
},
default_models=default_models,
default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
Expand Down
38 changes: 28 additions & 10 deletions llama_stack/templates/fireworks/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ providers:
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference
url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory:
- provider_id: faiss
provider_type: inline::faiss
Expand Down Expand Up @@ -74,40 +77,55 @@ metadata_store:
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-v3p1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-v3p1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-v3p1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: null
provider_id: fireworks
provider_model_id: fireworks/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
Expand Down
Loading

0 comments on commit 516e1a3

Please sign in to comment.