-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add propagate_precision pass #2853
base: develop
Are you sure you want to change the base?
Conversation
@@ -0,0 +1,20 @@ | |||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
License
bool is_integral() const { return std::is_integral<type>{}; } | ||
bool is_signed() const { return std::is_signed<type>{}; } | ||
bool is_unsigned() const { return std::is_unsigned<type>{}; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why auto
was changed to bool
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is for readability/clarity. I don't see why these shouldn't resolve to anything but bool? Unless we want to use value here specified by stl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing x.is_integral() != y.is_integral()
will fail to compile because they will be different types using auto
. So I explicitly convert it to bool
instead.
friend bool operator>=(const precision& xp, const precision& yp) | ||
{ | ||
return (xp > yp) or (xp == yp); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might seem like an odd ask but why not make these xor vs or? If one is true then the other doesn't matter as the result shouldn't be true. Anyway adding xor instead of or here can speed things up / check for errors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xor
wont short circuit.
@@ -0,0 +1,191 @@ | |||
#include <migraphx/propagate_precision.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add license here too
@@ -0,0 +1,158 @@ | |||
#include <migraphx/propagate_precision.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
License
auto mul = m2.add_instruction(migraphx::make_op("mul"), sqrt, y); | ||
m2.add_return({mul}); | ||
} | ||
EXPECT(m1.sort() == m2.sort()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this is done to preserve precision throughout all these divides/square roots but are we not worried about the added overhead here now? we've just converted the fp16 set of ops to double, or is compute not a concern here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we are converting to double anyways. For elementwise, the compute shouldn't be that much overhead since these are all essentially unary operators.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #2853 +/- ##
===========================================
+ Coverage 91.75% 91.76% +0.01%
===========================================
Files 473 475 +2
Lines 17958 18043 +85
===========================================
+ Hits 16478 16558 +80
- Misses 1480 1485 +5 ☔ View full report in Codecov by Sentry. |
This build is not recommended to merge 🔴 |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
return result; | ||
} | ||
|
||
void propagate_precision::apply(module_pass_manager& mpm) const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you write docstrings for all of these functions that describes what they are supposed to do and how they work?
Also add how the pass is supposed to work and how it helps with precision or accuracy or performance ?
We can read the code but it is not time efficient for all to get high level understanding.
EXPECT(m1.sort() == m2.sort()); | ||
} | ||
|
||
TEST_CASE(propagate_reduce) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test where the pass doesn't do anything ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some background: where are we failing accuracy because of precision changes?
@@ -0,0 +1,20 @@ | |||
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP | |||
#define MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why the include guard is named differently?
This is related to the fp16 inaccuracy with llamav2(see #2556). #2883 will use FP32 for large reduce_means, but it still isnt enough to get accurate results(or avoid nans). So this will use FP32 for the |
No description provided.