Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: support Hunyuan Embedding #23160

Merged
merged 21 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
from langchain_community.embeddings.huggingface_hub import (
HuggingFaceHubEmbeddings,
)
from langchain_community.embeddings.hunyuan import (
HunyuanEmbeddings,
)
from langchain_community.embeddings.infinity import (
InfinityEmbeddings,
)
Expand Down Expand Up @@ -327,6 +330,7 @@
"XinferenceEmbeddings",
"YandexGPTEmbeddings",
"ZhipuAIEmbeddings",
"HunyuanEmbeddings",
]

_module_lookup = {
Expand Down Expand Up @@ -412,6 +416,7 @@
"YandexGPTEmbeddings": "langchain_community.embeddings.yandex",
"AscendEmbeddings": "langchain_community.embeddings.ascend",
"ZhipuAIEmbeddings": "langchain_community.embeddings.zhipuai",
"HunyuanEmbeddings": "langchain_community.embeddings.hunyuan",
}


Expand Down
114 changes: 114 additions & 0 deletions libs/community/langchain_community/embeddings/hunyuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
from typing import Any, Dict, List, Literal, Type

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tqdm import tqdm


class HunyuanEmbeddings(Embeddings, BaseModel):
"""Tencent Hunyuan embedding models API by Tencent.

For more information, see https://cloud.tencent.com/document/product/1729
"""

hunyuan_secret_id: SecretStr = Field(alias="secret_id", default=None)
"""Hunyuan Secret ID"""
hunyuan_secret_key: SecretStr = Field(alias="secret_key", default=None)
"""Hunyuan Secret Key"""
region: Literal["ap-guangzhou", "ap-beijing"] = "ap-guangzhou"
"""The region of hunyuan service."""
embedding_ctx_length: int = 1024
"""The max embedding context length of hunyuan embedding. Just note that it is 1024."""
show_progress_bar: bool = False
"""Show progress bar when embedding. Default is False."""

client: Any = Field(default=None, exclude=True)
"""The tencentcloud client."""
request_cls: Type = Field(default=None, exclude=True)
"""The request class of tencentcloud sdk."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["hunyuan_secret_id"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"hunyuan_secret_id",
"HUNYUAN_SECRET_ID",
)
)
values["hunyuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"hunyuan_secret_key",
"HUNYUAN_SECRET_KEY",
)
)

try:
from tencentcloud.common.credential import Credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.hunyuan.v20230901.hunyuan_client import HunyuanClient
from tencentcloud.hunyuan.v20230901.models import GetEmbeddingRequest
except ImportError:
raise ImportError(
'Could not import tencentcloud sdk python package. Please install it with `pip install "tencentcloud-sdk-python>=3.0.1139"`.'
)

client_profile = ClientProfile()
client_profile.httpProfile.pre_conn_pool_size = 3

credential = Credential(
values["hunyuan_secret_id"].get_secret_value(),
values["hunyuan_secret_key"].get_secret_value(),
)

values["request_cls"] = GetEmbeddingRequest

values["client"] = HunyuanClient(credential, values["region"], client_profile)
return values

def _embed_text(self, text: str) -> List[float]:
request = self.request_cls()
request.Input = text

response = self.client.GetEmbedding(request)

_response: Dict[str, Any] = json.loads(response.to_json_string())

data: List[Dict[str, Any]] | None = _response.get("Data")
if not data:
raise RuntimeError("Occur hunyuan embedding error: Data is empty")

embedding = data[0].get("Embedding")
if not embedding:
raise RuntimeError("Occur hunyuan embedding error: Embedding is empty")

return embedding

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
embeddings = []
if self.show_progress_bar:
_iter = tqdm(iterable=texts, desc="Hunyuan Embedding")
else:
_iter = texts
for text in _iter:
embeddings.append(self.embed_query(text))

return embeddings

def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self._embed_text(text)

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await run_in_executor(None, self.embed_documents, texts)

async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await run_in_executor(None, self.embed_query, text)
25 changes: 25 additions & 0 deletions libs/community/tests/integration_tests/embeddings/test_hunyuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from langchain_community.embeddings.hunyuan import HunyuanEmbeddings


def test_embedding_query() -> None:
query = "foo"
embedding = HunyuanEmbeddings()
output = embedding.embed_query(query)
assert len(output) == 1024


def test_embedding_document() -> None:
documents = ["foo bar"]
embedding = HunyuanEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 1024


def test_embedding_documents() -> None:
documents = ["foo", "bar"]
embedding = HunyuanEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024
assert len(output[1]) == 1024
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"AscendEmbeddings",
"ZhipuAIEmbeddings",
"TextEmbedEmbeddings",
"HunyuanEmbeddings",
]


Expand Down
Loading