-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pass to rewrite pow2 div #2844
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #2844 +/- ##
===========================================
+ Coverage 91.75% 91.76% +0.01%
===========================================
Files 473 475 +2
Lines 17958 17982 +24
===========================================
+ Hits 16478 16502 +24
Misses 1480 1480 ☔ View full report in Codecov by Sentry. |
Check results before merge 🔆 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
This should go into a seperate pass. Maybe it could be called |
@pfultz2 I've updated the PR
|
c5977a2
to
8973085
Compare
This pass is needed for Llama2 fp16 where RMSNorm calc can go out of bounds and output inf values. The pass rewrites x^2/n to (x/sqrt(n))^2.
Rewrite operators in low precision types to avoid going out of precision bounds.
e72105b
to
2e0a5c3
Compare
@pfultz2 I've made the proposed chanegs |
auto ins = r.result; | ||
auto n = r.instructions["n"]; | ||
auto x = r.instructions["x"]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If instead of:
x^2/n --> (x/sqrt(n))^2,
If the following were applied, would there be any loss of accuracy?
x^2/n --> (x/n) * x
Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's actually works, the accuracy is the same with that. This solution saves us one instruction, also removes the sqrt
. Thanks for the idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
@pfultz2 @umangyadav @lakhinderwalia I've updated the PR with what @lakhinderwalia proposed to use x^2/n --> (x/n) * x. The tests are passing with that change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
This pass is needed for Llama2 fp16 where RMSNorm calc can go out of bounds outputting inf values.
The pass rewrites
x^2/n
to(x/sqrt(n))^2
.