From 495d3eb14a9946969d712df22b5187d6ee890d5f Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Wed, 13 Nov 2024 10:39:43 -0500 Subject: [PATCH] Enable GEMM/dot for FP8 using hipblasLT (#3577) --- src/targets/gpu/hip_gemm_impl.cpp | 6 ++++++ src/targets/gpu/target.cpp | 12 +++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index f5ec898d8d5..4e282cc01ff 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -70,8 +70,14 @@ hipDataType get_type_hipblas(shape::type_t type) case shape::int32_type: return HIP_R_32I; case shape::uint32_type: return HIP_R_32U; case shape::fp8e4m3fnuz_type: return HIP_R_8F_E4M3_FNUZ; +// TODO can remove this preprocessor conditional when hip verison defaults to have these types +#ifdef ROCM_USE_FLOAT8 + case shape::fp8e4m3fn_type: return HIP_R_8F_E4M3; + case shape::fp8e5m2_type: return HIP_R_8F_E5M2; +#else case shape::fp8e4m3fn_type: case shape::fp8e5m2_type: +#endif case shape::tuple_type: case shape::bool_type: case shape::uint16_type: diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index db76c5a24ac..fa01c514c0c 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -81,6 +81,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -129,9 +130,12 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_fp8e4m3fnuz_ops.insert("argmin"); std::set unsupported_fp8ocp_ops = {}; - // TODO update with hipBLASLt support - unsupported_fp8ocp_ops.insert("dot"); - unsupported_fp8ocp_ops.insert("quant_dot"); + // TODO: remove this when the flag is removed + if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{})) + { + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); + } #if MIGRAPHX_USE_MIOPEN // MIOpen doesn't have support for fp8 pooling yet. unsupported_fp8ocp_ops.insert("pooling"); @@ -140,6 +144,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti { unsupported_fp8ocp_ops.insert("convolution"); unsupported_fp8ocp_ops.insert("quant_convolution"); + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); } // add all device kernels unsupported_fp8ocp_ops.insert("logsoftmax");