Skip to content

Commit

Permalink
fix fp8 in buffercomparator (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang authored Nov 26, 2024
1 parent e02959b commit a7fcd09
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions xla/service/gpu/buffer_comparator.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
Expand All @@ -130,13 +131,17 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
}
if (mcount)
atomicAdd(mismatch_count, mcount);
#else
abort();
#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
}

__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
Expand All @@ -159,6 +164,9 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
}
if (mcount)
atomicAdd(mismatch_count, mcount);
#else
abort();
#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

Expand Down

0 comments on commit a7fcd09

Please sign in to comment.