Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export for DinoV2 models #1580

Closed
wants to merge 3 commits into from
Closed

Add ONNX export for DinoV2 models #1580

wants to merge 3 commits into from

Conversation

xenova
Copy link
Contributor

@xenova xenova commented Dec 9, 2023

What does this PR do?

As title says :)

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@xenova
Copy link
Contributor Author

xenova commented Dec 9, 2023

Although validation passes for all the facebook models on the Hub, I'm getting a few issues when actually running the models. It appears to stem from this line, since bicubic interpolation isn't supported for this operation in onnxruntime. Changing it to "bilinear" seems to work (and output doesn't differ too much).

@xenova
Copy link
Contributor Author

xenova commented Dec 9, 2023

I've updated the converter to use dummy values from the image preprocessor, so that the branch which interpolates the positional embeddings is triggered. As mentioned above, however, this doesn't work with the current state of that function (fails with mode="bicubic", and succeeds with mode="bilinear",).

@fxmarty I suppose we can fix this with a model patcher?

@fxmarty
Copy link
Contributor

fxmarty commented Jan 26, 2024

@xenova Is bicubic/linear not a config option? If it is not, I am afraid we need to patch indeed. What is the issue exactly? The model not loadable in ORT?

@xenova
Copy link
Contributor Author

xenova commented Jan 26, 2024

@xenova Is bicubic/linear not a config option? If it is not, I am afraid we need to patch indeed.

Unfortunately not - this seems to be hard-coded in the model code: to interpolate the positional embeddings to the correct shape.

What is the issue exactly? The model not loadable in ORT?

It seems to be that the required bubic interpolation is just not supported.

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsamplebase.h:147 onnxruntime::UpsampleBase::UpsampleBase(const onnxruntime::OpKernelInfo&) [ONNXRuntimeError] : 1 : FAIL : upsamplebase.h:365 ScalesValidation 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator

Here's the full error log:
config.json: 100% 954/954 [00:00<00:00, 2.71MB/s]
Framework not specified. Using pt to export the model.
model.safetensors: 100% 99.2M/99.2M [00:00<00:00, 133MB/s]
preprocessor_config.json: 100% 437/437 [00:00<00:00, 2.12MB/s]
Using the export variant default. Available variants are:
    - default: The default ONNX variant.
Using framework PyTorch: 2.1.0+cu121
/content/transformers.js/transformers/src/transformers/models/dinov2/modeling_dinov2.py:164: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_channels != self.num_channels:
/content/transformers.js/transformers/src/transformers/models/dinov2/modeling_dinov2.py:94: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_patches == num_positions and height == width:
/content/transformers.js/transformers/src/transformers/models/dinov2/modeling_dinov2.py:104: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
/content/transformers.js/transformers/src/transformers/models/dinov2/modeling_dinov2.py:109: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))),
/content/transformers.js/transformers/src/transformers/models/dinov2/modeling_dinov2.py:113: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
/content/transformers.js/transformers/src/transformers/models/dinov2/modeling_dinov2.py:113: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
/content/transformers.js/transformers/src/transformers/models/depth_anything/modeling_depth_anything.py:200: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if hidden_state.shape != residual.shape:
/content/transformers.js/transformers/src/transformers/models/depth_anything/modeling_depth_anything.py:349: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  (int(patch_height * self.patch_size), int(patch_width * self.patch_size)),
Post-processing the exported models...
Deduplicating shared (tied) weights...
Validating ONNX model models/LiheYoung/depth-anything-small-hf/model.onnx...
2024-01-26 09:56:43.735130317 [E:onnxruntime:, inference_session.cc:1644 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsamplebase.h:147 onnxruntime::UpsampleBase::UpsampleBase(const onnxruntime::OpKernelInfo&) [ONNXRuntimeError] : 1 : FAIL : upsamplebase.h:365 ScalesValidation 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator

Traceback (most recent call last):
  File "/content/transformers.js/optimum/optimum/exporters/onnx/__main__.py", line 604, in main_export
    validate_models_outputs(
  File "/content/transformers.js/optimum/optimum/exporters/onnx/convert.py", line 174, in validate_models_outputs
    raise exceptions[-1][1]
  File "/content/transformers.js/optimum/optimum/exporters/onnx/convert.py", line 156, in validate_models_outputs
    validate_model_outputs(
  File "/content/transformers.js/optimum/optimum/exporters/onnx/convert.py", line 228, in validate_model_outputs
    _run_validation(
  File "/content/transformers.js/optimum/optimum/exporters/onnx/convert.py", line 280, in _run_validation
    session = PickableInferenceSession(onnx_model.as_posix(), sess_options=session_options, providers=[provider])
  File "/content/transformers.js/optimum/optimum/exporters/onnx/utils.py", line 539, in __init__
    self.sess = ort.InferenceSession(self.model_path, sess_options=sess_options, providers=providers)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 383, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 435, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsamplebase.h:147 onnxruntime::UpsampleBase::UpsampleBase(const onnxruntime::OpKernelInfo&) [ONNXRuntimeError] : 1 : FAIL : upsamplebase.h:365 ScalesValidation 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/content/transformers.js/scripts/convert.py", line 465, in <module>
    main()
  File "/content/transformers.js/scripts/convert.py", line 426, in main
    main_export(**export_kwargs)
  File "/content/transformers.js/optimum/optimum/exporters/onnx/__main__.py", line 628, in main_export
    raise Exception(
Exception: An error occured during validation, but the model was saved nonetheless at models/LiheYoung/depth-anything-small-hf. Detailed error: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/upsamplebase.h:147 onnxruntime::UpsampleBase::UpsampleBase(const onnxruntime::OpKernelInfo&) [ONNXRuntimeError] : 1 : FAIL : upsamplebase.h:365 ScalesValidation 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs with the corresponding outermost 2 scale values being 1 in the Resize operator
.

(error copied from depth_anything export, but it uses dinov2 as a backend)

For all the depth_anything and dinov2 models I've exported and released on the HF hub, I had to manually override the code in the following ways:

  • update positional encoding interpolation to use bilinear instead of bicubic interpolation. Although there were minor differences in logits, it fortunately still left the (visible) output quite similar to the original implementation.
  • fix various casts to python numbers (specifically in resize operators). In fact, I believe that all the casts in the modeling file (int(...) or float(...)) are problematic.

Link to models I've converted with these fixes:

@xenova xenova marked this pull request as draft February 4, 2024 21:07
@pcuenca
Copy link
Member

pcuenca commented May 5, 2024

I followed the same path independently, and can confirm that bilinear instead of bicubic interpolation for the position encodings results in unnoticeable visual differences in the generated depth map.

@xenova
Copy link
Contributor Author

xenova commented Aug 30, 2024

Closed in favor of #2001

@xenova xenova closed this Aug 30, 2024
@xenova xenova deleted the add-dino branch August 30, 2024 18:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants