From 2a62c7215419a859321460c7fb9e2da272f4d003 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 5 Jun 2024 07:45:04 -0700 Subject: [PATCH] [FP8][Codegen] Add make_fp8 vector constructors (#17065) * [FP8][Codegen] Add make_fp8 vector constructors. Allows vectorized fp8 loading. --------- Co-authored-by: Chris Sullivan --- src/target/source/codegen_cuda.cc | 25 +++++++++---------- src/target/source/literal/cuda_half_t.h | 20 +++++++++++++++ .../codegen/test_target_codegen_cuda_fp8.py | 2 +- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ecb095761189..bd2804830172 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -48,21 +48,22 @@ std::string GetFP8Type(DataType type) { if (type.is_scalar()) { vec = ""; } else if (lanes == 2) { - vec = "_2"; + vec = "x2"; } else if (lanes == 4) { - vec = "_4"; - } else if (lanes == 8) { - vec = "_8"; + vec = "x4"; } else { LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; } + stream << "__nv_fp8"; + std::string suffix; if (type.code() == DataType::kE4M3Float) { - stream << "fp8_e4" << vec << "_t"; + suffix = "_e4m3"; } else if (type.code() == DataType::kE5M2Float) { - stream << "fp8_e5" << vec << "_t"; + suffix = "_e5m2"; } else { LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; } + stream << vec << suffix; return stream.str(); } @@ -146,12 +147,6 @@ std::string CodeGenCUDA::Finish() { if (enable_fp8_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n"; decl_stream << "#include \n"; - decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n"; - decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n"; - decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n"; - decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n"; - decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n"; - decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n"; decl_stream << "#endif\n\n"; } declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); @@ -299,7 +294,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (!fail) return; } else if (t.is_float8()) { enable_fp8_ = true; - os << GetFP8Type(t); + if (t.lanes() <= 4) { + os << GetFP8Type(t); + } else { + os << "uint" << t.lanes() / 4; + } return; } else if (t == DataType::Bool()) { os << "bool"; diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 27d44d9f7f4a..c5ecda07a4d3 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -431,6 +431,26 @@ struct __align__(8) half4 { (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); return result; } + __device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { + __nv_fp8x2_e5m2 result; + result.__x = (x) | (y << 8); + return result; + } + __device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { + __nv_fp8x4_e5m2 result; + result.__x = (a) | (b << 8) | (c << 16) | (d << 24); + return result; + } + __device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { + __nv_fp8x2_e4m3 result; + result.__x = (x) | (y << 8); + return result; + } + __device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { + __nv_fp8x4_e4m3 result; + result.__x = (a) | (b << 8) | (c << 16) | (d << 24); + return result; + } )"; } stream << R"( diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 5566ae243477..adcb05839bc9 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -64,7 +64,7 @@ def add( fadd = tvm.build(sch.mod, target=target) cuda_src = fadd.imported_modules[0].get_source() - assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" + assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" dev = tvm.device(target, 0)