Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in find_concat_op when input broadcasts are on different axes #3242

Merged
merged 4 commits into from
Jul 4, 2024

Conversation

shivadbhavsar
Copy link
Contributor

Fixes issue #3224
Which in turn fixes one of the issues exposed by #3104

@shivadbhavsar shivadbhavsar added bugfix Fixes a bug found in the code. TorchMIGraphX labels Jul 2, 2024
@shivadbhavsar shivadbhavsar self-assigned this Jul 2, 2024
@shivadbhavsar shivadbhavsar requested a review from causten as a code owner July 2, 2024 21:48
@shivadbhavsar shivadbhavsar linked an issue Jul 2, 2024 that may be closed by this pull request
Copy link

codecov bot commented Jul 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.20%. Comparing base (1b7653d) to head (0150cbc).
Report is 149 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3242      +/-   ##
===========================================
+ Coverage    92.16%   92.20%   +0.03%     
===========================================
  Files          493      493              
  Lines        19690    19700      +10     
===========================================
+ Hits         18148    18164      +16     
+ Misses        1542     1536       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Jul 2, 2024

Test Batch Rate new
0150cb
Rate old
e2d6ef
Diff Compare
torchvision-resnet50 64 1,741.30 1,749.23 -0.45%
torchvision-resnet50_fp16 64 4,064.50 4,078.43 -0.34%
torchvision-densenet121 32 1,459.79 1,467.35 -0.52%
torchvision-densenet121_fp16 32 2,520.24 2,532.63 -0.49%
torchvision-inceptionv3 32 885.25 888.93 -0.41%
torchvision-inceptionv3_fp16 32 1,478.60 1,483.54 -0.33%
cadene-inceptionv4 16 410.51 411.78 -0.31%
cadene-resnext64x4 16 417.53 419.22 -0.40%
slim-mobilenet 64 3,987.61 4,003.58 -0.40%
slim-nasnetalarge 64 100.56 100.96 -0.40%
slim-resnet50v2 64 1,672.18 1,678.56 -0.38%
bert-mrpc-onnx 8 612.15 616.93 -0.77%
bert-mrpc-tf 1 276.91 279.35 -0.87%
pytorch-examples-wlang-gru 1 321.83 364.12 -11.61% 🔴
pytorch-examples-wlang-lstm 1 291.65 294.38 -0.93%
torchvision-resnet50_1 1 467.29 465.00 0.49%
cadene-dpn92_1 1 246.51 247.39 -0.36%
cadene-resnext101_1 1 203.22 204.15 -0.46%
onnx-taau-downsample 1 205.50 206.13 -0.31%
dlrm-criteoterabyte 1 22.83 22.91 -0.36%
dlrm-criteoterabyte_fp16 1 42.57 42.71 -0.32%
agentmodel 1 6,342.24 6,228.16 1.83%
unet_fp16 2 34.19 34.30 -0.31%
resnet50v1_fp16 1 584.69 595.94 -1.89%
resnet50v1_int8 1 579.90 570.46 1.66%
bert_base_cased_fp16 64 642.42 645.74 -0.51%
bert_large_uncased_fp16 32 197.82 198.84 -0.51%
bert_large_fp16 1 117.07 117.52 -0.38%
distilgpt2_fp16 16 1,205.09 1,209.51 -0.37%
yolov5s 1 301.13 300.77 0.12%
tinyllama 1 23.21 23.32 -0.46%
vicuna-fastchat 1 133.20 133.25 -0.04%
whisper-tiny-encoder 1 243.31 244.46 -0.47%
whisper-tiny-decoder 1 255.30 256.38 -0.42%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@@ -811,6 +811,18 @@ struct find_concat_op
op.attributes().contains("pointwise");
}

static bool is_valid_concat(std::vector<instruction_ref> ins, size_t axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there is another bug on line 886

   auto pred = [](auto i, auto j) {
            return i->get_operator() == j->get_operator() and
                   i->inputs().size() == i->inputs().size() and
                   i->outputs().size() == i->outputs().size();
        };

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking in #3245

@TedThemistokleous TedThemistokleous self-requested a review July 4, 2024 14:20
@umangyadav umangyadav merged commit d9367cb into develop Jul 4, 2024
46 checks passed
@umangyadav umangyadav deleted the find_concat_fix branch July 4, 2024 15:34
umangyadav pushed a commit that referenced this pull request Jul 4, 2024
lajagapp pushed a commit to lajagapp/AMDMIGraphX that referenced this pull request Aug 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bugfix Fixes a bug found in the code. TorchMIGraphX
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug in find_concat_op
4 participants