-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 GPU implementation #2455
FP8 GPU implementation #2455
Conversation
@@ -118,7 +118,7 @@ struct highest | |||
template <class T> | |||
constexpr operator T() const | |||
{ | |||
return numeric_max<vec_type<T>>(); | |||
return numeric_max<vec_type<T>, void>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need void
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make it call numeric_max() always. I had to specialize numeric_max
for fp8 to avoid redefinition by setting second template parmaeter to void
. by adding void
this would work for all the types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can add class Enable=void
parameter to avoid this. The void
parameter is an implementation detail that consumers shouldn't touch as its just there for enable_if
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made changes.
constexpr T numeric_max<T, void>() \ | ||
{ \ | ||
return fp8::numeric_limits<T>::max(); \ | ||
} \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should just be a templated function instead of using a macro:
template <class T, MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or
is_same<T, fp8::fp8e5m2fnuz>{} or
is_same<T, fp8::fp8e4m3fn>{} or
is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_max<T>()
{
return fp8::numeric_limits<T>::max();
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't work it, It would be considered as redefinition of numeric_max()
. I need to speicalize numeric_max
template that is defined in type_traits.hpp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok then you should be able to specialize it like this:
template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ, class Enable=void>
constexpr float8<T, FNUZ> numeric_max<float8<T, FNUZ>, Enable>()
{
return fp8::numeric_limits<float8<T, FNUZ>>::max();
}
That should work for all fp8 types. Plus the class Enable=void
wont require us to pass void
into the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made changes.
constexpr T numeric_lowest<T>() \ | ||
{ \ | ||
return fp8::numeric_limits<T>::lowest(); \ | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this overload necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes i think it's used inside ops::lowest{}
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed it. It it taken care of by numeric_lowest
inside type_traits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, you can specialize this using float8<T, FNUZ>
template.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float8<T, FNUZ>
Partial specialization that way is not allowed. Runs into compilation errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the usage of implicit_conversion
?
Looks like a lot of the kernels were changed to be more explicit about the types.
Add throw message for DNNL Co-authored-by: Charlie Lin <[email protected]>
From what i understand, some of the internal functions/kernels return output in different type compared to what is desired from JIT kernel's output type. and implicit conversion takes care of such types mismatches. Float8 constructors are marked as explicit to avoid implict conversions and therefore had to call constructor explicitly in many places explicitly. Other reason is to avoid warnings about narrowing conversion from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. No more questions/concerns and looks like you've handled all comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! We should probably look into having an env variable to enable fp8 emulation since its probably much slower on unsupported hardware then just using fp32 or fp16 directly.
like this idea too but lets get this in so we can hammer the others out. Good for a smaller final PR |
Tested on MI300.
Added verify tests for Math ops, reduce ops and a couple/few pointwise ops.
Some FP8 tests are commented out because they run into compile errors. Need to fix and enable them