-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Scan operator #2936
Conversation
music-dino
commented
Apr 1, 2024
- Implement ONNX parsing support for the Scan operator
- Resolves Scan operator is unsupported migraphx-benchmark/AMDMIGraphX#116
Check results before merge 🔆 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #2936 +/- ##
===========================================
+ Coverage 92.24% 92.27% +0.03%
===========================================
Files 495 497 +2
Lines 19849 19985 +136
===========================================
+ Hits 18309 18441 +132
- Misses 1540 1544 +4 ☔ View full report in Codecov by Sentry. |
src/onnx/parse_scan.cpp
Outdated
} | ||
|
||
std::vector<int64_t> | ||
parse_dirs(onnx_parser::node_info& info, const std::string& name, long expected_size) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parse_dirs(onnx_parser::node_info& info, const std::string& name, long expected_size) const | |
parse_dirs(onnx_parser::node_info& info, const std::string& name, size_t expected_size) const |
for consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| Scan | ✅ | UINT8, UINT16, | ``identity``, | | ||
| | | UINT32, UINT64, | ``sequence`` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why identity and sequence are mentioned here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're mentioned in the Loop op entry as well. Since scan relies on Loop, I thought I'd carry over the info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. I do not know what that means though. @attila-dusnoki-htec can you comment this limitation is about ?
for(auto i = 0; i < n; ++i) | ||
new_params.push_back( | ||
mod->add_parameter("state_var" + std::to_string(i), params[i]->get_shape())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use std::transform
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use the loop iterator value in the loop body within the std::to_string as well, so this isn't that well suited to std::transform.
{ | ||
std::vector<int64_t> perm(rank); | ||
std::iota(perm.begin(), perm.end(), 0); | ||
std::copy(perm.begin() + 1, perm.begin() + 1 + axis, perm.begin()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if begin() + 1 + axis
would go out of bound if axis is the last axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For rank r, the last axis would be r-1, meaning the range would be [begin + 1, begin + r). Since the perm vector has r elements, begin + r is the first element beyond the last element, making it the valid end iterator.
src/onnx/parse_scan.cpp
Outdated
// Loop scan_outputs are concatenated along axis 0, so it must be transposed to the | ||
// index specified by the corresponding scan_output_axis | ||
auto perm = make_perm_for_scan_out(o->get_shape().ndim(), scan_output_axes[i]); | ||
ret.push_back(info.add_instruction(make_op("transpose", {{"permutation", perm}}), o)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am missing something here. Can you explain why the transpose is necessary ? Transpose will probably not result in correct element order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've expanded the comment, hopefully it's more helpful now.
There are a couple of test cases that cover this scenario, and the correct result is produced.
mod->replace_return(returns); | ||
} | ||
|
||
std::vector<int64_t> make_perm_for_scan_out(int64_t rank, int64_t axis) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this. can you add a docstring on what this is trying to do ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a comment.
src/include/migraphx/op/loop.hpp
Outdated
int64_t iter, | ||
int64_t iter_num) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to name them as curr_iter
and num_iters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hitting a failure. Build this and run "make check".
[2024-07-17T00:46:28.905Z] 358/358 Test #357: test_py_3.10_backend ......................................................***Failed 965.59 sec
[2024-07-17T00:46:28.905Z] .s.s.s.s.sssssssss.s.s.sssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssss.sss.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.s.s.sssssssss.sssssssssssssssssss.s.sssssssssssssssssssssssssssssssss.s.s.s.s.s.s.s.sssssssssssssssssssss.s.s.s.sssssssssssssssssssssssss.s.s.s.sssssssssssssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssssssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.sssssssssssss.s.s.s.sssssss.sss.s.s.s.s.s.s.s.s.s.sssssssssssssssssss.s.s.s.sssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sss.s.s.s.s.s.s.s.s.s.s.s.s.sssss.s.s.sssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sss.s.s.s.s.s.s.s.s.s.s.s.s.s.sssssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssss.sssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssssssssssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssssssss.s.s.sss.s.s.s.s.s.s.s.s.s.s.s.s.s.sss.s.s.s.s.s.s.s.s.s.s.s.s.s.sssss.s.s.s.s.s.s.s.s.sssss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sss.sssssssss.sssssssssss.s.s.sssssssssssssssssssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssss.s.sssssssssssss.s.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.s.sssssssssssssssssssssssssssss.s.sssssssssssss.s.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.s.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.s.s.s.s.sEsss.s.s.s.s.s.s.s.s.s.s.s.s.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.s.s.s.s.s.sssssssssssssssssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssssssssssssssss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssssssssssssssssssssss.s.sssssssssssssssssssss.s.s.s.s.s.s.s.s.s.s.sssssssssssssss.s.s.s.s.s.sssssssssssssssssssssss.s.s.s.s.s.s.s.sss.sssssss.sssss.sssssssssss.sssssssssssssssssssssssssssssss.s.sss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sssss.s.s.s.s.s.s.s.s.s.s.s.sss.sss.s.s.s.sss.sss.sss.s.s.s.sss.s.s.s.s.s.s.s.s.s.s.s.sssssssssss.s.s.s.s.s.sss.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.sss.s.s.s.s.s.sssssssssssssssssssssssssssss.s.s.sssssssssssss
[2024-07-17T00:46:28.905Z] ======================================================================
[2024-07-17T00:46:28.905Z] ERROR: test_scan9_sum_cpu (main.OnnxBackendNodeModelTest)
[2024-07-17T00:46:28.905Z] ----------------------------------------------------------------------
[2024-07-17T00:46:28.905Z] Traceback (most recent call last):
[2024-07-17T00:46:28.905Z] File "/usr/local/lib/python3.10/dist-packages/onnx/backend/test/runner/init.py", line 290, in device_test_func
[2024-07-17T00:46:28.905Z] return test_func(*args, device=device, **kwargs)
[2024-07-17T00:46:28.905Z] File "/usr/local/lib/python3.10/dist-packages/onnx/backend/test/runner/init.py", line 382, in run
[2024-07-17T00:46:28.905Z] prepared_model = self.backend.prepare(model, device)
[2024-07-17T00:46:28.905Z] File "/home/jenkins/workspace/AMDMIGraphX_PR-2936/build/lib/onnx_migraphx/backend.py", line 125, in prepare
[2024-07-17T00:46:28.905Z] return cls.prepare(bin, device, **kwargs)
[2024-07-17T00:46:28.905Z] File "/home/jenkins/workspace/AMDMIGraphX_PR-2936/build/lib/onnx_migraphx/backend.py", line 112, in prepare
[2024-07-17T00:46:28.905Z] inf = migraphx.parse_onnx_buffer(model)
[2024-07-17T00:46:28.905Z] RuntimeError: /home/jenkins/workspace/AMDMIGraphX_PR-2936/src/onnx/parse_scan.cpp:124: parse: Slice: Sliced scan input 0 shape {float_type, {3}, {1}} does not match corresponding body input shape {float_type, {2}, {1}}
[2024-07-17T00:46:28.905Z]
[2024-07-17T00:46:28.905Z] ----------------------------------------------------------------------
[2024-07-17T00:46:28.905Z] Ran 2634 tests in 963.388s
[2024-07-17T00:46:28.905Z]
[2024-07-17T00:46:28.905Z] FAILED (errors=1, skipped=1896)
[2024-07-17T00:46:28.905Z] Default GPU device is used ....
The issue is resolved. |