Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client Reject Incompatible models #1056

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from haystack.utils import Secret, deserialize_secrets_inplace
from tqdm import tqdm

from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation
from haystack_integrations.utils.nvidia import Model, NimBackend, is_hosted, url_validation, validate_hosted_model

from .truncate import EmbeddingTruncateMode

Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
def default_model(self):
"""Set default model in local NIM mode."""
valid_models = [
model.id for model in self.backend.models() if not model.base_model or model.base_model == model.id
model.id for model in self.available_models if not model.base_model or model.base_model == model.id
]
name = next(iter(valid_models), None)
if name:
Expand Down Expand Up @@ -129,12 +129,15 @@ def warm_up(self):
api_url=self.api_url,
api_key=self.api_key,
model_kwargs=model_kwargs,
client=self.__class__.__name__,
model_type="embedding",
)

self._initialized = True

if not self.model:
self.default_model()
validate_hosted_model(self.__class__.__name__, self.model, self)

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -157,6 +160,13 @@ def to_dict(self) -> Dict[str, Any]:
truncate=str(self.truncate) if self.truncate is not None else None,
)

@property
def available_models(self) -> List[Model]:
"""
Get a list of available models that work with ChatNVIDIA.
"""
return self.backend.models() if self.backend else []

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder":
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation
from haystack_integrations.utils.nvidia import Model, NimBackend, is_hosted, url_validation, validate_hosted_model

from .truncate import EmbeddingTruncateMode

Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
def default_model(self):
"""Set default model in local NIM mode."""
valid_models = [
model.id for model in self.backend.models() if not model.base_model or model.base_model == model.id
model.id for model in self.available_models if not model.base_model or model.base_model == model.id
]
name = next(iter(valid_models), None)
if name:
Expand Down Expand Up @@ -113,12 +113,15 @@ def warm_up(self):
api_url=self.api_url,
api_key=self.api_key,
model_kwargs=model_kwargs,
client=self.__class__.__name__,
model_type="embedding",
)

self._initialized = True

if not self.model:
self.default_model()
validate_hosted_model(self.__class__.__name__, self.model, self)

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -137,6 +140,13 @@ def to_dict(self) -> Dict[str, Any]:
truncate=str(self.truncate) if self.truncate is not None else None,
)

@property
def available_models(self) -> List[Model]:
"""
Get a list of available models that work with ChatNVIDIA.
"""
return self.backend.models() if self.backend else []

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NvidiaTextEmbedder":
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation
from haystack_integrations.utils.nvidia import Model, NimBackend, is_hosted, url_validation, validate_hosted_model

_DEFAULT_API_URL = "https://integrate.api.nvidia.com/v1"

Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
def default_model(self):
"""Set default model in local NIM mode."""
valid_models = [
model.id for model in self._backend.models() if not model.base_model or model.base_model == model.id
model.id for model in self.available_models if not model.base_model or model.base_model == model.id
]
name = next(iter(valid_models), None)
if name:
Expand Down Expand Up @@ -113,10 +113,13 @@ def warm_up(self):
api_url=self._api_url,
api_key=self._api_key,
model_kwargs=self._model_arguments,
client=self.__class__.__name__,
model_type="chat",
)

if not self.is_hosted and not self._model:
self.default_model()
validate_hosted_model(self.__class__.__name__, self._model, self)

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -133,6 +136,13 @@ def to_dict(self) -> Dict[str, Any]:
model_arguments=self._model_arguments,
)

@property
def available_models(self) -> List[Model]:
"""
Get a list of available models that work with ChatNVIDIA.
"""
return self._backend.models() if self._backend else []

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NvidiaGenerator":
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .nim_backend import Model, NimBackend
from .utils import is_hosted, url_validation
from .nim_backend import NimBackend
from .statics import Model
from .utils import determine_model, is_hosted, url_validation, validate_hosted_model

__all__ = ["NimBackend", "Model", "is_hosted", "url_validation"]
__all__ = ["NimBackend", "Model", "is_hosted", "url_validation", "validate_hosted_model", "determine_model"]
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import requests
from haystack import Document
from haystack.utils import Secret

REQUEST_TIMEOUT = 60


@dataclass
class Model:
"""
Model information.
from .statics import Model

id: unique identifier for the model, passed as model parameter for requests
aliases: list of aliases for the model
base_model: root model for the model
All aliases are deprecated and will trigger a warning when used.
"""

id: str
aliases: Optional[List[str]] = field(default_factory=list)
base_model: Optional[str] = None
REQUEST_TIMEOUT = 60


class NimBackend:
Expand All @@ -31,6 +16,8 @@ def __init__(
api_url: str,
api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"),
model_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[Literal["NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder"]] = None,
model_type: Optional[Literal["chat", "embedding"]] = None,
):
headers = {
"Content-Type": "application/json",
Expand All @@ -46,6 +33,8 @@ def __init__(
self.model = model
self.api_url = api_url
self.model_kwargs = model_kwargs or {}
self.client = client
self.model_type = model_type

def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]:
url = f"{self.api_url}/embeddings"
Expand Down Expand Up @@ -125,7 +114,11 @@ def models(self) -> List[Model]:
res.raise_for_status()

data = res.json()["data"]
models = [Model(element["id"]) for element in data if "id" in element]
models = [
Model(id=element["id"], client=self.client, model_type=self.model_type, base_model=element.get("root"))
for element in data
if "id" in element
]
if not models:
msg = f"No hosted model were found at URL '{url}'."
raise ValueError(msg)
Expand Down
Loading