Skip to content

Commit

Permalink
Make embedding generation go through inference (#606)
Browse files Browse the repository at this point in the history
This PR does the following:
1) adds the ability to generate embeddings in all supported inference
providers.
2) Moves all the memory providers to use the inference API and improved
the memory tests to setup the inference stack correctly and use the
embedding models

This is a merge from #589 and #598
  • Loading branch information
dineshyv authored Dec 12, 2024
1 parent a14785a commit 96e158e
Show file tree
Hide file tree
Showing 37 changed files with 677 additions and 156 deletions.
1 change: 1 addition & 0 deletions llama_stack/apis/memory_banks/memory_banks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
overlap_size_in_tokens: Optional[int] = None


Expand Down
11 changes: 10 additions & 1 deletion llama_stack/apis/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable

from llama_models.schema_utils import json_schema_type, webmethod
Expand All @@ -20,6 +21,11 @@ class CommonModelFields(BaseModel):
)


class ModelType(Enum):
llm = "llm"
embedding_model = "embedding"


@json_schema_type
class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
Expand All @@ -34,12 +40,14 @@ def provider_model_id(self) -> str:

model_config = ConfigDict(protected_namespaces=())

model_type: ModelType = Field(default=ModelType.llm)


class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
provider_model_id: Optional[str] = None

model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=())


Expand All @@ -59,6 +67,7 @@ async def register_model(
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ...

@webmethod(route="/models/unregister", method="POST")
Expand Down
24 changes: 23 additions & 1 deletion llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ async def register_model(
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata
model_id, provider_model_id, provider_id, metadata, model_type
)

async def chat_completion(
Expand All @@ -105,6 +106,13 @@ async def chat_completion(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict(
model_id=model_id,
messages=messages,
Expand All @@ -131,6 +139,13 @@ async def completion(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
Expand All @@ -150,6 +165,13 @@ async def embeddings(
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
Expand Down
44 changes: 34 additions & 10 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def register_model(
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
Expand All @@ -222,11 +223,21 @@ async def register_model(
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
Expand Down Expand Up @@ -298,16 +309,29 @@ async def register_memory_bank(
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
memory_bank = parse_obj_as(
MemoryBank,
{
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
},
)
model = await self.get_object_by_identifier("model", params.embedding_model)
if model is None:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding_model:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {params.embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
await self.register_object(memory_bank)
return memory_bank

Expand Down
2 changes: 1 addition & 1 deletion llama_stack/distribution/store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def delete(self, type: str, identifier: str) -> None: ...


REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v2"
KEY_VERSION = "v3"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"


Expand Down
5 changes: 4 additions & 1 deletion llama_stack/providers/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,13 @@ def provider_data_validator(self) -> Optional[str]:
return self.adapter.provider_data_validator


def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
api_dependencies=api_dependencies or [],
)
30 changes: 17 additions & 13 deletions llama_stack/providers/inline/inference/meta_reference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
)

from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
Expand All @@ -32,21 +34,24 @@
SEMAPHORE = asyncio.Semaphore(1)


class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
ModelRegistryHelper.__init__(
self,
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol

Expand Down Expand Up @@ -76,6 +81,12 @@ def check_model(self, request) -> None:
async def unregister_model(self, model_id: str) -> None:
pass

async def register_model(self, model: Model) -> Model:
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding_model:
self._load_sentence_transformer_model(model.provider_resource_id)
return model

async def completion(
self,
model_id: str,
Expand Down Expand Up @@ -394,13 +405,6 @@ def impl():
for x in impl():
yield x

async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()


async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)


async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
):
from .sentence_transformers import SentenceTransformersInferenceImpl

impl = SentenceTransformersInferenceImpl(config)
await impl.initialize()
return impl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from pydantic import BaseModel


class SentenceTransformersInferenceConfig(BaseModel): ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import logging
from typing import AsyncGenerator, List, Optional, Union

from llama_stack.apis.inference import (
CompletionResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from .config import SentenceTransformersInferenceConfig

log = logging.getLogger(__name__)


class SentenceTransformersInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config

async def initialize(self) -> None:
pass

async def shutdown(self) -> None:
pass

async def register_model(self, model: Model) -> None:
_ = self._load_sentence_transformer_model(model.provider_resource_id)
return model

async def unregister_model(self, model_id: str) -> None:
pass

async def completion(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncGenerator]:
raise ValueError("Sentence transformers don't support completion")

async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
Loading

0 comments on commit 96e158e

Please sign in to comment.