Skip to content

Commit

Permalink
Merge branch 'main' into update-ort-trainer-to-4.32
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Oct 18, 2023
2 parents 363f8db + e7bd60d commit a19269e
Show file tree
Hide file tree
Showing 26 changed files with 1,092 additions and 801 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/test_onnxruntime_slow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: ONNX Runtime slow / Python - Test

on:
workflow_dispatch:
schedule:
- cron: 0 7 * * * # every day at 7am

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
os: [ubuntu-20.04]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for export
run: |
pip install .[tests,onnxruntime]
- name: Test with unittest
working-directory: tests
run: |
RUN_SLOW=1 pytest onnxruntime -s -m "run_slow" --durations=0
2 changes: 2 additions & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ The list of supported model below:
- [DeiT](https://arxiv.org/abs/2012.12877)
- [Electra](https://arxiv.org/abs/2003.10555)
- [Ernie](https://arxiv.org/abs/1904.09223)
- [Falcon](https://arxiv.org/abs/2306.01116)
- [FSMT](https://arxiv.org/abs/1907.06616)
- [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
- [GPT-j](https://huggingface.co/EleutherAI/gpt-j-6B)
Expand All @@ -58,6 +59,7 @@ The list of supported model below:
- [GPT BigCode](https://arxiv.org/abs/2301.03988) (SantaCoder, StarCoder)
- [HuBERT](https://arxiv.org/pdf/2106.07447.pdf)
- [LayoutLM](https://arxiv.org/abs/1912.13318)
- [Llama & Llama2](https://arxiv.org/abs/2302.13971)
- [MarkupLM](https://arxiv.org/abs/2110.08518)
- [Marian](https://arxiv.org/abs/1804.00344)
- [MBart](https://arxiv.org/abs/2001.08210)
Expand Down
6 changes: 2 additions & 4 deletions docs/source/onnxruntime/usage_guides/pipelines.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ There are tags on the Model Hub that allow you to filter for a model you'd like

<Tip>

To be able to load the model with the ONNX Runtime backend, the export to ONNX needs
to be supported for the considered architecture.
To be able to load the model with the ONNX Runtime backend, the export to ONNX needs to be supported for the considered architecture.

You can check the list of supported architectures
[here](/exporters/onnx/package_reference/configuration#Supported-architectures).
You can check the list of supported architectures [here](https://huggingface.co/docs/optimum/exporters/onnx/overview#overview).

</Tip>

Expand Down
18 changes: 9 additions & 9 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ def parse_args_onnx(parser):
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"),
)
optional_group.add_argument(
"--no-position-ids",
action="store_true",
help=(
"Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
)
Expand Down Expand Up @@ -217,6 +209,14 @@ def parse_args_onnx(parser):
default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
help="For Segment Anything. It corresponds to the number of points per segmentation masks.",
)
optional_group.add_argument(
"--legacy",
action="store_true",
help=(
"Export decoder only models in three files (without + with past and the resulting merged model)."
"Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)

# deprecated argument
parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS)
Expand Down Expand Up @@ -255,6 +255,6 @@ def run(self):
use_subprocess=True,
_variant=self.args.variant,
library_name=self.args.library_name,
no_position_ids=self.args.no_position_ids,
legacy=self.args.legacy,
**input_shapes,
)
21 changes: 11 additions & 10 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_submodels_and_onnx_configs(
float_dtype: str = "fp32",
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
legacy: bool = False,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand All @@ -82,8 +82,8 @@ def _get_submodels_and_onnx_configs(
model=model, exporter="onnx", task=task
)
onnx_config_kwargs = {}
if task.startswith("text-generation") and no_position_ids:
onnx_config_kwargs["no_position_ids"] = no_position_ids
if task.startswith("text-generation") and legacy:
onnx_config_kwargs["no_position_ids"] = legacy

onnx_config = onnx_config_constructor(
model.config,
Expand All @@ -106,7 +106,7 @@ def _get_submodels_and_onnx_configs(
):
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("text-generation") and not monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
else:
Expand Down Expand Up @@ -184,7 +184,7 @@ def main_export(
use_subprocess: bool = False,
_variant: str = "default",
library_name: Optional[str] = None,
no_position_ids: bool = False,
legacy: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -264,8 +264,8 @@ def main_export(
library_name (`Optional[str]`, defaults to `None`):
The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect
the library name for the checkpoint.
no_position_ids (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -353,9 +353,9 @@ def main_export(
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")

if no_position_ids and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"):
logger.warning(
f"no_position_ids=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `no_position_ids=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381"
)

if not is_stable_diffusion:
Expand Down Expand Up @@ -424,7 +424,7 @@ def main_export(
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant=_variant,
no_position_ids=no_position_ids,
legacy=legacy,
)

if not is_stable_diffusion:
Expand Down Expand Up @@ -610,6 +610,7 @@ def main():
pad_token_id=args.pad_token_id,
for_ort=args.for_ort,
library_name=args.library_name,
legacy=args.legacy,
**input_shapes,
)

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
elif self.task == "feature-extraction":
common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}})
else:
common_outputs = OrderedDict({"logits": {0: "batch_size"}})
common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}})
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
Expand Down
7 changes: 2 additions & 5 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size"}}
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
else:
Expand Down Expand Up @@ -164,10 +164,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if self.use_past_in_inputs:
common_inputs["position_ids"] = {0: "batch_size"}
else:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs

Expand Down
99 changes: 80 additions & 19 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, WavLMModelPatcher
from .model_patcher import (
BartModelPatcher,
BloomModelPatcher,
LlamaModelPatcher,
MistralModelPatcher,
OPTModelPatcher,
SAMModelPatcher,
WavLMModelPatcher,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -216,13 +224,23 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return OPTModelPatcher(self, model, model_kwargs=model_kwargs)


class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
Expand All @@ -233,6 +251,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
Expand All @@ -241,6 +264,11 @@ class MPTOnnxConfig(TextDecoderOnnxConfig):
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BloomModelPatcher(self, model, model_kwargs=model_kwargs)


class BloomOnnxConfig(TextDecoderOnnxConfig):
# Bloom does not require position_ids input.
Expand Down Expand Up @@ -274,6 +302,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
1: decoder_sequence_name,
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BloomModelPatcher(self, model, model_kwargs=model_kwargs)


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
Expand Down Expand Up @@ -413,7 +446,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return int_tensor


class BartOnnxConfig(TextSeq2SeqOnnxConfig):
class M2M100OnnxConfig(TextSeq2SeqOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
Expand Down Expand Up @@ -537,11 +570,14 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
)


class MBartOnnxConfig(BartOnnxConfig):
pass
class BartOnnxConfig(M2M100OnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return BartModelPatcher(self, model, model_kwargs=model_kwargs)


class M2M100OnnxConfig(BartOnnxConfig):
class MBartOnnxConfig(BartOnnxConfig):
pass


Expand Down Expand Up @@ -998,12 +1034,37 @@ class Data2VecAudioOnnxConfig(AudioOnnxConfig):


class PerceiverDummyInputGenerator(DummyVisionInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
num_channels=num_channels,
width=width,
height=height,
**kwargs,
)

from transformers.onnx.utils import get_preprocessor

preprocessor = get_preprocessor(normalized_config._name_or_path)
if preprocessor is not None and hasattr(preprocessor, "size"):
self.height = preprocessor.size.get("height", self.height)
self.width = preprocessor.size.get("width", self.width)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
input_ = super().generate(
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
)
# if input_name == "pixel_values":
# input_ = input_[None, :]
return input_


Expand Down Expand Up @@ -1038,22 +1099,22 @@ def inputs_name(self):

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
# TODO: validate that.
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
self.inputs_name: dynamic_axis,
# TODO: should we add the attention_mask?
# This breaks things for image-classification, suspected bug is the DummyInputGenerators not having the
# same num_channels / sequence_length.
# "attention_mask": dynamic_axis,
}
if self.inputs_name in ["input_ids", "inputs"]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
}
else:
dynamic_axis = {0: "batch_size", 1: "sequence_length", 2: "width", 3: "height"}
return {
"pixel_values": dynamic_axis,
}

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
self.is_generating_dummy_inputs = True
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
specialized_inputs_name = self.inputs_name
self.is_generating_dummy_inputs = True
dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name)
dummy_inputs[self.inputs_name] = dummy_inputs.pop(self.inputs_name)
return dummy_inputs


Expand Down
Loading

0 comments on commit a19269e

Please sign in to comment.