From f79e64f7a6d18828d1808c5e931878676e2897c3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 4 Jun 2024 16:48:21 -0700 Subject: [PATCH] fix --- src/target/source/codegen_cuda.cc | 6 +++++- tests/python/codegen/test_target_codegen_cuda_fp8.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 5e87031939a5..bd2804830172 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -294,7 +294,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (!fail) return; } else if (t.is_float8()) { enable_fp8_ = true; - os << GetFP8Type(t); + if (t.lanes() <= 4) { + os << GetFP8Type(t); + } else { + os << "uint" << t.lanes() / 4; + } return; } else if (t == DataType::Bool()) { os << "bool"; diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 5566ae243477..adcb05839bc9 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -64,7 +64,7 @@ def add( fadd = tvm.build(sch.mod, target=target) cuda_src = fadd.imported_modules[0].get_source() - assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" + assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" dev = tvm.device(target, 0)