diff --git a/docs/package_reference/util.md b/docs/package_reference/util.md index 3e81f6de2..495a914fc 100644 --- a/docs/package_reference/util.md +++ b/docs/package_reference/util.md @@ -10,7 +10,7 @@ ## Model Optimization ```eval_rst .. automodule:: sentence_transformers.backend - :members: export_optimized_onnx_model, export_dynamic_quantized_onnx_model + :members: export_optimized_onnx_model, export_dynamic_quantized_onnx_model, export_static_quantized_openvino_model ``` ## Similarity Metrics diff --git a/docs/sentence_transformer/usage/efficiency.rst b/docs/sentence_transformer/usage/efficiency.rst index c30770078..e2c988f15 100644 --- a/docs/sentence_transformer/usage/efficiency.rst +++ b/docs/sentence_transformer/usage/efficiency.rst @@ -290,6 +290,77 @@ To convert a model to OpenVINO format, you can use the following code: model = SentenceTransformer("intfloat/multilingual-e5-small", backend="openvino") model.push_to_hub("intfloat/multilingual-e5-small", create_pr=True) +Quantizing OpenVINO Models +^^^^^^^^^^^^^^^^^^^^^^ + +OpenVINO models can be quantized to int8 precision using Optimum Intel to speed up inference. +To do this, you can use the :func:`~sentence_transformers.backend.export_static_quantized_openvino_model` function, +which saves the quantized model in a directory or model repository that you specify. +Post-Training Static Quantization expects: + +- ``model``: a Sentence Transformer model loaded with the OpenVINO backend. +- ``quantization_config``: a quantization configuration from :class:`~optimum.intel.OVQuantizationConfig` instance. +- ``model_name_or_path``: a path to save the quantized model file, or the repository name if you want to push it to the Hugging Face Hub. +- ``push_to_hub``: (Optional) a boolean to push the quantized model to the Hugging Face Hub. +- ``create_pr``: (Optional) a boolean to create a pull request when pushing to the Hugging Face Hub. Useful when you don't have write access to the repository. +- ``file_suffix``: (Optional) a string to append to the model name when saving it. If not specified, ``"qint8_quantized"`` will be used. + +See this example for quantizing a model to ``int8`` with :doc:`static quantization `: + +.. tab:: Hugging Face Hub Model + + Only quantize once:: + + from sentence_transformers import SentenceTransformer, export_static_quantized_openvino_model + from optimum.intel import OVQuantizationConfig + + model = SentenceTransformer("all-MiniLM-L6-v2", backend="openvino") + quantization_config = OVQuantizationConfig() + export_static_quantized_openvino_model(model, quantization_config, "all-MiniLM-L6-v2", push_to_hub=True, create_pr=True) + + Before the pull request gets merged:: + + from sentence_transformers import SentenceTransformer + + pull_request_nr = 2 # TODO: Update this to the number of your pull request + model = SentenceTransformer( + "all-MiniLM-L6-v2", + backend="openvino", + model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"}, + revision=f"refs/pr/{pull_request_nr}" + ) + + Once the pull request gets merged:: + + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer( + "all-MiniLM-L6-v2", + backend="openvino", + model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"}, + ) + +.. tab:: Local Model + + Only quantize once:: + + from sentence_transformers import SentenceTransformer, export_static_quantized_openvino_model + from optimum.intel import OVQuantizationConfig + + model = SentenceTransformer("path/to/my/mpnet-legal-finetuned", backend="openvino") + quantization_config = OVQuantizationConfig() + export_static_quantized_openvino_model(model, quantization_config, "path/to/my/mpnet-legal-finetuned") + + After quantizing:: + + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer( + "path/to/my/mpnet-legal-finetuned", + backend="openvino", + model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"}, + ) + Benchmarks ---------- diff --git a/sentence_transformers/__init__.py b/sentence_transformers/__init__.py index 1ba4558e8..22d417ab2 100644 --- a/sentence_transformers/__init__.py +++ b/sentence_transformers/__init__.py @@ -7,6 +7,7 @@ import os from sentence_transformers.backend import export_dynamic_quantized_onnx_model, export_optimized_onnx_model +from sentence_transformers.backend import export_static_quantized_openvino_model from sentence_transformers.cross_encoder.CrossEncoder import CrossEncoder from sentence_transformers.datasets import ParallelSentencesDataset, SentencesDataset from sentence_transformers.LoggingHandler import LoggingHandler @@ -37,4 +38,5 @@ "quantize_embeddings", "export_optimized_onnx_model", "export_dynamic_quantized_onnx_model", + "export_static_quantized_openvino_model", ] diff --git a/sentence_transformers/backend.py b/sentence_transformers/backend.py index 355f40d83..6130faa6a 100644 --- a/sentence_transformers/backend.py +++ b/sentence_transformers/backend.py @@ -16,6 +16,7 @@ try: from optimum.onnxruntime.configuration import OptimizationConfig, QuantizationConfig + from optimum.intel import OVQuantizationConfig except ImportError: pass @@ -97,7 +98,7 @@ def export_optimized_onnx_model( if file_suffix is None: file_suffix = "optimized" - save_or_push_to_hub_onnx_model( + save_or_push_to_hub_model( export_function=lambda save_dir: optimizer.optimize(optimization_config, save_dir, file_suffix=file_suffix), export_function_name="export_optimized_onnx_model", config=optimization_config, @@ -105,6 +106,7 @@ def export_optimized_onnx_model( push_to_hub=push_to_hub, create_pr=create_pr, file_suffix=file_suffix, + backend="onnx", ) @@ -180,7 +182,7 @@ def export_dynamic_quantized_onnx_model( if file_suffix is None: file_suffix = f"{quantization_config.weights_dtype.name.lower()}_quantized" - save_or_push_to_hub_onnx_model( + save_or_push_to_hub_model( export_function=lambda save_dir: quantizer.quantize(quantization_config, save_dir, file_suffix=file_suffix), export_function_name="export_dynamic_quantized_onnx_model", config=quantization_config, @@ -188,10 +190,92 @@ def export_dynamic_quantized_onnx_model( push_to_hub=push_to_hub, create_pr=create_pr, file_suffix=file_suffix, + backend="onnx", ) -def save_or_push_to_hub_onnx_model( +def export_static_quantized_openvino_model( + model: SentenceTransformer, + quantization_config: OVQuantizationConfig, + model_name_or_path: str, + push_to_hub: bool = False, + create_pr: bool = False, + file_suffix: str = "qint8_quantized", +) -> None: + """ + Export a quantized OpenVINO model from a SentenceTransformer model. + + This function applies Post-Training Static Quantization (PTQ) using a calibration dataset, which calibrates + quantization constants without requiring model retraining. Each default quantization configuration converts + the model to int8 precision, enabling faster inference while maintaining accuracy. + + See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for more information & benchmarks. + + Args: + model (SentenceTransformer): The SentenceTransformer model to be quantized. Must be loaded with `backend="openvino"`. + quantization_config (OVQuantizationConfig): The quantization configuration. + model_name_or_path (str): The path or Hugging Face Hub repository name where the quantized model will be saved. + push_to_hub (bool, optional): Whether to push the quantized model to the Hugging Face Hub. Defaults to False. + create_pr (bool, optional): Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False. + file_suffix (str, optional): The suffix to add to the quantized model file name. Defaults to `qint8_quantized`. + + Raises: + ImportError: If the required packages `optimum` and `openvino` are not installed. + ValueError: If the provided model is not a valid SentenceTransformer model loaded with `backend="openvino"`. + ValueError: If the provided quantization_config is not valid. + + Returns: + None + """ + from sentence_transformers import SentenceTransformer + from sentence_transformers.models.Transformer import Transformer + + try: + from optimum.intel import OVModelForFeatureExtraction, OVQuantizer, OVConfig + except ImportError: + raise ImportError( + "Please install Optimum and OpenVINO to use this function. " + "You can install them with pip: `pip install optimum[openvino]`" + ) + + if ( + not isinstance(model, SentenceTransformer) + or not len(model) + or not isinstance(model[0], Transformer) + or not isinstance(model[0].auto_model, OVModelForFeatureExtraction) + ): + raise ValueError( + 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="openvino"`.' + ) + + ov_model: OVModelForFeatureExtraction = model[0].auto_model + ov_config = OVConfig(quantization_config=quantization_config) + quantizer = OVQuantizer.from_pretrained(ov_model) + + def preprocess_function(examples): + return model.tokenizer(examples["sentence"], padding="max_length", max_length=384, truncation=True) + + calibration_dataset = quantizer.get_calibration_dataset( + dataset_name="glue", + dataset_config_name="sst2", + preprocess_function=preprocess_function, + num_samples=300, + dataset_split="train", + ) + + save_or_push_to_hub_model( + export_function=lambda save_dir: quantizer.quantize(calibration_dataset, save_directory=save_dir, ov_config=ov_config), + export_function_name="export_static_quantized_openvino_model", + config=quantization_config, + model_name_or_path=model_name_or_path, + push_to_hub=push_to_hub, + create_pr=create_pr, + file_suffix=file_suffix, + backend="openvino", + ) + + +def save_or_push_to_hub_model( export_function: Callable, export_function_name: str, config, @@ -199,13 +283,25 @@ def save_or_push_to_hub_onnx_model( push_to_hub: bool = False, create_pr: bool = False, file_suffix: str | None = None, + backend: str = "onnx", ): + if backend == "onnx": + file_name = f"model_{file_suffix}.onnx" + elif backend == "openvino": + file_name = f"openvino_model.xml" + destination_file_name = Path(f"openvino_model_{file_suffix}.xml") + if push_to_hub: with tempfile.TemporaryDirectory() as save_dir: export_function(save_dir) - file_name = f"model_{file_suffix}.onnx" - source = (Path(save_dir) / file_name).as_posix() - destination = (Path("onnx") / file_name).as_posix() + if backend == "onnx": + source = (Path(save_dir) / file_name).as_posix() + destination = Path(backend) / file_name + elif backend == "openvino": + source = (Path(save_dir) / backend / file_name).as_posix() + destination = Path(backend) / destination_file_name + else: + raise NotImplementedError(f"Unsupported backend type: {backend}") commit_description = "" if create_pr: @@ -230,7 +326,7 @@ def save_or_push_to_hub_onnx_model( model = SentenceTransformer( "{model_name_or_path}", revision=f"refs/pr/{{pr_number}}", - backend="onnx", + backend="{backend}", model_kwargs={{"file_name": "{destination}"}}, ) @@ -245,10 +341,10 @@ def save_or_push_to_hub_onnx_model( huggingface_hub.upload_file( path_or_fileobj=source, - path_in_repo=destination, + path_in_repo=destination.as_posix(), repo_id=model_name_or_path, repo_type="model", - commit_message=f"Add exported ONNX model {file_name!r}", + commit_message=f"Add exported {backend} model {destination.name!r}", commit_description=commit_description, create_pr=create_pr, ) @@ -257,9 +353,18 @@ def save_or_push_to_hub_onnx_model( with tempfile.TemporaryDirectory() as save_dir: export_function(save_dir) - file_name = f"model_{file_suffix}.onnx" - source = os.path.join(save_dir, file_name) - destination = os.path.join(model_name_or_path, "onnx", file_name) + dst_dir = os.path.join(model_name_or_path, backend) # Create destination if it does not exist - os.makedirs(os.path.dirname(destination), exist_ok=True) - shutil.copy(source, destination) + os.makedirs(dst_dir, exist_ok=True) + + if backend == "openvino": + source = Path(save_dir) / backend / file_name + bin_file = source.with_suffix(".bin") + xml_destination = os.path.join(dst_dir, destination_file_name) + bin_destination = os.path.join(dst_dir, destination_file_name.with_suffix(".bin")) + shutil.copy(source, xml_destination) + shutil.copy(bin_file, bin_destination) + else: + source = os.path.join(save_dir, file_name) + destination = os.path.join(dst_dir, file_name) + shutil.copy(source, destination)