Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concat - multibroadcast fix #3096

Merged
merged 13 commits into from
May 21, 2024
Merged

Concat - multibroadcast fix #3096

merged 13 commits into from
May 21, 2024

Conversation

CharlieL7
Copy link
Collaborator

@CharlieL7 CharlieL7 added bugfix Fixes a bug found in the code. TorchMIGraphX labels May 16, 2024
@CharlieL7 CharlieL7 self-assigned this May 16, 2024
@CharlieL7 CharlieL7 requested a review from causten as a code owner May 16, 2024 16:19
Copy link

codecov bot commented May 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.82%. Comparing base (93d77e9) to head (b314f90).
Report is 150 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3096      +/-   ##
===========================================
+ Coverage    91.81%   91.82%   +0.01%     
===========================================
  Files          486      486              
  Lines        18977    18991      +14     
===========================================
+ Hits         17423    17438      +15     
+ Misses        1554     1553       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented May 16, 2024

Test Batch Rate new
b314f9
Rate old
3b3eca
Diff Compare
torchvision-resnet50 64 2,959.02 2,950.34 0.29%
torchvision-resnet50_fp16 64 6,565.81 6,567.55 -0.03%
torchvision-densenet121 32 2,422.22 2,421.55 0.03%
torchvision-densenet121_fp16 32 3,942.03 3,971.69 -0.75%
torchvision-inceptionv3 32 1,658.08 1,659.48 -0.08%
torchvision-inceptionv3_fp16 32 2,599.24 2,599.62 -0.01%
cadene-inceptionv4 16 777.83 776.25 0.20%
cadene-resnext64x4 16 740.79 740.67 0.02%
slim-mobilenet 64 6,919.60 6,926.37 -0.10%
slim-nasnetalarge 64 177.25 177.12 0.08%
slim-resnet50v2 64 2,877.84 2,877.81 0.00%
bert-mrpc-onnx 8 1,064.93 1,064.78 0.01%
bert-mrpc-tf 1 485.45 499.76 -2.86%
pytorch-examples-wlang-gru 1 428.92 431.20 -0.53%
pytorch-examples-wlang-lstm 1 388.23 349.07 11.22% 🔆
torchvision-resnet50_1 1 803.28 794.82 1.07%
cadene-dpn92_1 1 444.01 397.45 11.72% 🔆
cadene-resnext101_1 1 367.29 361.22 1.68%
onnx-taau-downsample 1 349.34 349.66 -0.09%
dlrm-criteoterabyte 1 33.49 33.64 -0.45%
dlrm-criteoterabyte_fp16 1 56.28 56.59 -0.55%
agentmodel 1 7,251.16 7,792.96 -6.95% 🔴
unet_fp16 2 56.25 57.44 -2.08%
resnet50v1_fp16 1 908.31 902.13 0.68%
resnet50v1_int8 1 781.68 822.26 -4.94% 🔴
bert_base_cased_fp16 64 1,011.41 1,012.77 -0.13%
bert_large_uncased_fp16 32 316.50 316.81 -0.10%
bert_large_fp16 1 nan nan nan%
distilgpt2_fp16 16 1,992.60 1,994.70 -0.11%
yolov5s 1 515.12 514.72 0.08%
tinyllama 1 45.03 45.02 0.01%
vicuna-fastchat 1 177.69 180.47 -1.54%
whisper-tiny-encoder 1 405.08 403.06 0.50%
whisper-tiny-decoder 1 427.94 424.49 0.81%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented May 16, 2024


❌bert-mrpc-onnx: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/huggingface-transformers/bert_mrpc1.onnx


     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

