Skip to content

Commit

Permalink
Expose "safe_serialization" parameter from AutoModel to HuggingFaceEm… (
Browse files Browse the repository at this point in the history
#11939)

* Expose "safe_serialization" parameter from AutoModel to HuggingFaceEmbedding

* Change "safe_serialization" parameter to Optional.

* Add comma for linter check

* cr

---------

Co-authored-by: Haotian Zhang <[email protected]>
  • Loading branch information
hkristof03 and hatianzhang authored Mar 15, 2024
1 parent 1bd378c commit cd6700a
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
trust_remote_code: bool = False,
device: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
safe_serialization: Optional[bool] = None,
):
self._device = device or infer_torch_device()

Expand All @@ -80,7 +81,10 @@ def __init__(
else DEFAULT_HUGGINGFACE_EMBEDDING_MODEL
)
model = AutoModel.from_pretrained(
model_name, cache_dir=cache_folder, trust_remote_code=trust_remote_code
model_name,
cache_dir=cache_folder,
trust_remote_code=trust_remote_code,
safe_serialization=safe_serialization,
)
elif model_name is None: # Extract model_name from model
model_name = model.name_or_path
Expand Down

0 comments on commit cd6700a

Please sign in to comment.