Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simpler azure embedding #2751

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""add api_version and deployment_name to search settings

Revision ID: 5d12a446f5c0
Revises: e4334d5b33ba
Create Date: 2024-10-08 15:56:07.975636

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "5d12a446f5c0"
down_revision = "e4334d5b33ba"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
)
op.add_column(
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)


def downgrade() -> None:
op.drop_column("embedding_provider", "deployment_name")
op.drop_column("embedding_provider", "api_version")
18 changes: 18 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ class SearchSettings(Base):
normalize: Mapped[bool] = mapped_column(Boolean)
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)

status: Mapped[IndexModelStatus] = mapped_column(
Enum(IndexModelStatus, native_enum=False)
)
Expand Down Expand Up @@ -670,6 +671,20 @@ def __repr__(self) -> str:
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"

@property
def api_version(self) -> str | None:
return (
self.cloud_provider.api_version if self.cloud_provider is not None else None
)

@property
def deployment_name(self) -> str | None:
return (
self.cloud_provider.deployment_name
if self.cloud_provider is not None
else None
)

@property
def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None
Expand Down Expand Up @@ -1164,6 +1179,9 @@ class CloudEmbeddingProvider(Base):
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)

search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings",
back_populates="cloud_provider",
Expand Down
12 changes: 12 additions & 0 deletions backend/danswer/indexing/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
Expand All @@ -41,6 +43,8 @@ def __init__(
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name

self.embedding_model = EmbeddingModel(
model_name=model_name,
Expand All @@ -50,6 +54,8 @@ def __init__(
api_key=api_key,
provider_type=provider_type,
api_url=api_url,
api_version=api_version,
deployment_name=deployment_name,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
Expand All @@ -75,6 +81,8 @@ def __init__(
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
Expand All @@ -85,6 +93,8 @@ def __init__(
provider_type,
api_key,
api_url,
api_version,
deployment_name,
heartbeat,
)

Expand Down Expand Up @@ -193,5 +203,7 @@ def from_db_search_settings(
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
heartbeat=heartbeat,
)
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
Expand All @@ -106,6 +108,8 @@ def __init__(
self.model_name = model_name
self.retrim_content = retrim_content
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
Expand Down Expand Up @@ -157,6 +161,8 @@ def _batch_encode_texts(
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
api_version=self.api_version,
deployment_name=self.deployment_name,
max_context_length=max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
Expand Down Expand Up @@ -239,6 +245,8 @@ def from_db_model(
provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
)


Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/server/manage/embedding/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def test_embedding_configuration(
api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type,
model_name=test_llm_request.model_name,
api_version=test_llm_request.api_version,
deployment_name=test_llm_request.deployment_name,
normalize=False,
query_prefix=None,
passage_prefix=None,
Expand Down
8 changes: 8 additions & 0 deletions backend/danswer/server/manage/embedding/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class TestEmbeddingRequest(BaseModel):
api_key: str | None = None
api_url: str | None = None
model_name: str | None = None
api_version: str | None = None
deployment_name: str | None = None

# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
Expand All @@ -26,6 +28,8 @@ class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None

@classmethod
def from_request(
Expand All @@ -35,10 +39,14 @@ def from_request(
provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key,
api_url=cloud_provider_model.api_url,
api_version=cloud_provider_model.api_version,
deployment_name=cloud_provider_model.deployment_name,
)


class CloudEmbeddingProviderCreationRequest(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None
49 changes: 44 additions & 5 deletions backend/model_server/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
Expand Down Expand Up @@ -54,7 +55,11 @@


def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
Expand All @@ -69,6 +74,8 @@ def _initialize_client(
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
elif provider == EmbeddingProvider.AZURE:
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
else:
raise ValueError(f"Unsupported provider: {provider}")

Expand All @@ -78,11 +85,15 @@ def __init__(
self,
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
# Only for Google as is needed on client setup
model: str | None = None,
) -> None:
self.provider = provider
self.client = _initialize_client(api_key, self.provider, model)
self.client = _initialize_client(
api_key, self.provider, model, api_url, api_version
)

def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
if not model:
Expand Down Expand Up @@ -144,6 +155,18 @@ def _embed_voyage(
)
return response.embeddings

def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
response = embedding(
model=model,
input=texts,
api_key=self.client["api_key"],
api_base=self.client["api_url"],
api_version=self.client["api_version"],
)
embeddings = [embedding["embedding"] for embedding in response.data]

return embeddings

def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
Expand All @@ -169,10 +192,13 @@ def embed(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
Expand All @@ -190,10 +216,14 @@ def embed(

@staticmethod
def create(
api_key: str, provider: EmbeddingProvider, model: str | None = None
api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model)
return CloudEmbedding(api_key, provider, model, api_url, api_version)


def get_embedding_model(
Expand Down Expand Up @@ -260,12 +290,14 @@ def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
deployment_name: str | None,
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
api_url: str | None,
api_version: str | None,
) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")

Expand Down Expand Up @@ -307,11 +339,16 @@ def embed_text(
)

cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
api_key=api_key,
provider=provider_type,
model=model_name,
api_url=api_url,
api_version=api_version,
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)

Expand Down Expand Up @@ -405,12 +442,14 @@ async def process_embed_request(
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
Expand Down
1 change: 1 addition & 0 deletions backend/shared_configs/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class EmbeddingProvider(str, Enum):
VOYAGE = "voyage"
GOOGLE = "google"
LITELLM = "litellm"
AZURE = "azure"


class RerankerProvider(str, Enum):
Expand Down
3 changes: 2 additions & 1 deletion backend/shared_configs/model_server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class EmbedRequest(BaseModel):
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None = None
deployment_name: str | None = None
max_context_length: int
normalize_embeddings: bool
api_key: str | None = None
Expand All @@ -28,7 +29,7 @@ class EmbedRequest(BaseModel):
manual_query_prefix: str | None = None
manual_passage_prefix: str | None = None
api_url: str | None = None

api_version: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

Expand Down
2 changes: 1 addition & 1 deletion web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@
"eslint-config-next": "^14.1.0",
"prettier": "2.8.8"
}
}
}
Loading
Loading