Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize broadcast + transpose for nonscalars #2271

Merged
merged 12 commits into from
Oct 14, 2023
Merged

Conversation

kahmed10
Copy link
Collaborator

@kahmed10 kahmed10 commented Oct 2, 2023

Includes optimizations to find_inner_broadcast because otherwise a bunch of multibroadcasts were appearing after the transformation.
TODO:

  • see if scalar case can be written generically as well (unsqueeze is giving problems with scalar)

@codecov
Copy link

codecov bot commented Oct 2, 2023

Codecov Report

Merging #2271 (d084b4c) into develop (6816143) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head d084b4c differs from pull request most recent head 870cff3. Consider uploading reports for the commit 870cff3 to get more accurate results

@@           Coverage Diff            @@
##           develop    #2271   +/-   ##
========================================
  Coverage    91.32%   91.33%           
========================================
  Files          434      434           
  Lines        16245    16262   +17     
========================================
+ Hits         14836    14853   +17     
  Misses        1409     1409           
Files Coverage Δ
src/simplify_algebra.cpp 96.88% <100.00%> (+0.03%) ⬆️
src/simplify_reshapes.cpp 98.75% <100.00%> (+0.02%) ⬆️

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Oct 2, 2023

Test Batch Rate new
ea62d7
Rate old
a3cf99
Diff Compare
torchvision-resnet50 64 2,322.90 2,321.52 0.06%
torchvision-resnet50_fp16 64 5,350.99 5,353.90 -0.05%
torchvision-densenet121 32 1,849.78 1,844.93 0.26%
torchvision-densenet121_fp16 32 3,417.86 3,416.32 0.05%
torchvision-inceptionv3 32 1,294.75 1,292.15 0.20%
torchvision-inceptionv3_fp16 32 2,538.37 2,537.75 0.02%
cadene-inceptionv4 16 619.84 619.76 0.01%
cadene-resnext64x4 16 589.30 588.72 0.10%
slim-mobilenet 64 7,213.86 7,203.01 0.15%
slim-nasnetalarge 64 236.59 236.36 0.10%
slim-resnet50v2 64 2,555.11 2,555.32 -0.01%
bert-mrpc-onnx 8 824.85 825.15 -0.04%
bert-mrpc-tf 1 388.46 389.70 -0.32%
pytorch-examples-wlang-gru 1 296.50 297.68 -0.40%
pytorch-examples-wlang-lstm 1 311.70 318.37 -2.10%
torchvision-resnet50_1 1 546.31 546.51 -0.04%
torchvision-inceptionv3_1 1 300.66 306.80 -2.00%
cadene-dpn92_1 1 352.67 355.12 -0.69%
cadene-resnext101_1 1 218.21 219.52 -0.59%
slim-vgg16_1 1 224.13 223.90 0.10%
slim-mobilenet_1 1 1,510.54 1,505.39 0.34%
slim-inceptionv4_1 1 217.16 218.69 -0.70%
onnx-taau-downsample 1 306.07 306.66 -0.19%
dlrm-criteoterabyte 1 21.68 21.70 -0.09%
dlrm-criteoterabyte_fp16 1 40.73 40.64 0.22%
agentmodel 1 5,817.17 5,866.62 -0.84%
unet_fp16 2 55.80 55.11 1.25%
resnet50v1_fp16 1 751.68 767.81 -2.10%
bert_base_cased_fp16 64 971.10 970.20 0.09%
bert_large_uncased_fp16 32 305.12 304.76 0.12%
bert_large_fp16 1 166.83 166.79 0.02%
distilgpt2_fp16 16 1,351.33 1,350.58 0.06%

This build is OK for merge ✅

@migraphx-bot
Copy link
Collaborator


    :white_check_mark:bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

    :white_check_mark:bert-mrpc-tf: PASSED: MIGraphX meets tolerance

    :white_check_mark:pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

    :white_check_mark:pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

    :white_check_mark:torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:torchvision-inceptionv3_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:cadene-dpn92_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:cadene-resnext101_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:slim-vgg16_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:slim-mobilenet_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:slim-inceptionv4_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

    :white_check_mark:agentmodel: PASSED: MIGraphX meets tolerance

    :white_check_mark:unet: PASSED: MIGraphX meets tolerance

    :white_check_mark:resnet50v1: PASSED: MIGraphX meets tolerance

🔴bert_base_cased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


    :white_check_mark:bert_large: PASSED: MIGraphX meets tolerance

🔴distilgpt2_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

