-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Concat - multibroadcast fix #3096
Conversation
CharlieL7
commented
May 16, 2024
- Resolves normalize_compute_shape: CONCAT: all input dimensions should match along axis 2 #3023
- Adds checks and their respective tests for cases that the concat-multibroadcast matcher cannot currently handle.
- See the issue for a breakdown and Concat broadcast rewrite to handle more cases #3095 to track a rewrite of the this matcher to handle these cases.
- Added documentation and chose new variable names to make what the matcher is doing more clear.
other than the concat axis
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #3096 +/- ##
===========================================
+ Coverage 91.81% 91.82% +0.01%
===========================================
Files 486 486
Lines 18977 18991 +14
===========================================
+ Hits 17423 17438 +15
+ Misses 1554 1553 -1 ☔ View full report in Codecov by Sentry. |
This build is not recommended to merge 🔴 |
❌bert-mrpc-onnx: ERROR - check error outputTraceback (most recent call last):File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main model = migraphx.parse_onnx(model_name, default_dim_value=batch) RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/huggingface-transformers/bert_mrpc1.onnx ❌cadene-resnext101_1: ERROR - check error output2024-05-21 00:17:05.354644647 [W:onnxruntime:, model.cc:183 Model] ONNX Runtime only guarantees support for models stamped with opset version 7 or above for opset domain 'ai.onnx'. Please upgrade your model to opset 7 or higher. For now, this opset 6 model may run depending upon legacy support of some older opset version operators.2024-05-21 00:17:05.360674208 [W:onnxruntime:, transpose_optimizer.cc:28 ApplyImpl] Transpose optimizer failed: Unsupported ONNX opset: 6 Traceback (most recent call last): File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 267, in main sess = ort.InferenceSession(model_name, File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in init self._create_inference_session(providers, provider_options, disabled_optimizers) File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 463, in _create_inference_session sess.initialize_session(providers, provider_options, disabled_optimizers) onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for BatchNormalization(6) node with name '' ❌unet: ERROR - check error outputTraceback (most recent call last):File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 207, in main model = migraphx.parse_onnx(model_name, RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/unet/model.onnx 🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output❌bert_large: ERROR - check error outputTraceback (most recent call last):File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in main() File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main model = migraphx.parse_onnx(model_name, default_dim_value=batch) RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/bert/model.onnx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comment suggestions, and one test case that I believe needs changing because it's ineffective.
@@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3) | |||
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2); | |||
} | |||
|
|||
// Broadcasted batch dim, axis is broadcasted dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're going to have a problem with these tests in the future if any change to simplify_reshapes breaks one of them, in that it's not made clear which of the simplify_reshapes
matchers each test is intended to match. There's currently also no way to check whether a test that's expected to match but then exit without doing anything actually matched. For now, suggest you add to the descriptive comment for each test: what matcher it's supposed to match (or not match). The current test names are very similar to struct find_concat_multibroadcasts
but just different enough that the intent might not be clear to someone new.
A similar problem would come up if the order of passes in simplify_reshapes
changes and causes a match where there was none before, or vice versa--a different matcher modifies an instruction graph before the intended matcher can see it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A similar problem would come up if the order of passes in simplify_reshapes changes and causes a match where there was none before, or vice versa--a different matcher modifies an instruction graph before the intended matcher can see it.
That's possible, but quite unlikely given how small these test cases are. The only way I can think of to actually prevent this is to have passes that only run one matcher.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're going to have a problem with these tests in the future if any change to simplify_reshapes breaks one of them, in that it's not made clear which of the simplify_reshapes matchers each test is intended to match.
I don't follow your logic here. All we would have to do is turn on MIGRAPHX_TRACE_MATCHES
.
src/simplify_reshapes.cpp
Outdated
const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); | ||
for(std::size_t ax = 0; ax < front_in_lens.size(); ++ax) | ||
{ | ||
if(ax != concat_op.axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when the axis is negative? Do we need to normalize the operator first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that normalize_compute_shape()
would replace the operator with another one with a normalized axis
attribute?
Co-authored-by: Brian Pickrell <[email protected]>
Co-authored-by: Brian Pickrell <[email protected]>
Co-authored-by: Brian Pickrell <[email protected]>
auto m_original = m; | ||
run_pass(m); | ||
EXPECT(m == m_original); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add TODO comments on these test that we will simplify them in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand what you mean by simplify these in the future. I'm only intending to change what the comment says.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean when we rewrite the matcher for doing a broadcast before if there's atleast one common broadcast axis?
Another test case that should be added is for the inputs For bonus points, we should also reject |