Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 GPU implementation #2455

Merged
merged 105 commits into from
Dec 1, 2023
Merged

FP8 GPU implementation #2455

merged 105 commits into from
Dec 1, 2023

Conversation

umangyadav
Copy link
Member

Tested on MI300.

Added verify tests for Math ops, reduce ops and a couple/few pointwise ops.

Some FP8 tests are commented out because they run into compile errors. Need to fix and enable them

@@ -118,7 +118,7 @@ struct highest
template <class T>
constexpr operator T() const
{
return numeric_max<vec_type<T>>();
return numeric_max<vec_type<T>, void>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need void here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make it call numeric_max() always. I had to specialize numeric_max for fp8 to avoid redefinition by setting second template parmaeter to void. by adding void this would work for all the types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can add class Enable=void parameter to avoid this. The void parameter is an implementation detail that consumers shouldn't touch as its just there for enable_if.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made changes.

constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be a templated function instead of using a macro:

template <class T, MIGRAPHX_REQUIRES(is_same<T, fp8::fp8e4m3fnuz>{} or
                                     is_same<T, fp8::fp8e5m2fnuz>{} or
                                     is_same<T, fp8::fp8e4m3fn>{} or
                                     is_same<T, fp8::fp8e5m2>{})>
constexpr T numeric_max<T>()
{
    return fp8::numeric_limits<T>::max();
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't work it, It would be considered as redefinition of numeric_max(). I need to speicalize numeric_max template that is defined in type_traits.hpp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then you should be able to specialize it like this:

template <migraphx::fp8::f8_type T = migraphx::fp8::f8_type::fp8, bool FNUZ, class Enable=void>
constexpr float8<T, FNUZ> numeric_max<float8<T, FNUZ>, Enable>()
{
    return fp8::numeric_limits<float8<T, FNUZ>>::max();
}

That should work for all fp8 types. Plus the class Enable=void wont require us to pass void into the function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made changes.

constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this overload necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i think it's used inside ops::lowest{}.

Copy link
Member Author

@umangyadav umangyadav Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it. It it taken care of by numeric_lowest inside type_traits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, you can specialize this using float8<T, FNUZ> template.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8<T, FNUZ> Partial specialization that way is not allowed. Runs into compilation errors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made changes.

Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the usage of implicit_conversion?
Looks like a lot of the kernels were changed to be more explicit about the types.

src/targets/cpu/dnnl.cpp Outdated Show resolved Hide resolved
test/simplify_algebra_test.cpp Show resolved Hide resolved
Add throw message for DNNL

Co-authored-by: Charlie Lin <[email protected]>
@umangyadav
Copy link
Member Author

umangyadav commented Nov 30, 2023

What's the usage of implicit_conversion?
Looks like a lot of the kernels were changed to be more explicit about the types.

From what i understand, some of the internal functions/kernels return output in different type compared to what is desired from JIT kernel's output type. and implicit conversion takes care of such types mismatches.

Float8 constructors are marked as explicit to avoid implict conversions and therefore had to call constructor explicitly in many places explicitly.

Other reason is to avoid warnings about narrowing conversion from Float to __Float16 or fp8.

@umangyadav umangyadav added the FP8 issues related to FP8 implemenation label Dec 1, 2023
Copy link
Collaborator

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. No more questions/concerns and looks like you've handled all comments

Copy link
Collaborator

@pfultz2 pfultz2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! We should probably look into having an env variable to enable fp8 emulation since its probably much slower on unsupported hardware then just using fp32 or fp16 directly.

@TedThemistokleous
Copy link
Collaborator

Looks good! We should probably look into having an env variable to enable fp8 emulation since its probably much slower on unsupported hardware then just using fp32 or fp16 directly.

like this idea too but lets get this in so we can hammer the others out. Good for a smaller final PR

@causten causten merged commit eafd55d into develop Dec 1, 2023
14 of 15 checks passed
@causten causten deleted the gpu_fp8 branch December 1, 2023 23:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FP8 issues related to FP8 implemenation high priority A PR with high priority for review and merging.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants