Skip to content

Commit

Permalink
Enable GEMM/dot for FP8 using hipblasLT (#3577)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Nov 13, 2024
1 parent ca70d73 commit 495d3eb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ hipDataType get_type_hipblas(shape::type_t type)
case shape::int32_type: return HIP_R_32I;
case shape::uint32_type: return HIP_R_32U;
case shape::fp8e4m3fnuz_type: return HIP_R_8F_E4M3_FNUZ;
// TODO can remove this preprocessor conditional when hip verison defaults to have these types
#ifdef ROCM_USE_FLOAT8
case shape::fp8e4m3fn_type: return HIP_R_8F_E4M3;
case shape::fp8e5m2_type: return HIP_R_8F_E5M2;
#else
case shape::fp8e4m3fn_type:
case shape::fp8e5m2_type:
#endif
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
Expand Down
12 changes: 9 additions & 3 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM)

std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{
Expand Down Expand Up @@ -129,9 +130,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_fp8e4m3fnuz_ops.insert("argmin");

std::set<std::string> unsupported_fp8ocp_ops = {};
// TODO update with hipBLASLt support
unsupported_fp8ocp_ops.insert("dot");
unsupported_fp8ocp_ops.insert("quant_dot");
// TODO: remove this when the flag is removed
if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{}))
{
unsupported_fp8ocp_ops.insert("dot");
unsupported_fp8ocp_ops.insert("quant_dot");
}
#if MIGRAPHX_USE_MIOPEN
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8ocp_ops.insert("pooling");
Expand All @@ -140,6 +144,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
{
unsupported_fp8ocp_ops.insert("convolution");
unsupported_fp8ocp_ops.insert("quant_convolution");
unsupported_fp8ocp_ops.insert("dot");
unsupported_fp8ocp_ops.insert("quant_dot");
}
// add all device kernels
unsupported_fp8ocp_ops.insert("logsoftmax");
Expand Down

0 comments on commit 495d3eb

Please sign in to comment.