❌cadene-resnext101_1: ERROR - check error output2024-05-21 00:17:05.354644647 [W:onnxruntime:, model.cc:183 Model] ONNX Runtime only guarantees support for models stamped with opset version 7 or above for opset domain 'ai.onnx'. Please upgrade your model to opset 7 or higher. For now, this opset 6 model may run depending upon legacy support of some older opset version operators.
2024-05-21 00:17:05.360674208 [W:onnxruntime:, transpose_optimizer.cc:28 ApplyImpl] Transpose optimizer failed: Unsupported ONNX opset: 6
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 267, in main
sess = ort.InferenceSession(model_name,
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in init
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 463, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for BatchNormalization(6) node with name ''


     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

❌unet: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 207, in main
model = migraphx.parse_onnx(model_name,
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/unet/model.onnx


     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


❌bert_large: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/bert/model.onnx


     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

src/simplify_reshapes.cpp Outdated Show resolved Hide resolved
src/simplify_reshapes.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@bpickrel bpickrel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comment suggestions, and one test case that I believe needs changing because it's ineffective.

@@ -918,9 +918,10 @@ TEST_CASE(concat_multibroadcasts3)
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
}

// Broadcasted batch dim, axis is broadcasted dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to have a problem with these tests in the future if any change to simplify_reshapes breaks one of them, in that it's not made clear which of the simplify_reshapes matchers each test is intended to match. There's currently also no way to check whether a test that's expected to match but then exit without doing anything actually matched. For now, suggest you add to the descriptive comment for each test: what matcher it's supposed to match (or not match). The current test names are very similar to struct find_concat_multibroadcasts but just different enough that the intent might not be clear to someone new.

A similar problem would come up if the order of passes in simplify_reshapes changes and causes a match where there was none before, or vice versa--a different matcher modifies an instruction graph before the intended matcher can see it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar problem would come up if the order of passes in simplify_reshapes changes and causes a match where there was none before, or vice versa--a different matcher modifies an instruction graph before the intended matcher can see it.

That's possible, but quite unlikely given how small these test cases are. The only way I can think of to actually prevent this is to have passes that only run one matcher.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to have a problem with these tests in the future if any change to simplify_reshapes breaks one of them, in that it's not made clear which of the simplify_reshapes matchers each test is intended to match.

I don't follow your logic here. All we would have to do is turn on MIGRAPHX_TRACE_MATCHES.

test/simplify_reshapes_test.cpp Outdated Show resolved Hide resolved
test/simplify_reshapes_test.cpp Outdated Show resolved Hide resolved
test/simplify_reshapes_test.cpp Outdated Show resolved Hide resolved
test/simplify_reshapes_test.cpp Outdated Show resolved Hide resolved
const auto& front_in_lens = mb_inputs.front()->get_shape().lens();
for(std::size_t ax = 0; ax < front_in_lens.size(); ++ax)
{
if(ax != concat_op.axis)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when the axis is negative? Do we need to normalize the operator first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that normalize_compute_shape() would replace the operator with another one with a normalized axis attribute?

auto m_original = m;
run_pass(m);
EXPECT(m == m_original);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add TODO comments on these test that we will simplify them in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand what you mean by simplify these in the future. I'm only intending to change what the comment says.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean when we rewrite the matcher for doing a broadcast before if there's atleast one common broadcast axis?

@pfultz2
Copy link
Collaborator

pfultz2 commented May 16, 2024

Another test case that should be added is for the inputs {64, 64} and {60, 1, 192} with axis=2. We should not do the transformation for this case ever(so no TODOs in the comments).

For bonus points, we should also reject {1, 64, 64} and {60, 1, 192} even when they are the same rank. For that you need to check there is a common broadcasted axes across all inputs. There is an example of how to that here in this old version of find_inner_broadcast(since then its been replaced), but you could use it here.

@CharlieL7 CharlieL7 requested a review from lakhinderwalia May 20, 2024 19:02
@CharlieL7 CharlieL7 assigned lakhinderwalia and pfultz2 and unassigned pfultz2 May 20, 2024
@CharlieL7 CharlieL7 removed the request for review from lakhinderwalia May 20, 2024 19:10
@CharlieL7 CharlieL7 requested review from bpickrel and pfultz2 May 20, 2024 19:10
@kahmed10 kahmed10 merged commit 1f07af9 into develop May 21, 2024
44 checks passed
@kahmed10 kahmed10 deleted the concat_mb_fix branch May 21, 2024 09:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bugfix Fixes a bug found in the code. TorchMIGraphX
Projects
None yet
Development

Successfully merging this pull request may close these issues.

normalize_compute_shape: CONCAT: all input dimensions should match along axis 2
6 participants