From cd6700a274e3c77e41215d39e4b5f8799276a3ef Mon Sep 17 00:00:00 2001 From: hkristof03 Date: Fri, 15 Mar 2024 17:24:33 +0100 Subject: [PATCH] =?UTF-8?q?Expose=20"safe=5Fserialization"=20parameter=20f?= =?UTF-8?q?rom=20AutoModel=20to=20HuggingFaceEm=E2=80=A6=20(#11939)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Expose "safe_serialization" parameter from AutoModel to HuggingFaceEmbedding * Change "safe_serialization" parameter to Optional. * Add comma for linter check * cr --------- Co-authored-by: Haotian Zhang --- .../llama_index/embeddings/huggingface/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py index cca7da667cb78..4c1a81f2a2a6c 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py @@ -68,6 +68,7 @@ def __init__( trust_remote_code: bool = False, device: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, + safe_serialization: Optional[bool] = None, ): self._device = device or infer_torch_device() @@ -80,7 +81,10 @@ def __init__( else DEFAULT_HUGGINGFACE_EMBEDDING_MODEL ) model = AutoModel.from_pretrained( - model_name, cache_dir=cache_folder, trust_remote_code=trust_remote_code + model_name, + cache_dir=cache_folder, + trust_remote_code=trust_remote_code, + safe_serialization=safe_serialization, ) elif model_name is None: # Extract model_name from model model_name = model.name_or_path