From d404170535348f530182f82322a0ee55b4729e08 Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Tue, 19 Nov 2024 06:32:01 -0600 Subject: [PATCH] quick fix --- xla/service/gpu/buffer_comparator.cu.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xla/service/gpu/buffer_comparator.cu.cc b/xla/service/gpu/buffer_comparator.cu.cc index e3c8b4310542e..15b7ece1db097 100644 --- a/xla/service/gpu/buffer_comparator.cu.cc +++ b/xla/service/gpu/buffer_comparator.cu.cc @@ -108,6 +108,7 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) int mcount = 0; uint64_t unroll = 128 / sizeof(*buffer_a); uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll; @@ -130,6 +131,9 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, } if (mcount) atomicAdd(mismatch_count, mcount); +#else + abort(); +#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) } __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, @@ -137,6 +141,7 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) int mcount = 0; uint64_t unroll = 128 / sizeof(*buffer_a); uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll; @@ -159,6 +164,9 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, } if (mcount) atomicAdd(mismatch_count, mcount); +#else + abort(); +#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) } #endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200