@kahmed10 kahmed10 marked this pull request as ready for review October 4, 2023 20:44
@kahmed10 kahmed10 requested review from umangyadav, CharlieL7 and bpickrel and removed request for CharlieL7 October 4, 2023 20:44
@kahmed10 kahmed10 self-assigned this Oct 4, 2023
@TedThemistokleous TedThemistokleous added enhancement New feature or request Cleanup Cleans up code from stale bits/warnings/previous changes for a previous feature PR dependencies Pull requests that update a dependency file Matchers Updates or adds a change to compile time Matchers and removed Cleanup Cleans up code from stale bits/warnings/previous changes for a previous feature PR labels Oct 6, 2023
Copy link
Contributor

@bpickrel bpickrel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code and tests look good, but

  1. Update copyright date
  2. I see that of all the structs defined in this file, only two have brief comments outlining the transformation that struct defines. Could you take the opportunity to add one for this struct too?

Comment on lines +681 to +682
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Includes optimizations to find_inner_broadcast because otherwise a bunch of multibroadcasts were appearing after the transformation

Do you have a test for this specific case where multibroadcasts are appearing after find_inner_broadcast that are not being cleaned up by simplify_reshapes ? because for this test it will add multibroadcast but later i think it will get cleaned up by find_nop_reshaper.

Copy link
Collaborator Author

@kahmed10 kahmed10 Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this exact test will fail currently on develop and insert a bunch of multibroadcasts:

./bin/test_simplify_algebra_test simplify_inner_broadcast_no_common_axis
[   RUN    ] simplify_inner_broadcast_no_common_axis
void simplify_inner_broadcast_no_common_axis()
/code/AMDMIGraphX/test/simplify_algebra_test.cpp:686:
    FAILED: m1 == m2 [ y = @param:y -> int32_type, {1, 5, 1}, {5, 1, 1}, target_id=0
x = @param:x -> int32_type, {5, 10}, {10, 1}, target_id=0
@2 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](x) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@3 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](y) -> int32_type, {1, 5, 10}, {5, 1, 0}, target_id=0
@4 = add(@2,@3) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@5 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@4) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@6 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@5) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@7 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@6) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@8 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@7) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@9 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@8) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@10 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@9) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@11 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@10) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@12 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](@11) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@13 = pass(@12) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
 == y = @param:y -> int32_type, {1, 5, 1}, {5, 1, 1}, target_id=0
x = @param:x -> int32_type, {5, 10}, {10, 1}, target_id=0
@2 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](x) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@3 = multibroadcast[out_lens={1, 5, 10},out_dyn_dims={}](y) -> int32_type, {1, 5, 10}, {5, 1, 0}, target_id=0
@4 = add(@2,@3) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
@5 = pass(@4) -> int32_type, {1, 5, 10}, {0, 10, 1}, target_id=0
 ]
[  FAILED  ] simplify_inner_broadcast_no_common_axis: Test failure
[==========] 1 tests ran
[  FAILED  ] 1 tests failed
[  FAILED  ] simplify_inner_broadcast_no_common_axis

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those multi broadcasts should get cleaned up by find_nop_reshaper later

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also where all those broadcasts are being added. Looks to me that find_inner_broadcast would only add one multibroadcast after add

@umangyadav
Copy link
Member

see if scalar case can be written generically as well (unsqueeze is giving problems with scalar)

inside simplify_reshapes, it is only adding unsqueeze if it is not scalar. Is it still an issue ?

@kahmed10
Copy link
Collaborator Author

see if scalar case can be written generically as well (unsqueeze is giving problems with scalar)

inside simplify_reshapes, it is only adding unsqueeze if it is not scalar. Is it still an issue ?

This is the behavior of unsqueeze on a scalar:

x = @param:x -> float_type, {1}, {0}, target_id=0
@2 = unsqueeze[axes={0, 1},steps={}](x) -> float_type, {1}, {1}, target_id=0

Looks like it completely ignores the extra axes and instead converts from scalar to literal with stride 1.

}
}
// if no common broadcast axis, transformation is not useful
if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use std::none_of instead.

@causten causten merged commit 271eedd into develop Oct 14, 2023
14 of 15 checks passed
@causten causten deleted the bcast_transpose_generic branch October 14, 2023 03:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Pull requests that update a dependency file enhancement New feature or request Matchers Updates or adds a change to compile time Matchers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Skip find_inner_broadcast() if there is no common broadcast axis Rewrite broadcast+transpose
7 participants