-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent collapsing batch dims in dot ops with constants #2823
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #2823 +/- ##
========================================
Coverage 91.92% 91.93%
========================================
Files 489 489
Lines 19275 19301 +26
========================================
+ Hits 17719 17744 +25
- Misses 1556 1557 +1 ☔ View full report in Codecov by Sentry. |
SDXL Pref results for reference: Torch-MIGraphX (end to end): ONNX Unet (4x attn trim): As expected, it doesnt affect the onnx version much because there is an extra convert in the middle. Once the convert is handled, the perf number reduces to 5.47ms. |
src/simplify_reshapes.cpp
Outdated
|
||
auto sq_const = | ||
m.insert_instruction(mbr, make_op("squeeze", {{"axes", sq_axes}}), constant); | ||
m.replace_instruction(mbr, mbr->get_operator(), sq_const); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we replace it with broadcast
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just removing any unnecessary preceding dims in literals eg. {1, 1, 640, 640) which are later broadcasted to something like {2, 32, 640, 640}. Would broadcast work for this? I thought it only does 1 axis
This build is not recommended to merge 🔴 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
This simplifies many
reshape -> dot -> reshape
patterns that are not handled by thefind_reshape_reshape_dot
pass (ie. in gemms where one input is a constant).This also simplifies the reshape found in #2736