Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export support for DETA model #2025

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

TheMattBin
Copy link

@TheMattBin TheMattBin commented Sep 15, 2024

What does this PR do?

I try to add support for DETA model by running:
optimum-cli export onnx -m 'jozhang97/deta-resnet-50' --task 'object-detection' --framework 'pt' deta_onnx

However, I got issues as mentioned in #2018, so I adjusted HF deta like below to see the issue

- topk = self.two_stage_num_proposals
+ topk = min(20, topk)
  proposal_logit = enc_outputs_class[..., 0]
  
  if self.assign_first_stage:
      proposal_boxes = center_to_corners_format(enc_outputs_coord_logits.sigmoid().float()).clamp(0, 1)
      topk_proposals = []
      for b in range(batch_size):
          prop_boxes_b = proposal_boxes[b]
          prop_logits_b = proposal_logit[b]
  
          # pre-nms per-level topk
-         pre_nms_topk = 1000
+         pre_nms_topk = 50
          pre_nms_inds = []
          for lvl in range(len(spatial_shapes)):
              lvl_mask = level_ids == lvl
              pre_nms_inds.append(torch.topk(prop_logits_b.sigmoid() * lvl_mask, pre_nms_topk)[1])
          pre_nms_inds = torch.cat(pre_nms_inds)
logs after hardcode a bit
2024-09-15 04:44:46.389938: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: 
   Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-15 04:44:46.410966: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-15 04:44:46.417727: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-15 04:44:47.482377: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Using the export variant default. Available variants are:
  - default: The default ONNX variant.

***** Exporting submodel 1/1: DetaForObjectDetection *****
Using framework PyTorch: 2.4.0+cu121
/usr/local/lib/python3.10/dist-packages/transformers/models/resnet/modeling_resnet.py:91: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if num_channels != self.num_channels:
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1689: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1146: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
for level, (height, width) in enumerate(spatial_shapes):
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:651: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:672: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if num_coordinates == 2:
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:531: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:531: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:534: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
for level_id, (height, width) in enumerate(value_spatial_shapes):
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1559: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
for level, (height, width) in enumerate(spatial_shapes):
300
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1750: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
for lvl in range(len(spatial_shapes)):
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1763: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
if len(keep_inds) < self.two_stage_num_proposals:
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1765: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
f"[WARNING] nms proposals ({len(keep_inds)}) < {self.two_stage_num_proposals}, running"
[WARNING] nms proposals (5) < 300, running naive topk
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1771: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
q_per_l = topk // len(spatial_shapes)
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1774: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
== torch.arange(len(spatial_shapes), device=level_ids.device)[:, None]
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1780: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if keep_inds_mask.sum() < topk:
[WARNING] nms proposals (5) < 300, running naive topk
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1332: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if reference_points.shape[-1] == 4:
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:774: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:810: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:678: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
elif num_coordinates == 4:
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:1373: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if reference_points.shape[-1] == 4:
/usr/local/lib/python3.10/dist-packages/transformers/models/deprecated/deta/modeling_deta.py:2000: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if reference.shape[-1] == 4:
/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:1738: UserWarning: The exported ONNX model failed ONNX shape inference. The model will not be executable by the ONNX Runtime. If this is unintended and you believe there is a bug, please report an issue at https://github.com/pytorch/pytorch/issues. Error reported by strict ONNX shape inference: [ShapeInferenceError] (op_type:CumSum, node name: /model/CumSum): x typestr: T, has unsupported type: tensor(bool) (Triggered internally at ../torch/csrc/jit/serialization/export.cpp:1469.)
_C._check_onnx_proto(proto)
Traceback (most recent call last):
File "/usr/local/bin/optimum-cli", line 8, in <module>
  sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/optimum/commands/optimum_cli.py", line 208, in main
  service.run()
File "/usr/local/lib/python3.10/dist-packages/optimum/commands/export/onnx.py", line 265, in run
  main_export(
File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/__main__.py", line 374, in main_export
  onnx_export_from_model(
File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 1171, in onnx_export_from_model
  _, onnx_outputs = export_models(
File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 776, in export_models
  export(
File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 910, in export
  config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/base.py", line 306, in fix_dynamic_axes
  session = InferenceSession(model_path.as_posix(), providers=providers, sess_options=session_options)
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
  self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 480, in _create_inference_session
  sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from deta_test/model.onnx failed:This is an invalid model. Type Error: Type 'tensor(bool)' of input parameter (/model/Equal_7_output_0) of operator (CumSum) in node (/model/CumSum_1) is invalid.
Fixes #2018

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@IlyasMoutawwakil
Copy link
Member

Hi ! are you able to export the model now ?

@TheMattBin
Copy link
Author

Sorry, quite busy this week. Will look at it again this weekend!

@TheMattBin
Copy link
Author

It seems like able to obtain onnx model but got issues when running inference during testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding Support for DETA Model
2 participants