Skip to content

Commit

Permalink
[FP8][Codegen] Add make_fp8 vector constructors (#17065)
Browse files Browse the repository at this point in the history
* [FP8][Codegen] Add make_fp8 vector constructors.

Allows vectorized fp8 loading.

---------

Co-authored-by: Chris Sullivan <[email protected]>
  • Loading branch information
vinx13 and csullivan authored Jun 5, 2024
1 parent 1400627 commit 2a62c72
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
25 changes: 12 additions & 13 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,22 @@ std::string GetFP8Type(DataType type) {
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "_2";
vec = "x2";
} else if (lanes == 4) {
vec = "_4";
} else if (lanes == 8) {
vec = "_8";
vec = "x4";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8";
}
stream << "__nv_fp8";
std::string suffix;
if (type.code() == DataType::kE4M3Float) {
stream << "fp8_e4" << vec << "_t";
suffix = "_e4m3";
} else if (type.code() == DataType::kE5M2Float) {
stream << "fp8_e5" << vec << "_t";
suffix = "_e5m2";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
}
stream << vec << suffix;
return stream.str();
}

Expand Down Expand Up @@ -146,12 +147,6 @@ std::string CodeGenCUDA::Finish() {
if (enable_fp8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
decl_stream << "#include <cuda_fp8.h>\n";
decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n";
decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n";
decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
decl_stream << "#endif\n\n";
}
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
Expand Down Expand Up @@ -299,7 +294,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
if (!fail) return;
} else if (t.is_float8()) {
enable_fp8_ = true;
os << GetFP8Type(t);
if (t.lanes() <= 4) {
os << GetFP8Type(t);
} else {
os << "uint" << t.lanes() / 4;
}
return;
} else if (t == DataType::Bool()) {
os << "bool";
Expand Down
20 changes: 20 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,26 @@ struct __align__(8) half4 {
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
__device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x, __nv_fp8_storage_t y) {
__nv_fp8x2_e5m2 result;
result.__x = (x) | (y << 8);
return result;
}
__device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
__nv_fp8x4_e5m2 result;
result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
return result;
}
__device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x, __nv_fp8_storage_t y) {
__nv_fp8x2_e4m3 result;
result.__x = (x) | (y << 8);
return result;
}
__device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
__nv_fp8x4_e4m3 result;
result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
return result;
}
)";
}
stream << R"(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/codegen/test_target_codegen_cuda_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def add(
fadd = tvm.build(sch.mod, target=target)

cuda_src = fadd.imported_modules[0].get_source()
assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA"
assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA"

dev = tvm.device(target, 0)

Expand Down

0 comments on commit 2a62c72

Please sign in to comment.