diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index a17e8e48d7..b037dfa663 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin): memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int + embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2 overlap_size_in_tokens: Optional[int] = None diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index cb9cb1117e..71101ec8b9 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod @@ -20,6 +21,11 @@ class CommonModelFields(BaseModel): ) +class ModelType(Enum): + llm = "llm" + embedding_model = "embedding" + + @json_schema_type class Model(CommonModelFields, Resource): type: Literal[ResourceType.model.value] = ResourceType.model.value @@ -34,12 +40,14 @@ def provider_model_id(self) -> str: model_config = ConfigDict(protected_namespaces=()) + model_type: ModelType = Field(default=ModelType.llm) + class ModelInput(CommonModelFields): model_id: str provider_id: Optional[str] = None provider_model_id: Optional[str] = None - + model_type: Optional[ModelType] = ModelType.llm model_config = ConfigDict(protected_namespaces=()) @@ -59,6 +67,7 @@ async def register_model( provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> Model: ... @webmethod(route="/models/unregister", method="POST") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5b75a525b3..51be318cb3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -88,9 +88,10 @@ async def register_model( provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> None: await self.routing_table.register_model( - model_id, provider_model_id, provider_id, metadata + model_id, provider_model_id, provider_id, metadata, model_type ) async def chat_completion( @@ -105,6 +106,13 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.embedding_model: + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) params = dict( model_id=model_id, messages=messages, @@ -131,6 +139,13 @@ async def completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.embedding_model: + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -150,6 +165,13 @@ async def embeddings( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: + model = await self.routing_table.get_model(model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + if model.model_type == ModelType.llm: + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 2fb5a5e1c0..bc3de8be08 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -209,6 +209,7 @@ async def register_model( provider_model_id: Optional[str] = None, provider_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + model_type: Optional[ModelType] = None, ) -> Model: if provider_model_id is None: provider_model_id = model_id @@ -222,11 +223,21 @@ async def register_model( ) if metadata is None: metadata = {} + if model_type is None: + model_type = ModelType.llm + if ( + "embedding_dimension" not in metadata + and model_type == ModelType.embedding_model + ): + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, + model_type=model_type, ) registered_model = await self.register_object(model) return registered_model @@ -298,16 +309,29 @@ async def register_memory_bank( raise ValueError( "No provider specified and multiple providers available. Please specify a provider_id." ) - memory_bank = parse_obj_as( - MemoryBank, - { - "identifier": memory_bank_id, - "type": ResourceType.memory_bank.value, - "provider_id": provider_id, - "provider_resource_id": provider_memory_bank_id, - **params.model_dump(), - }, - ) + model = await self.get_object_by_identifier("model", params.embedding_model) + if model is None: + raise ValueError(f"Model {params.embedding_model} not found") + if model.model_type != ModelType.embedding_model: + raise ValueError( + f"Model {params.embedding_model} is not an embedding model" + ) + if "embedding_dimension" not in model.metadata: + raise ValueError( + f"Model {params.embedding_model} does not have an embedding dimension" + ) + memory_bank_data = { + "identifier": memory_bank_id, + "type": ResourceType.memory_bank.value, + "provider_id": provider_id, + "provider_resource_id": provider_memory_bank_id, + **params.model_dump(), + } + if params.memory_bank_type == MemoryBankType.vector.value: + memory_bank_data["embedding_dimension"] = model.metadata[ + "embedding_dimension" + ] + memory_bank = parse_obj_as(MemoryBank, memory_bank_data) await self.register_object(memory_bank) return memory_bank diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 041a5677c1..8f93c0c4b3 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -40,7 +40,7 @@ async def delete(self, type: str, identifier: str) -> None: ... REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v2" +KEY_VERSION = "v3" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 241497050a..27490954bc 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -200,10 +200,13 @@ def provider_data_validator(self) -> Optional[str]: return self.adapter.provider_data_validator -def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: +def remote_provider_spec( + api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None +) -> RemoteProviderSpec: return RemoteProviderSpec( api=api, provider_type=f"remote::{adapter.adapter_type}", config_class=adapter.config_class, adapter=adapter, + api_dependencies=api_dependencies or [], ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 07fd4af446..e7abde2273 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -16,12 +16,14 @@ from llama_stack.providers.utils.inference.model_registry import build_model_alias from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.embedding_mixin import ( + SentenceTransformerEmbeddingMixin, +) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_media_to_url, request_has_media, ) - from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -32,12 +34,17 @@ SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate): +class MetaReferenceInferenceImpl( + SentenceTransformerEmbeddingMixin, + Inference, + ModelsProtocolPrivate, +): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) - ModelRegistryHelper.__init__( - self, + if model is None: + raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") + self.model_registry_helper = ModelRegistryHelper( [ build_model_alias( model.descriptor(), @@ -45,8 +52,6 @@ def __init__(self, config: MetaReferenceInferenceConfig) -> None: ) ], ) - if model is None: - raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model # verify that the checkpoint actually is for this model lol @@ -76,6 +81,12 @@ def check_model(self, request) -> None: async def unregister_model(self, model_id: str) -> None: pass + async def register_model(self, model: Model) -> Model: + model = await self.model_registry_helper.register_model(model) + if model.model_type == ModelType.embedding_model: + self._load_sentence_transformer_model(model.provider_resource_id) + return model + async def completion( self, model_id: str, @@ -394,13 +405,6 @@ def impl(): for x in impl(): yield x - async def embeddings( - self, - model_id: str, - contents: List[InterleavedTextMedia], - ) -> EmbeddingsResponse: - raise NotImplementedError() - async def request_with_localized_media( request: Union[ChatCompletionRequest, CompletionRequest], diff --git a/llama_stack/providers/inline/inference/sentence_transformers/__init__.py b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py new file mode 100644 index 0000000000..d5710f7fd7 --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.inline.inference.sentence_transformers.config import ( + SentenceTransformersInferenceConfig, +) + + +async def get_provider_impl( + config: SentenceTransformersInferenceConfig, + _deps, +): + from .sentence_transformers import SentenceTransformersInferenceImpl + + impl = SentenceTransformersInferenceImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/inference/sentence_transformers/config.py b/llama_stack/providers/inline/inference/sentence_transformers/config.py new file mode 100644 index 0000000000..aec6d56d81 --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/config.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class SentenceTransformersInferenceConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py new file mode 100644 index 0000000000..0896b44af7 --- /dev/null +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import AsyncGenerator, List, Optional, Union + +from llama_stack.apis.inference import ( + CompletionResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.embedding_mixin import ( + SentenceTransformerEmbeddingMixin, +) +from .config import SentenceTransformersInferenceConfig + +log = logging.getLogger(__name__) + + +class SentenceTransformersInferenceImpl( + SentenceTransformerEmbeddingMixin, + Inference, + ModelsProtocolPrivate, +): + def __init__(self, config: SentenceTransformersInferenceConfig) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_model(self, model: Model) -> None: + _ = self._load_sentence_transformer_model(model.provider_resource_id) + return model + + async def unregister_model(self, model_id: str) -> None: + pass + + async def completion( + self, + model_id: str, + content: str, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncGenerator]: + raise ValueError("Sentence transformers don't support completion") + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise ValueError("Sentence transformers don't support chat completion") diff --git a/llama_stack/providers/inline/memory/faiss/__init__.py b/llama_stack/providers/inline/memory/faiss/__init__.py index 16c383be3f..2d7ede3b14 100644 --- a/llama_stack/providers/inline/memory/faiss/__init__.py +++ b/llama_stack/providers/inline/memory/faiss/__init__.py @@ -4,16 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec from .config import FaissImplConfig -async def get_provider_impl(config: FaissImplConfig, _deps): +async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]): from .faiss import FaissMemoryImpl assert isinstance( config, FaissImplConfig ), f"Unexpected config type: {type(config)}" - impl = FaissMemoryImpl(config) + impl = FaissMemoryImpl(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 78de131209..7c27aca85f 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -19,11 +19,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, ) @@ -32,7 +31,8 @@ logger = logging.getLogger(__name__) -MEMORY_BANKS_PREFIX = "memory_banks:v1::" +MEMORY_BANKS_PREFIX = "memory_banks:v2::" +FAISS_INDEX_PREFIX = "faiss_index:v2::" class FaissIndex(EmbeddingIndex): @@ -56,7 +56,7 @@ async def initialize(self) -> None: if not self.kvstore: return - index_key = f"faiss_index:v1::{self.bank_id}" + index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}" stored_data = await self.kvstore.get(index_key) if stored_data: @@ -85,16 +85,25 @@ async def _save_index(self): "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), } - index_key = f"faiss_index:v1::{self.bank_id}" + index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}" await self.kvstore.set(key=index_key, value=json.dumps(data)) async def delete(self): if not self.kvstore or not self.bank_id: return - await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") + await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}") async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + # Add dimension check + embedding_dim = ( + embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0] + ) + if embedding_dim != self.index.d: + raise ValueError( + f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}" + ) + indexlen = len(self.id_by_index) for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk @@ -124,8 +133,9 @@ async def query( class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: FaissImplConfig) -> None: + def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cache = {} self.kvstore = None @@ -139,10 +149,11 @@ async def initialize(self) -> None: for bank_data in stored_banks: bank = VectorMemoryBank.model_validate_json(bank_data) index = BankWithIndex( - bank=bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier + bank, + await FaissIndex.create( + bank.embedding_dimension, self.kvstore, bank.identifier ), + self.inference_api, ) self.cache[bank.identifier] = index @@ -166,13 +177,13 @@ async def register_memory_bank( ) # Store in cache - index = BankWithIndex( - bank=memory_bank, - index=await FaissIndex.create( - ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, + await FaissIndex.create( + memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier ), + self.inference_api, ) - self.cache[memory_bank.identifier] = index async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 13d463ad8a..0ff557b9f9 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -18,6 +18,7 @@ "transformers", "zmq", "lm-format-enforcer", + "sentence-transformers", ] @@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.inference.vllm", config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", ), + InlineProviderSpec( + api=Api.inference, + provider_type="inline::sentence-transformers", + pip_packages=["sentence-transformers"], + module="llama_stack.providers.inline.inference.sentence_transformers", + config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index c52aba6c6b..27c07e0079 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", + api_dependencies=[Api.inference], ), InlineProviderSpec( api=Api.memory, @@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.inline.memory.faiss", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.chroma", config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig", ), + api_dependencies=[Api.inference], ), InlineProviderSpec( api=Api.memory, @@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.pgvector", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", ), + api_dependencies=[Api.inference], ), remote_provider_spec( Api.memory, @@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", ), + api_dependencies=[Api.inference], ), remote_provider_spec( api=Api.memory, @@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.sample", config_class="llama_stack.providers.remote.memory.sample.SampleConfig", ), + api_dependencies=[], ), remote_provider_spec( Api.memory, @@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.remote.memory.qdrant", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", ), + api_dependencies=[Api.inference], ), ] diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index f575d9dc33..96cbcaa67c 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from typing import * # noqa: F403 +import json from botocore.client import BaseClient from llama_models.datatypes import CoreModelId @@ -19,8 +20,10 @@ from llama_stack.apis.inference import * # noqa: F403 + from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client +from llama_stack.providers.utils.inference.prompt_adapter import content_has_media model_aliases = [ @@ -448,4 +451,21 @@ async def embeddings( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + embeddings = [] + for content in contents: + assert not content_has_media( + content + ), "Bedrock does not support media for embeddings" + input_text = interleaved_text_media_as_str(content) + input_body = {"inputText": input_text} + body = json.dumps(input_body) + response = self.client.invoke_model( + body=body, + modelId=model.provider_resource_id, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + embeddings.append(response_body.get("embedding")) + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 062c1e1eac..e699269424 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -13,7 +13,7 @@ @json_schema_type class FireworksImplConfig(BaseModel): url: str = Field( - default="https://api.fireworks.ai/inference", + default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", ) api_key: Optional[str] = Field( @@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel): @classmethod def sample_run_config(cls) -> Dict[str, Any]: return { - "url": "https://api.fireworks.ai/inference", + "url": "https://api.fireworks.ai/inference/v1", "api_key": "${env.FIREWORKS_API_KEY}", } diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index c3e6341550..b0e93305e3 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks from llama_models.datatypes import CoreModelId @@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -89,17 +90,19 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass - def _get_client(self) -> Fireworks: - fireworks_api_key = None + def _get_api_key(self) -> str: if self.config.api_key is not None: - fireworks_api_key = self.config.api_key + return self.config.api_key else: provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.fireworks_api_key: raise ValueError( 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": }' ) - fireworks_api_key = provider_data.fireworks_api_key + return provider_data.fireworks_api_key + + def _get_client(self) -> Fireworks: + fireworks_api_key = self._get_api_key() return Fireworks(api_key=fireworks_api_key) async def completion( @@ -264,4 +267,19 @@ async def embeddings( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + kwargs = {} + if model.metadata.get("embedding_dimensions"): + kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + assert all( + not content_has_media(content) for content in contents + ), "Fireworks does not support media for embeddings" + response = self._get_client().embeddings.create( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + **kwargs, + ) + + embeddings = [data.embedding for data in response.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index d6fa20835c..1ba4ad5994 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_image_media_to_url, request_has_media, ) @@ -321,9 +322,30 @@ async def embeddings( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + assert all( + not content_has_media(content) for content in contents + ), "Ollama does not support media for embeddings" + response = await self.client.embed( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + ) + embeddings = response["embeddings"] + + return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: + # ollama does not have embedding models running. Check if the model is in list of available models. + if model.model_type == ModelType.embedding_model: + response = await self.client.list() + available_models = [m["model"] for m in response["models"]] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model '{model.provider_resource_id}' is not available in Ollama. " + f"Available models: {', '.join(available_models)}" + ) + return model model = await self.register_helper.register_model(model) models = await self.client.ps() available_models = [m["model"] for m in models["models"]] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e7c96ce98c..7cd798d160 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -253,4 +254,13 @@ async def embeddings( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + assert all( + not content_has_media(content) for content in contents + ), "Together does not support media for embeddings" + r = self._get_client().embeddings.create( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + ) + embeddings = [item.embedding for item in r.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 57f3db802b..7ad5cef0f1 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + content_has_media, convert_message_to_dict, request_has_media, ) @@ -203,4 +204,20 @@ async def embeddings( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + model = await self.model_store.get_model(model_id) + + kwargs = {} + assert model.model_type == ModelType.embedding_model + assert model.metadata.get("embedding_dimensions") + kwargs["dimensions"] = model.metadata.get("embedding_dimensions") + assert all( + not content_has_media(content) for content in contents + ), "VLLM does not support media for embeddings" + response = self.client.embeddings.create( + model=model.provider_resource_id, + input=[interleaved_text_media_as_str(content) for content in contents], + **kwargs, + ) + + embeddings = [data.embedding for data in response.data] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/providers/remote/memory/chroma/__init__.py b/llama_stack/providers/remote/memory/chroma/__init__.py index 63e9eae7d7..581d60e754 100644 --- a/llama_stack/providers/remote/memory/chroma/__init__.py +++ b/llama_stack/providers/remote/memory/chroma/__init__.py @@ -4,12 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import ChromaRemoteImplConfig -async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps): +async def get_adapter_impl( + config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec] +): from .chroma import ChromaMemoryAdapter - impl = ChromaMemoryAdapter(config) + impl = ChromaMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index f4fb50a7ce..20c81da3e7 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -13,8 +13,7 @@ from numpy.typing import NDArray from llama_stack.apis.memory import * # noqa: F403 - -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, @@ -87,10 +86,14 @@ async def delete(self): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): def __init__( - self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig] + self, + config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig], + inference_api: Api.inference, ) -> None: log.info(f"Initializing ChromaMemoryAdapter with url: {config}") self.config = config + self.inference_api = inference_api + self.client = None self.cache = {} @@ -127,10 +130,9 @@ async def register_memory_bank( metadata={"bank": memory_bank.model_dump_json()}, ) ) - bank_index = BankWithIndex( - bank=memory_bank, index=ChromaIndex(self.client, collection) + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, ChromaIndex(self.client, collection), self.inference_api ) - self.cache[memory_bank.identifier] = bank_index async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() @@ -166,6 +168,8 @@ async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: collection = await maybe_await(self.client.get_collection(bank_id)) if not collection: raise ValueError(f"Bank {bank_id} not found in Chroma") - index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) + index = BankWithIndex( + bank, ChromaIndex(self.client, collection), self.inference_api + ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/pgvector/__init__.py b/llama_stack/providers/remote/memory/pgvector/__init__.py index 4ac30452fd..b4620cae0c 100644 --- a/llama_stack/providers/remote/memory/pgvector/__init__.py +++ b/llama_stack/providers/remote/memory/pgvector/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import PGVectorConfig -async def get_adapter_impl(config: PGVectorConfig, _deps): +async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]): from .pgvector import PGVectorMemoryAdapter - impl = PGVectorMemoryAdapter(config) + impl = PGVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 9ec76e8ca4..0f295f38ae 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -16,9 +16,9 @@ from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate + from llama_stack.providers.utils.memory.vector_store import ( - ALL_MINILM_L6_V2_DIMENSION, BankWithIndex, EmbeddingIndex, ) @@ -120,8 +120,9 @@ async def delete(self): class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: PGVectorConfig) -> None: + def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.cursor = None self.conn = None self.cache = {} @@ -160,27 +161,17 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass - async def register_memory_bank( - self, - memory_bank: MemoryBank, - ) -> None: + async def register_memory_bank(self, memory_bank: MemoryBank) -> None: assert ( memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - upsert_models( - self.cursor, - [ - (memory_bank.identifier, memory_bank), - ], + upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)]) + index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor) + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, index, self.inference_api ) - index = BankWithIndex( - bank=memory_bank, - index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[memory_bank.identifier] = index - async def unregister_memory_bank(self, memory_bank_id: str) -> None: await self.cache[memory_bank_id].index.delete() del self.cache[memory_bank_id] @@ -203,14 +194,13 @@ async def query_documents( index = await self._get_and_cache_bank_index(bank_id) return await index.query_documents(query, params) + self.inference_api = inference_api + async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: if bank_id in self.cache: return self.cache[bank_id] bank = await self.memory_bank_store.get_memory_bank(bank_id) - index = BankWithIndex( - bank=bank, - index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), - ) - self.cache[bank_id] = index - return index + index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor) + self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api) + return self.cache[bank_id] diff --git a/llama_stack/providers/remote/memory/qdrant/__init__.py b/llama_stack/providers/remote/memory/qdrant/__init__.py index 9f54babad2..54605fcf91 100644 --- a/llama_stack/providers/remote/memory/qdrant/__init__.py +++ b/llama_stack/providers/remote/memory/qdrant/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import QdrantConfig -async def get_adapter_impl(config: QdrantConfig, _deps): +async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]): from .qdrant import QdrantVectorMemoryAdapter - impl = QdrantVectorMemoryAdapter(config) + impl = QdrantVectorMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index a9badbd6ab..0f1a7c7d10 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -101,10 +101,11 @@ async def query( class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): - def __init__(self, config: QdrantConfig) -> None: + def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.cache = {} + self.inference_api = inference_api async def initialize(self) -> None: pass @@ -123,6 +124,7 @@ async def register_memory_bank( index = BankWithIndex( bank=memory_bank, index=QdrantIndex(self.client, memory_bank.identifier), + inference_api=self.inference_api, ) self.cache[memory_bank.identifier] = index @@ -138,6 +140,7 @@ async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithInde index = BankWithIndex( bank=bank, index=QdrantIndex(client=self.client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/remote/memory/weaviate/__init__.py b/llama_stack/providers/remote/memory/weaviate/__init__.py index 504bd15084..f7120bec03 100644 --- a/llama_stack/providers/remote/memory/weaviate/__init__.py +++ b/llama_stack/providers/remote/memory/weaviate/__init__.py @@ -4,12 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 -async def get_adapter_impl(config: WeaviateConfig, _deps): +async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]): from .weaviate import WeaviateMemoryAdapter - impl = WeaviateMemoryAdapter(config) + impl = WeaviateMemoryAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index f05fc663ed..510915e659 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -12,10 +12,11 @@ import weaviate.classes as wvc from numpy.typing import NDArray from weaviate.classes.init import Auth +from weaviate.classes.query import Filter from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, EmbeddingIndex, @@ -80,12 +81,21 @@ async def query( return QueryDocumentsResponse(chunks=chunks, scores=scores) + async def delete(self, chunk_ids: List[str]) -> None: + collection = self.client.collections.get(self.collection_name) + collection.data.delete_many( + where=Filter.by_property("id").contains_any(chunk_ids) + ) + class WeaviateMemoryAdapter( - Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate + Memory, + NeedsRequestProviderData, + MemoryBanksProtocolPrivate, ): - def __init__(self, config: WeaviateConfig) -> None: + def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None: self.config = config + self.inference_api = inference_api self.client_cache = {} self.cache = {} @@ -117,7 +127,7 @@ async def register_memory_bank( memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.memory_bank_type == MemoryBankType.vector + memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.memory_bank_type}" client = self._get_client() @@ -135,11 +145,11 @@ async def register_memory_bank( ], ) - index = BankWithIndex( - bank=memory_bank, - index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.cache[memory_bank.identifier] = BankWithIndex( + memory_bank, + WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.inference_api, ) - self.cache[memory_bank.identifier] = index async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: if bank_id in self.cache: @@ -156,6 +166,7 @@ async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithInde index = BankWithIndex( bank=bank, index=WeaviateIndex(client=client, collection_name=bank_id), + inference_api=self.inference_api, ) self.cache[bank_id] = index return index diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 7fe19b4037..54ebcd83a0 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -18,6 +18,12 @@ def pytest_addoption(parser): default=None, help="Specify the inference model to use for testing", ) + parser.addoption( + "--embedding-model", + action="store", + default=None, + help="Specify the embedding model to use for testing", + ) def pytest_configure(config): diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 21e1221491..ed0b0302d9 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -9,9 +9,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.models import ModelInput - +from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider + from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) @@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) + # If embedding dimension is set, use the 8B model for testing + if os.getenv("EMBEDDING_DIMENSION"): + inference_model = ["meta-llama/Llama-3.1-8B-Instruct"] return ProviderFixture( providers=[ @@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) - if "Llama3.1-8B-Instruct" in inference_model: + if inference_model and "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") return ProviderFixture( @@ -232,11 +235,23 @@ def model_id(inference_model) -> str: async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") + model_type = ModelType.llm + metadata = {} + if os.getenv("EMBEDDING_DIMENSION"): + model_type = ModelType.embedding_model + metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION") + test_stack = await construct_stack_for_test( [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, - models=[ModelInput(model_id=inference_model)], + models=[ + ModelInput( + model_id=inference_model, + model_type=model_type, + metadata=metadata, + ) + ], ) return test_stack.impls[Api.inference], test_stack.impls[Api.models] diff --git a/llama_stack/providers/tests/inference/test_embeddings.py b/llama_stack/providers/tests/inference/test_embeddings.py new file mode 100644 index 0000000000..3502c6b20b --- /dev/null +++ b/llama_stack/providers/tests/inference/test_embeddings.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.inference import EmbeddingsResponse, ModelType + +# How to run this test: +# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py + + +class TestEmbeddings: + @pytest.mark.asyncio + async def test_embeddings(self, inference_model, inference_stack): + inference_impl, models_impl = inference_stack + model = await models_impl.get_model(inference_model) + + if model.model_type != ModelType.embedding_model: + pytest.skip("This test is only applicable for embedding models") + + response = await inference_impl.embeddings( + model_id=inference_model, + contents=["Hello, world!"], + ) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) > 0 + assert all(isinstance(embedding, list) for embedding in response.embeddings) + assert all( + isinstance(value, float) + for embedding in response.embeddings + for value in embedding + ) + + @pytest.mark.asyncio + async def test_batch_embeddings(self, inference_model, inference_stack): + inference_impl, models_impl = inference_stack + model = await models_impl.get_model(inference_model) + + if model.model_type != ModelType.embedding_model: + pytest.skip("This test is only applicable for embedding models") + + texts = ["Hello, world!", "This is a test", "Testing embeddings"] + + response = await inference_impl.embeddings( + model_id=inference_model, + contents=texts, + ) + + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == len(texts) + assert all(isinstance(embedding, list) for embedding in response.embeddings) + assert all( + isinstance(value, float) + for embedding in response.embeddings + for value in embedding + ) + + embedding_dim = len(response.embeddings[0]) + assert all(len(embedding) == embedding_dim for embedding in response.embeddings) diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/memory/conftest.py index 99ecbe794b..7595538eba 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/memory/conftest.py @@ -6,9 +6,65 @@ import pytest +from ..conftest import get_provider_fixture_overrides + +from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import MEMORY_FIXTURES +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "inference": "meta_reference", + "memory": "faiss", + }, + id="meta_reference", + marks=pytest.mark.meta_reference, + ), + pytest.param( + { + "inference": "ollama", + "memory": "pgvector", + }, + id="ollama", + marks=pytest.mark.ollama, + ), + pytest.param( + { + "inference": "together", + "memory": "chroma", + }, + id="chroma", + marks=pytest.mark.chroma, + ), + pytest.param( + { + "inference": "bedrock", + "memory": "qdrant", + }, + id="qdrant", + marks=pytest.mark.qdrant, + ), + pytest.param( + { + "inference": "fireworks", + "memory": "weaviate", + }, + id="weaviate", + marks=pytest.mark.weaviate, + ), +] + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default=None, + help="Specify the inference model to use for testing", + ) + + def pytest_configure(config): for fixture_name in MEMORY_FIXTURES: config.addinivalue_line( @@ -18,12 +74,22 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): + if "inference_model" in metafunc.fixturenames: + model = metafunc.config.getoption("--inference-model") + if not model: + raise ValueError( + "No inference model specified. Please provide a valid inference model." + ) + params = [pytest.param(model, id="")] + + metafunc.parametrize("inference_model", params, indirect=True) if "memory_stack" in metafunc.fixturenames: - metafunc.parametrize( - "memory_stack", - [ - pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) - for fixture_name in MEMORY_FIXTURES - ], - indirect=True, + available_fixtures = { + "inference": INFERENCE_FIXTURES, + "memory": MEMORY_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS ) + metafunc.parametrize("memory_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index cc57bb9162..92fd1720e9 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,6 +10,8 @@ import pytest import pytest_asyncio +from llama_stack.apis.inference import ModelInput, ModelType + from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig @@ -105,14 +107,30 @@ def memory_chroma() -> ProviderFixture: @pytest_asyncio.fixture(scope="session") -async def memory_stack(request): - fixture_name = request.param - fixture = request.getfixturevalue(f"memory_{fixture_name}") +async def memory_stack(inference_model, request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["inference", "memory"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) test_stack = await construct_stack_for_test( - [Api.memory], - {"memory": fixture.providers}, - fixture.provider_data, + [Api.memory, Api.inference], + providers, + provider_data, + models=[ + ModelInput( + model_id=inference_model, + model_type=ModelType.embedding_model, + metadata={ + "embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"), + }, + ) + ], ) return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b6e2e0a76e..03597d073d 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -45,12 +45,14 @@ def sample_documents(): ] -async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: +async def register_memory_bank( + banks_impl: MemoryBanks, inference_model: str +) -> MemoryBank: bank_id = f"test_bank_{uuid.uuid4().hex}" return await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: class TestMemory: @pytest.mark.asyncio - async def test_banks_list(self, memory_stack): + async def test_banks_list(self, memory_stack, inference_model): _, banks_impl = memory_stack # Register a test bank - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, inference_model) try: # Verify our bank shows up in list @@ -84,7 +86,7 @@ async def test_banks_list(self, memory_stack): ) @pytest.mark.asyncio - async def test_banks_register(self, memory_stack): + async def test_banks_register(self, memory_stack, inference_model): _, banks_impl = memory_stack bank_id = f"test_bank_{uuid.uuid4().hex}" @@ -94,7 +96,7 @@ async def test_banks_register(self, memory_stack): await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -109,7 +111,7 @@ async def test_banks_register(self, memory_stack): await banks_impl.register_memory_bank( memory_bank_id=bank_id, params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", + embedding_model=inference_model, chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -126,13 +128,15 @@ async def test_banks_register(self, memory_stack): await banks_impl.unregister_memory_bank(bank_id) @pytest.mark.asyncio - async def test_query_documents(self, memory_stack, sample_documents): + async def test_query_documents( + self, memory_stack, inference_model, sample_documents + ): memory_impl, banks_impl = memory_stack with pytest.raises(ValueError): await memory_impl.insert_documents("test_bank", sample_documents) - registered_bank = await register_memory_bank(banks_impl) + registered_bank = await register_memory_bank(banks_impl, inference_model) await memory_impl.insert_documents( registered_bank.memory_bank_id, sample_documents ) @@ -165,13 +169,13 @@ async def test_query_documents(self, memory_stack, sample_documents): # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document - params5 = {"score_threshold": 0.2} + params5 = {"score_threshold": 0.01} response5 = await memory_impl.query_documents( registered_bank.memory_bank_id, query5, params5 ) assert_valid_response(response5) print("The scores are:", response5.scores) - assert all(score >= 0.2 for score in response5.scores) + assert all(score >= 0.01 for score in response5.scores) def assert_valid_response(response: QueryDocumentsResponse): diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py new file mode 100644 index 0000000000..b53f8cd32f --- /dev/null +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from typing import List + +from llama_models.llama3.api.datatypes import InterleavedTextMedia + +from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore + +EMBEDDING_MODELS = {} + + +log = logging.getLogger(__name__) + + +class SentenceTransformerEmbeddingMixin: + model_store: ModelStore + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + model = await self.model_store.get_model(model_id) + embedding_model = self._load_sentence_transformer_model( + model.provider_resource_id + ) + embeddings = embedding_model.encode(contents) + return EmbeddingsResponse(embeddings=embeddings) + + def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + global EMBEDDING_MODELS + + loaded_model = EMBEDDING_MODELS.get(model) + if loaded_model is not None: + return loaded_model + + log.info(f"Loading sentence transformer for {model}...") + from sentence_transformers import SentenceTransformer + + loaded_model = SentenceTransformer(model) + EMBEDDING_MODELS[model] = loaded_model + return loaded_model diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 8dbfab14aa..be2642cdb2 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -9,6 +9,7 @@ from llama_models.sku_list import all_registered_models +from llama_stack.apis.models.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( @@ -77,7 +78,13 @@ def get_llama_model(self, provider_model_id: str) -> str: return None async def register_model(self, model: Model) -> Model: - provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if model.model_type == ModelType.embedding_model: + # embedding models are always registered by their provider model id and does not need to be mapped to a llama model + provider_resource_id = model.provider_resource_id + else: + provider_resource_id = self.get_provider_model_id( + model.provider_resource_id + ) if provider_resource_id: model.provider_resource_id = provider_resource_id else: diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index eb83aa6715..cebe897bc1 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -22,28 +22,10 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) -ALL_MINILM_L6_V2_DIMENSION = 384 - -EMBEDDING_MODELS = {} - - -def get_embedding_model(model: str) -> "SentenceTransformer": - global EMBEDDING_MODELS - - loaded_model = EMBEDDING_MODELS.get(model) - if loaded_model is not None: - return loaded_model - - log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - - loaded_model = SentenceTransformer(model) - EMBEDDING_MODELS[model] = loaded_model - return loaded_model - def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string @@ -166,12 +148,12 @@ async def delete(self): class BankWithIndex: bank: VectorMemoryBank index: EmbeddingIndex + inference_api: Api.inference async def insert_documents( self, documents: List[MemoryBankDocument], ) -> None: - model = get_embedding_model(self.bank.embedding_model) for doc in documents: content = await content_from_doc(doc) chunks = make_overlapped_chunks( @@ -183,7 +165,10 @@ async def insert_documents( ) if not chunks: continue - embeddings = model.encode([x.content for x in chunks]).astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [x.content for x in chunks] + ) + embeddings = np.array(embeddings_response.embeddings) await self.index.add_chunks(chunks, embeddings) @@ -208,6 +193,8 @@ def _process(c) -> str: else: query_str = _process(query) - model = get_embedding_model(self.bank.embedding_model) - query_vector = model.encode([query_str])[0].astype(np.float32) + embeddings_response = await self.inference_api.embeddings( + self.bank.embedding_model, [query_str] + ) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) return await self.index.query(query_vector, k, score_threshold)