Skip to content

Commit

Permalink
[TIR][CUDA] Add native FP8 support to codegen (#16548)
Browse files Browse the repository at this point in the history
* [TIR][CUDA] Add native FP8 support to codegen

Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops.

* Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls.

* Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3)

* Add test for e4m3 <-> half conversion which lowers to ptx intrins.

* Introduce half4 and support native fp8 vector types (1, 2, 4), and
conversion between float and half vector types with equal lanes

* Only cast to half2 for vector loads/stores of non native half struct types (lanes > 4).

* Test e4m3 x4 vector quant/dequant

---------

Co-authored-by: Joseph McMahan <[email protected]>
  • Loading branch information
csullivan and JosephTheOctonaut authored Mar 15, 2024
1 parent 45df124 commit feb1043
Show file tree
Hide file tree
Showing 8 changed files with 957 additions and 45 deletions.
6 changes: 4 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,17 +398,19 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
/*!
* \brief Legalize bf16 compute Ops. Add a cast to fp32
* before Ops, then add a cast back to bf16.
* \param target The target used for checking native bf16 support
* \return The pass.
*/
TVM_DLL Pass BF16ComputeLegalize();

/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
* \param target The target used for checking native fp8 support
* \param promote_dtype_str The data type used for type promotion, defaults to float16
* \return The pass.
*/
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16");

/*!
* \brief Legalize bf16 storage types to u16.
Expand All @@ -420,7 +422,7 @@ TVM_DLL Pass BF16StorageLegalize();
* \brief Legalize fp8 storage types to u8.
* \return The pass.
*/
TVM_DLL Pass FP8StorageLegalize();
TVM_DLL Pass FP8StorageLegalize(Target target);

/*!
* \brief Inline calls to private functions
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def callback_libdevice_path(arch):
return ""


@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version")
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
Expand Down Expand Up @@ -406,6 +407,7 @@ def have_cudagraph():
return False


@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16")
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
Expand All @@ -421,6 +423,7 @@ def have_bf16(compute_version):
return False


@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8")
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not
Expand Down
5 changes: 3 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::TransformMmaBufferLayout());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::FP8ComputeLegalize());
pass_list.push_back(tir::transform::BF16ComputeLegalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down Expand Up @@ -570,6 +569,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

Array<Pass> mixed_pass_list;

mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
Expand Down Expand Up @@ -619,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
} else {
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
Expand Down
2 changes: 2 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
default:
LOG(FATAL) << "do not support " << dtype;
}
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) {
etype = llvm::Type::getInt8Ty(*ctx);
}
if (!dtype.is_scalar()) {
#if TVM_LLVM_VERSION >= 110
Expand Down
113 changes: 77 additions & 36 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@
namespace tvm {
namespace codegen {

std::string GetFP8Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "_2";
} else if (lanes == 4) {
vec = "_4";
} else if (lanes == 8) {
vec = "_8";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8";
}
if (type.code() == DataType::kE4M3Float) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kE5M2Float) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
}
return stream.str();
}

CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }

void CodeGenCUDA::Init(bool output_ssa) {
Expand Down Expand Up @@ -121,8 +146,15 @@ std::string CodeGenCUDA::Finish() {
if (enable_fp8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
decl_stream << "#include <cuda_fp8.h>\n";
decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n";
decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n";
decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
decl_stream << "#endif\n\n";
}
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
Expand Down Expand Up @@ -214,17 +246,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << "half";
} else if (lanes <= 8) {
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type";
if (lanes <= 4) {
os << "half" << lanes;
} else {
os << "uint" << lanes / 2;
}
} else {
fail = true;
}
Expand Down Expand Up @@ -271,16 +298,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
if (!fail) return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
enable_fp8_ = true;
os << GetFP8Type(t);
return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
Expand Down Expand Up @@ -446,7 +466,7 @@ void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) {

void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
// Delcare the result.
// Declare the result.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(t, stream);
Expand Down Expand Up @@ -497,7 +517,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
if (t.lanes() <= 4) {
os << vec << "." << access[i];
} else {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
}
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
Expand Down Expand Up @@ -543,8 +567,13 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
if (t.lanes() <= 4) {
stream << vec << "." << access[i] << " = " << value << ";\n";
} else {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
}

} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
Expand Down Expand Up @@ -648,6 +677,16 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
// Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);

if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float ||
from_ty.code() == DataType::kE4M3Float || from_ty.code() == DataType::kE5M2Float) {
std::ostringstream val;
val << "(";
PrintType(target_ty, val);
val << ")(" << PrintExpr(op->value) << ")";
os << val.str();
return;
}

// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
std::string sret = name_supply_->FreshName("_");
Expand Down Expand Up @@ -1194,9 +1233,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
if (lanes <= 4) {
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << v << ", " << v;
}
} else {
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
}
os << ')';
return;
Expand Down Expand Up @@ -1448,15 +1494,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_half2(" << value;
if (i == t.lanes() - 1) {
os << value << ")";
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
os << value << ",";
}
return;
}
Expand Down
42 changes: 42 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
#define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_

#include <string>

static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
Expand Down Expand Up @@ -379,4 +381,44 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(
)";

void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8) {
if (enable_fp16 || enable_fp8) {
stream << R"(
struct __align__(8) half4 {
__half x, y, z, w;
__host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}
__host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}
)";
if (enable_fp8) {
stream << R"(
__host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
__nv_fp8x2_e4m3 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
__half2 lo_half2 = static_cast<__half2>(lo_part);
__half2 hi_half2 = static_cast<__half2>(hi_part);
x = reinterpret_cast<__half*>(&lo_half2)[0];
y = reinterpret_cast<__half*>(&lo_half2)[1];
z = reinterpret_cast<__half*>(&hi_half2)[0];
w = reinterpret_cast<__half*>(&hi_half2)[1];
}
__host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
__nv_fp8x4_e4m3 result;
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
__nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
})";
}
stream << R"(
};
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
return half4(x, y, z, w);
}
)";
}
}

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
28 changes: 23 additions & 5 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,20 @@ class FP8StorageLegalizer : public StorageLegalizer {

namespace transform {

bool CheckDataTypeSupport(const Target& target, const std::string& support_func_name) {
bool has_native_support = false;
if (target->kind->name == "cuda") {
if (const PackedFunc* get_cv =
tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) {
std::string compute_version = (*get_cv)(target);
if (const PackedFunc* check_support = tvm::runtime::Registry::Get(support_func_name)) {
has_native_support = (*check_support)(compute_version);
}
}
}
return has_native_support;
}

Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
// TODO(tvm-team): skip if the target supports bf16
Expand All @@ -713,19 +727,23 @@ Pass BF16StorageLegalize() {

TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);

Pass FP8ComputeLegalize(String promote_dtype_str) {
Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
// TODO(tvm-team): skip if the target supports fp8
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
}

TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);

Pass FP8StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
// TODO(tvm-team): skip if the target supports fp8
Pass FP8StorageLegalize(Target target) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
return FP8StorageLegalizer().Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {});
Expand Down
Loading

0 comments on commit feb1043

Please sign in to comment.