diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 934c2756f69d..e219cc684657 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -398,6 +398,7 @@ TVM_DLL Pass ForceNarrowIndexToInt32(); /*! * \brief Legalize bf16 compute Ops. Add a cast to fp32 * before Ops, then add a cast back to bf16. + * \param target The target used for checking native bf16 support * \return The pass. */ TVM_DLL Pass BF16ComputeLegalize(); @@ -405,10 +406,11 @@ TVM_DLL Pass BF16ComputeLegalize(); /*! * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32 * before Ops, then add a cast back to fp8. + * \param target The target used for checking native fp8 support * \param promote_dtype_str The data type used for type promotion, defaults to float16 * \return The pass. */ -TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16"); +TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16"); /*! * \brief Legalize bf16 storage types to u16. @@ -420,7 +422,7 @@ TVM_DLL Pass BF16StorageLegalize(); * \brief Legalize fp8 storage types to u8. * \return The pass. */ -TVM_DLL Pass FP8StorageLegalize(); +TVM_DLL Pass FP8StorageLegalize(Target target); /*! * \brief Inline calls to private functions diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index d203007dd182..b1f042c1a597 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -270,6 +270,7 @@ def callback_libdevice_path(arch): return "" +@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version") def get_target_compute_version(target=None): """Utility function to get compute capability of compilation target. @@ -406,6 +407,7 @@ def have_cudagraph(): return False +@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16") def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not @@ -421,6 +423,7 @@ def have_bf16(compute_version): return False +@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8") def have_fp8(compute_version): """Whether fp8 support is provided in the specified compute capability or not diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bdadb6db0fb4..33b4514e6b29 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -216,7 +216,6 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::TransformMmaBufferLayout()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::FP8ComputeLegalize()); pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -570,6 +569,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array mixed_pass_list; + mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target)); + // VerifyVTCMLimit must occur before LowerVtcmAlloc mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations @@ -619,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target)); mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index bba1488274e2..8fe740dad197 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -586,6 +586,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } + } else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) { + etype = llvm::Type::getInt8Ty(*ctx); } if (!dtype.is_scalar()) { #if TVM_LLVM_VERSION >= 110 diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 15905b030433..d352616f55fa 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -41,6 +41,31 @@ namespace tvm { namespace codegen { +std::string GetFP8Type(DataType type) { + std::stringstream stream; + int32_t lanes = type.lanes(); + std::string vec; + if (type.is_scalar()) { + vec = ""; + } else if (lanes == 2) { + vec = "_2"; + } else if (lanes == 4) { + vec = "_4"; + } else if (lanes == 8) { + vec = "_8"; + } else { + LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; + } + if (type.code() == DataType::kE4M3Float) { + stream << "fp8_e4" << vec << "_t"; + } else if (type.code() == DataType::kE5M2Float) { + stream << "fp8_e5" << vec << "_t"; + } else { + LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; + } + return stream.str(); +} + CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } void CodeGenCUDA::Init(bool output_ssa) { @@ -121,8 +146,15 @@ std::string CodeGenCUDA::Finish() { if (enable_fp8_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n"; decl_stream << "#include \n"; + decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n"; + decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n"; + decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n"; + decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n"; + decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n"; + decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n"; decl_stream << "#endif\n\n"; } + declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); if (enable_warp_shuffle_) { decl_stream << _cuda_warp_intrinsic_util; @@ -214,17 +246,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (t.is_scalar()) { os << "half"; } else if (lanes <= 8) { - // Emit CUDA code to access fp16 vector elements. - // - // half4 is stored as uint2 - // - // h4.x is emitted as *(half2*)(&(u2.x)).x - // h4.y is emitted as *(half2*)(&(u2.x)).y - // h4.z is emitted as *(half2*)(&(u2.y)).x - // h4.w is emitted as *(half2*)(&(u2.y)).y - // - ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - os << "uint" << lanes / 2; + ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type"; + if (lanes <= 4) { + os << "half" << lanes; + } else { + os << "uint" << lanes / 2; + } } else { fail = true; } @@ -271,16 +298,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (!fail) return; } else if (t.is_float8()) { - if (t.is_scalar()) { - os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char - } else if (lanes == 2) { - os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short - } else if (lanes == 4) { - os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int - } else { - fail = true; - } - if (!fail) return; + enable_fp8_ = true; + os << GetFP8Type(t); + return; } else if (t == DataType::Bool()) { os << "bool"; return; @@ -446,7 +466,7 @@ void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) - // Delcare the result. + // Declare the result. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(t, stream); @@ -497,7 +517,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; } } else if (t.is_float16()) { - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + if (t.lanes() <= 4) { + os << vec << "." << access[i]; + } else { + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } } else if (t.is_bfloat16()) { os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.lanes() > 4 && t.lanes() <= 8) { @@ -543,8 +567,13 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, stream << "(" << value << " << " << i % 4 * 8 << ");\n"; } } else if (t.is_float16()) { - stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " - << value << ";\n"; + if (t.lanes() <= 4) { + stream << vec << "." << access[i] << " = " << value << ";\n"; + } else { + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " + << value << ";\n"; + } + } else if (t.is_bfloat16()) { stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; @@ -648,6 +677,16 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { // Emit simple C-style type conversion. if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); + if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float || + from_ty.code() == DataType::kE4M3Float || from_ty.code() == DataType::kE5M2Float) { + std::ostringstream val; + val << "("; + PrintType(target_ty, val); + val << ")(" << PrintExpr(op->value) << ")"; + os << val.str(); + return; + } + // We could emit make_float4 like calls, but the emitted code looks // too compact to read. Emit this as vectorized unary ops. std::string sret = name_supply_->FreshName("_"); @@ -1194,9 +1233,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO std::string v = PrintExpr(op->value); PrintVecConstructor(op->dtype, os); os << '('; - for (int i = 0; i < lanes / 2; ++i) { - if (i != 0) os << ", "; - os << "__pack_half2(" << v << ", " << v << ")"; + if (lanes <= 4) { + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) os << ", "; + os << v << ", " << v; + } + } else { + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } } os << ')'; return; @@ -1448,15 +1494,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val PrintVecConstructor(t, os); os << '('; } - if (i % 2 == 0) { - os << "__pack_half2(" << value; + if (i == t.lanes() - 1) { + os << value << ")"; } else { - os << "," << value << ")"; - if (i != t.lanes() - 1) { - os << ","; - } else { - os << ")"; - } + os << value << ","; } return; } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 67471daf82c4..bf3e83928ed7 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -24,6 +24,8 @@ #ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ #define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ +#include + static constexpr const char* _cuda_half_t_def = R"( typedef unsigned short uint16_t; typedef unsigned char uint8_t; @@ -379,4 +381,44 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"( )"; +void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8) { + if (enable_fp16 || enable_fp8) { + stream << R"( +struct __align__(8) half4 { + __half x, y, z, w; + __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {} + __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {} +)"; + if (enable_fp8) { + stream << R"( + __host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) { + __nv_fp8x2_e4m3 lo_part, hi_part; + lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF); + hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF); + __half2 lo_half2 = static_cast<__half2>(lo_part); + __half2 hi_half2 = static_cast<__half2>(hi_part); + x = reinterpret_cast<__half*>(&lo_half2)[0]; + y = reinterpret_cast<__half*>(&lo_half2)[1]; + z = reinterpret_cast<__half*>(&hi_half2)[0]; + w = reinterpret_cast<__half*>(&hi_half2)[1]; + } + __host__ __device__ explicit operator __nv_fp8x4_e4m3() const { + __nv_fp8x4_e4m3 result; + __half2 lo_half2 = *reinterpret_cast(&x); + __half2 hi_half2 = *reinterpret_cast(&z); + __nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2); + result.__x = + (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + return result; + })"; + } + stream << R"( +}; +__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) { + return half4(x, y, z, w); +} +)"; + } +} + #endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 030dbd01badf..c0378790740f 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -693,6 +693,20 @@ class FP8StorageLegalizer : public StorageLegalizer { namespace transform { +bool CheckDataTypeSupport(const Target& target, const std::string& support_func_name) { + bool has_native_support = false; + if (target->kind->name == "cuda") { + if (const PackedFunc* get_cv = + tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) { + std::string compute_version = (*get_cv)(target); + if (const PackedFunc* check_support = tvm::runtime::Registry::Get(support_func_name)) { + has_native_support = (*check_support)(compute_version); + } + } + } + return has_native_support; +} + Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { // TODO(tvm-team): skip if the target supports bf16 @@ -713,9 +727,11 @@ Pass BF16StorageLegalize() { TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); -Pass FP8ComputeLegalize(String promote_dtype_str) { +Pass FP8ComputeLegalize(Target target, String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports fp8 + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { + return f; + } return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); @@ -723,9 +739,11 @@ Pass FP8ComputeLegalize(String promote_dtype_str) { TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); -Pass FP8StorageLegalize() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports fp8 +Pass FP8StorageLegalize(Target target) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { + return f; + } return FP8StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py new file mode 100644 index 000000000000..dade970418f9 --- /dev/null +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -0,0 +1,803 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import pytest + +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +from typing import List, Tuple +from tvm import DataType, DataTypeCode, IRModule +from tvm import dlight as dl +from tvm import relax, te, tir, topi +from tvm.relax.frontend import nn +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.topi.utils import get_const_tuple + + +@tvm.testing.requires_cuda_compute_version(9) +def test_e4m3_conversions(): + dtype = "e4m3_float8" + + @T.prim_func + def add( + A: T.Buffer((64,), dtype), + B: T.Buffer((64,), dtype), + C: T.Buffer((64,), dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i in range(64): + with T.block("C"): + v_i = T.axis.spatial(64, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast(dtype, T.Cast("float16", A[v_i]) + T.Cast("float16", B[v_i])) + + sch = tvm.tir.Schedule(add) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx = sch.split(b[0], factors=[None, 32]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + target = "cuda" + fadd = tvm.build(sch.mod, target=target) + + cuda_src = fadd.imported_modules[0].get_source() + assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" + + dev = tvm.device(target, 0) + + numpytype = "float8_e4m3fn" + a = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(numpytype), dev) + b = tvm.nd.array(np.random.uniform(low=0, high=5, size=64).astype(numpytype), dev) + c = tvm.nd.array(np.zeros(64, dtype=numpytype), dev) + fadd(a, b, c) + + tvm.testing.assert_allclose( + c.numpy().astype("float16"), (a.numpy() + b.numpy()).astype("float16") + ) + + +@tvm.testing.requires_cuda_compute_version(9) +def test_e4m3_packing(): + length = 64 + vector_length = 4 + native_dtype, packed_dtype = ("e4m3_float8x4", "uint32") + + @T.prim_func + def add( + A: T.Buffer((length,), native_dtype), + R: T.Buffer((length,), packed_dtype), + B: T.Buffer((length,), native_dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i in range(length): + with T.block("R"): + v_i = T.axis.spatial(length, i) + T.reads(A[v_i]) + T.writes(R[v_i]) + R[v_i] = T.reinterpret(packed_dtype, A[v_i]) + for i in range(length): + with T.block("B"): + v_i = T.axis.spatial(length, i) + T.reads(R[v_i]) + T.writes(B[v_i]) + B[v_i] = T.reinterpret(native_dtype, R[v_i]) + + sch = tvm.tir.Schedule(add) + block = sch.get_block("R") + b = sch.get_loops(block) + bx, tx = sch.split(b[0], factors=[None, 32]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + block = sch.get_block("B") + b = sch.get_loops(block) + bx, tx = sch.split(b[0], factors=[None, 32]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + target = "cuda" + f = tvm.build(sch.mod, target=target) + dev = tvm.device(target, 0) + + numpytype = "float8_e4m3fn" + np_shape = (length, vector_length) + a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) + a = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev) + r = tvm.nd.empty(shape=(length,), dtype=packed_dtype, device=dev) + b = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev) + a.copyfrom(a_np) + f(a, r, b) + tvm.testing.assert_allclose(a.numpy().astype("float16"), b.numpy().astype("float16")) + + +native_dtype, promoted_dtype = tvm.testing.parameters( + ("e4m3_float8", "float32"), + ("e4m3_float8", "float16"), + ("e4m3_float8x2", "float32x2"), + ("e4m3_float8x2", "float16x2"), + ("e4m3_float8x4", "float32x4"), + # Supported via half4 vector type extension in codegen + ("e4m3_float8x4", "float16x4"), +) + + +@tvm.testing.requires_cuda_compute_version(9) +def test_e4m3_vector_conversions(native_dtype, promoted_dtype): + vector_length = 64 + + @T.prim_func + def add( + A: T.Buffer((vector_length,), native_dtype), + B: T.Buffer((vector_length,), native_dtype), + C: T.Buffer((vector_length,), native_dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i in range(vector_length): + with T.block("C"): + v_i = T.axis.spatial(vector_length, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = T.Cast( + native_dtype, T.Cast(promoted_dtype, A[v_i]) + T.Cast(promoted_dtype, B[v_i]) + ) + + sch = tvm.tir.Schedule(add) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx = sch.split(b[0], factors=[None, 32]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + target = "cuda" + fadd = tvm.build(sch.mod, target=target) + cuda_src = fadd.imported_modules[0].get_source() + dev = tvm.device(target, 0) + + numpytype = "float8_e4m3fn" + if "x" in native_dtype: + lanes = int(native_dtype.split("x")[-1]) + else: + lanes = 1 + + if "x" in promoted_dtype: + promoted_base_dtype = promoted_dtype.split("x")[0] + else: + promoted_base_dtype = promoted_dtype + + np_shape = (vector_length, lanes) if lanes > 1 else (vector_length,) + a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) + a = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + a.copyfrom(a_np) + b_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype) + b = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + b.copyfrom(b_np) + c = tvm.nd.empty(shape=(vector_length,), dtype=native_dtype, device=dev) + fadd(a, b, c) + + tvm.testing.assert_allclose( + c.numpy().astype(promoted_base_dtype), (a_np + b_np).astype(promoted_base_dtype) + ) + + +bcast_length = tvm.testing.parameter(2, 4, 6, 8) + + +@tvm.testing.requires_cuda_compute_version(8) +def test_half_broadcast(bcast_length): + dtype = "float16" + + @T.prim_func + def vector_broadcast(a: T.Buffer[(), dtype], vec: T.Buffer[(bcast_length,), dtype]): + for t in range(1): + with T.block("broadcast"): + vec[0:bcast_length] = T.broadcast(a[()], bcast_length) + + sch = tvm.tir.Schedule(vector_broadcast) + block = sch.get_block("broadcast") + b = sch.get_loops(block) + bx, tx = sch.split(b[0], factors=[None, 1]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + target = "cuda" + func = tvm.build(sch.mod, target=target) + dev = tvm.device(target, 0) + + a_np = np.random.uniform(low=0, high=4, size=()).astype(dtype) + a = tvm.nd.array(a_np, device=dev) + b = tvm.nd.empty((bcast_length,), dtype=dtype, device=dev) + + func(a, b) + + b_np = np.full((bcast_length,), a_np) + + tvm.testing.assert_allclose(b.numpy(), b_np) + + +vector_length = tvm.testing.parameter(2, 4) + + +@tvm.testing.requires_cuda_compute_version(8) +def test_half_misaligned_vector_load(vector_length): + dtype = "float16" + vec_dtype = dtype + "x" + str(vector_length) + length = 256 + + @T.prim_func + def vector_load( + A: T.Buffer[(length,), dtype], B: T.Buffer[(length // vector_length,), vec_dtype] + ): + for b in T.thread_binding(1, thread="blockIdx.x"): + for i in T.thread_binding(length // vector_length, thread="threadIdx.x"): + vec_index = T.ramp((i + 1) * vector_length - 1, -1, vector_length) + B[i] = A[vec_index] + + target = "cuda" + f = tvm.build(vector_load, target=target) + + dev = tvm.device(target, 0) + a_np = np.random.uniform(low=0, high=1, size=(length,)).astype(dtype) + a = tvm.nd.array(a_np, device=dev) + + b = tvm.nd.empty((length // vector_length,), dtype=vec_dtype, device=dev) + + f(a, b) + + b_np = np.empty((length // vector_length, vector_length), dtype=dtype) + + for i in range(length // vector_length): + start_index = (i + 1) * vector_length - 1 + b_np[i, :] = a_np[start_index - vector_length + 1 : start_index + 1][::-1] + + tvm.testing.assert_allclose(b.numpy(), b_np) + + +@tvm.testing.requires_cuda_compute_version(8) +def test_half4_vector_add(): + dtype = "float16" + length = 64 + vector_length = 4 + vec_dtype = dtype + "x" + str(vector_length) + + @T.prim_func + def add( + A: T.Buffer((length,), vec_dtype), + B: T.Buffer((length,), vec_dtype), + C: T.Buffer((length,), vec_dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i in range(length): + with T.block("C"): + v_i = T.axis.spatial(length, i) + T.reads(A[v_i], B[v_i]) + T.writes(C[v_i]) + C[v_i] = A[v_i] + B[v_i] + + sch = tvm.tir.Schedule(add) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx = sch.split(b[0], factors=[None, 32]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + target = "cuda" + fadd = tvm.build(sch.mod, target=target) + dev = tvm.device(target, 0) + + a_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype) + a = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev) + a.copyfrom(a_np) + b_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype) + b = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev) + b.copyfrom(b_np) + c = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev) + + fadd(a, b, c) + c_expected = a_np + b_np + tvm.testing.assert_allclose(c.numpy(), c_expected, atol=1e-5, rtol=1e-5) + + +class BaseFP8E4M3QuantScaleOnly: + @classmethod + def create_quantize_func( + cls, + weight_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_elem_per_storage, + max_int_value, + axis, + output_transpose, + ) -> IRModule: + if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float: + quantize_func = cls.quantize_fp8x4_e4m3 + else: + assert NotImplementedError() + + bb = relax.BlockBuilder() # pylint: disable=invalid-name + weight_var = relax.Var("weight", relax.TensorStructInfo(weight_shape, model_dtype)) + compute_scale, compute_quantize, compute_transpose = quantize_func( + weight_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_elem_per_storage, + max_int_value, + axis, + output_transpose, + ) + with bb.function(name="main", params=[weight_var]): + with bb.dataflow(): + lv_scale = bb.emit_te(compute_scale, weight_var) + lv_quantized_weight = compute_quantize(bb, (weight_var, lv_scale)) + if compute_transpose: + lv_output = bb.emit_te(compute_transpose, lv_quantized_weight, lv_scale) + lv_quantized_weight = lv_output[0] + lv_scale = lv_output[1] + tuple_output = bb.emit((lv_quantized_weight, lv_scale)) + gv = bb.emit_output(tuple_output) + bb.emit_func_output(gv) + return bb.finalize() + + @classmethod + def create_dequantize_func( + cls, + packed_weight_shape, + scale_shape, + dequantized_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_elem_per_storage, + axis, + ) -> IRModule: + if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float: + dequantize_func = cls.dequantize_fp8x4_e4m3 + else: + assert NotImplementedError() + + bb = relax.BlockBuilder() # pylint: disable=invalid-name + packed_weight_var = relax.Var( + "weight", relax.TensorStructInfo(packed_weight_shape, storage_dtype) + ) + scale_var = relax.Var("scale", relax.TensorStructInfo(scale_shape, model_dtype)) + compute_dequantize = dequantize_func( + packed_weight_shape, + scale_shape, + dequantized_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_elem_per_storage, + axis, + ) + with bb.function(name="main", params=[packed_weight_var, scale_var]): + with bb.dataflow(): + lv = compute_dequantize(bb, (packed_weight_var, scale_var)) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.finalize() + + @classmethod + def quantize_fp8x4_e4m3( # pylint: disable=too-many-locals + cls, + weight_shape: List[tir.PrimExpr], + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_elem_per_storage, + max_int_value, + axis: int = -1, + output_transpose: bool = False, + ) -> Tuple[te.Tensor, te.Tensor]: + """Group quantization for weight tensor, defined in tensor expression.""" + max_int = tir.const(max_int_value, model_dtype) + shape = weight_shape # pylint: disable=invalid-name + axis = axis if axis >= 0 else len(shape) + axis + k = shape[axis] + quantize_dtype = DataType(quantize_dtype) + # compute scale per group + r = te.reduce_axis((0, group_size), name="r") # pylint: disable=invalid-name + num_group = tir.ceildiv(k, group_size) + # (4096, 4096) -> quantize axis = 0, group size = 32 -> (128, 4096) + # for channel quant group_size = 4096 -> (1, 4096) + scale_shape = (*shape[:axis], num_group, *shape[axis + 1 :]) + + def compute_scale(weight: te.Tensor): + min_scaling_factor = tir.const(1.0 / (max_int_value * 512.0), model_dtype) + max_abs = te.compute( + shape=scale_shape, + fcompute=lambda *idx: te.max( + tir.if_then_else( + idx[axis] * group_size + r < k, + te.abs(weight(*idx[:axis], idx[axis] * group_size + r, *idx[axis + 1 :])), + te.min_value(model_dtype), + ), + axis=r, + ), + name="max_abs_value", + ) + scale = te.compute( + scale_shape, + lambda *idx: te.max( + max_abs(*idx).astype(model_dtype) / max_int, min_scaling_factor + ), + name="scale", + ) + return scale + + def compute_quantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr): + # compute scaled weight + packed_shape = (weight_shape[0], weight_shape[1] // num_elem_per_storage) + quant = cls.quant_and_pack_fp8x4_e4m3_sm90( + weight_shape, + packed_shape, + scale_shape, + group_size, + axis, + model_dtype, + storage_dtype, + quantize_dtype, + ) + # quant.show() + + global_var = bb.add_func(quant, "quantized_weight") + lv_quantized_weight = bb.emit( + relax.call_tir( + global_var, args, relax.TensorStructInfo(packed_shape, storage_dtype) + ) + ) + return lv_quantized_weight + + compute_transpose = None + if output_transpose: + + def compute_transpose(quantized_weight: te.Tensor, scale: te.Tensor): + if len(quantized_weight.shape) != 2 or len(scale.shape) != 2: + raise ValueError( + "Does not support transpose output quantized weight with ndim != 2" + ) + + quantized_weight = topi.transpose(quantized_weight) + scale = topi.transpose(scale) + return quantized_weight, scale + + return compute_scale, compute_quantize_weight, compute_transpose + + @classmethod + def dequantize_fp8x4_e4m3( # pylint: disable=too-many-locals + cls, + packed_weight_shape: List[tir.PrimExpr], + scale_shape, + dequant_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_elem_per_storage, + axis: int = -1, + ) -> Tuple[te.Tensor, te.Tensor]: + """Group quantization for weight tensor, defined in tensor expression.""" + axis = axis if axis >= 0 else len(shape) + axis + + def compute_dequantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr): + dequant = cls.dequant_fp8x4_e4m3_sm90( + packed_weight_shape, + scale_shape, + dequant_shape, + group_size, + axis, + model_dtype, + storage_dtype, + quantize_dtype, + ) + + global_var = bb.add_func(dequant, "dequantize_weight") + lv_dequantized_weight = bb.emit( + relax.call_tir(global_var, args, relax.TensorStructInfo(dequant_shape, model_dtype)) + ) + return lv_dequantized_weight + + return compute_dequantize_weight + + @classmethod + def quant_and_pack_fp8x4_e4m3_sm90( + cls, + weight_shape, + packed_shape, + scale_shape, + group_size, + axis, + model_dtype, + storage_dtype, + quantized_dtype, + ): + vector_length = 4 + vec_quantized_dtype = f"{quantized_dtype}x{vector_length}" + vec_model_dtype = f"{model_dtype}x{vector_length}" + num_elem_per_storage = vector_length + # TODO(csullivan) assert on storage dtype / quantize type bytes == vector length + assert ( + group_size % vector_length == 0 + ), f"Number of elements in a group must be divisible by fp8 vector length {vector_length}" + + @T.prim_func(private=True) + def quant_pack( + A: T.Buffer(weight_shape, model_dtype), + scale: T.Buffer(scale_shape, model_dtype), + compute: T.Buffer( + packed_shape, + storage_dtype, + ), + ): + # with T.block("root"): + # test = T.alloc_buffer(1, dtype=vec_model_dtype, scope="local") + for i0, i1 in T.grid( + T.int64(weight_shape[0]), T.int64(weight_shape[1] // vector_length) + ): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads( + A[v_i0, v_i1 : v_i1 + vector_length], + scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)], + ) + T.writes(compute[v_i0, v_i1 * vector_length]) + compute[v_i0, v_i1] = T.reinterpret( + storage_dtype, + T.Cast( + vec_quantized_dtype, + A[v_i0, T.ramp(v_i1 * vector_length, 1, vector_length)] + / scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)], + ), + ) + + return quant_pack + + @classmethod + def dequant_fp8x4_e4m3_sm90( + cls, + packed_weight_shape, + scale_shape, + out_shape, + group_size, + axis, + model_dtype, + storage_dtype, + quantized_dtype, + ): + vector_length = 4 + vec_quantized_dtype = f"{quantized_dtype}x{vector_length}" + vec_model_dtype = f"{model_dtype}x{vector_length}" + num_elem_per_storage = vector_length + + @T.prim_func + def dequant( + packed_weight: T.Buffer(packed_weight_shape, storage_dtype), + scale: T.Buffer(scale_shape, model_dtype), + dequantize: T.Buffer(out_shape, model_dtype), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i0, i1 in T.grid(T.int64(packed_weight_shape[0]), T.int64(packed_weight_shape[1])): + with T.block("dequantize"): + v_i0 = T.axis.spatial(T.int64(packed_weight_shape[0]), i0) + v_i1 = T.axis.spatial(T.int64(packed_weight_shape[1]), i1) + T.reads( + packed_weight[v_i0, v_i1], + scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)], + ) + + dequantize[v_i0, T.ramp(v_i1 * vector_length, 1, vector_length)] = T.Cast( + vec_model_dtype, + T.reinterpret(vec_quantized_dtype, packed_weight[v_i0, v_i1]), + ) * T.Broadcast( + scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)], + vector_length, + ) + + return dequant + + @classmethod + def compile_quant_and_dequant_by_scale( + cls, + weight_shape, + scales_shape, + quant_weight_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_el_per_storage, + max_int_value, + axis, + target_str, + dev, + ): + quant_mod = cls.create_quantize_func( + weight_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_el_per_storage, + max_int_value, + axis, + output_transpose=False, + ) + # quant_mod.show() + + target = tvm.target.Target(target_str) + with target: + quant_mod = dl.ApplyDefaultSchedule( + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(quant_mod) + ex_1 = relax.build(quant_mod, target=target) + vm_1 = relax.VirtualMachine(ex_1, dev) + + dequant_mod = cls.create_dequantize_func( + quant_weight_shape, + scales_shape, + weight_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_el_per_storage, + axis, + ) + # dequant_mod.show() + + with target: + dequant_mod = dl.ApplyDefaultSchedule( + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(dequant_mod) + dequant_mod.show() + + ex_2 = relax.build(dequant_mod, target=target) + vm_2 = relax.VirtualMachine(ex_2, dev) + + def print_cuda(target, mod, name=None): + if name: + mod = mod[name] + f = tvm.build(mod, target=target) + cuda_src = f.imported_modules[0].get_source() + print(cuda_src) + + print_cuda(target, dequant_mod, name="dequant") + + return vm_1["main"], vm_2["main"] + + +class TestFP8e4x4QuantDequantScale(BaseFP8E4M3QuantScaleOnly): + # weight_shape = tvm.testing.parameter((32000, 4096), (4096, 14336)) + weight_shape = tvm.testing.parameter((128, 256), (128, 64)) + + @tvm.testing.fixture + def group_size(self): + return 64 + + @tvm.testing.fixture + def axis(self): + return 1 + + @tvm.testing.fixture + def model_dtype(self): + return "float16" + + @tvm.testing.fixture + def storage_dtype(self): + return "uint32" + + @tvm.testing.fixture + def quantize_dtype(self): + return "e4m3_float8" + + @tvm.testing.fixture + def num_el_per_storage(self): + return 4 + + @tvm.testing.fixture + def max_int_value(self): + return 448 + + @tvm.testing.fixture + def target_str(self): + return "cuda" + + @tvm.testing.fixture + def scale_shape(self, weight_shape, group_size, axis): + return [ + (d + group_size - 1) // group_size if axis == i else d + for i, d in enumerate(weight_shape) + ] + + @tvm.testing.fixture + def quant_weight_shape(self, weight_shape, num_el_per_storage, axis): + return [ + (d + num_el_per_storage - 1) // num_el_per_storage if axis == i else d + for i, d in enumerate(weight_shape) + ] + + @tvm.testing.fixture + def compiled_functions( + self, + weight_shape, + scale_shape, + quant_weight_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_el_per_storage, + max_int_value, + axis, + target_str, + ): + dev = tvm.device(target_str, 0) + return self.compile_quant_and_dequant_by_scale( + weight_shape, + scale_shape, + quant_weight_shape, + model_dtype, + quantize_dtype, + storage_dtype, + group_size, + num_el_per_storage, + max_int_value, + axis, + target_str, + dev, + ) + + @tvm.testing.requires_cuda_compute_version(9) + def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): + quant, dequant = compiled_functions + dev = tvm.device(target_str, 0) + + weight_np = np.random.uniform(-100, 100, weight_shape).astype(model_dtype) + weight = tvm.nd.array(weight_np, device=dev) + quant_weight, scales = quant(weight) + quant_weight_np, scales_np = quant_weight.numpy(), scales.numpy() + + dequant_weight = dequant(quant_weight, scales) + dequant_weight_np = dequant_weight.numpy() + tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2) + + +if __name__ == "__main__": + tvm.testing.main()