diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index 39f0d211e1611..40e55a9a07c5d 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -1079,10 +1079,10 @@ absl::Status GpuDriver::SynchronousMemsetUint32(Context* context, absl::Status GpuDriver::AsynchronousMemsetUint8(Context* context, CUdeviceptr location, uint8_t value, - size_t uint32_count, + size_t uint8_count, CUstream stream) { ScopedActivateContext activation(context); - return cuda::ToStatus(cuMemsetD8Async(location, value, uint32_count, stream), + return cuda::ToStatus(cuMemsetD8Async(location, value, uint8_count, stream), "Failed to enqueue async memset operation"); } diff --git a/xla/stream_executor/gpu/gpu_driver.h b/xla/stream_executor/gpu/gpu_driver.h index 2b299e544b307..08531dd0f9f18 100644 --- a/xla/stream_executor/gpu/gpu_driver.h +++ b/xla/stream_executor/gpu/gpu_driver.h @@ -499,7 +499,7 @@ class GpuDriver { static absl::Status AsynchronousMemsetUint8(Context* context, GpuDevicePtr location, uint8_t value, - size_t uint32_count, + size_t uint8_count, GpuStreamHandle stream); // Performs an asynchronous memset of the device memory segment via diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index 81f7f3d76fbd7..1755c3a63ff1c 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -956,11 +956,11 @@ absl::Status GpuDriver::SynchronousMemsetUint32(Context* context, absl::Status GpuDriver::AsynchronousMemsetUint8(Context* context, hipDeviceptr_t location, uint8 value, - size_t uint32_count, + size_t uint8_count, GpuStreamHandle stream) { ScopedActivateContext activation{context}; RETURN_IF_ROCM_ERROR( - wrap::hipMemsetAsync(location, value, uint32_count, stream), + wrap::hipMemsetAsync(location, value, uint8_count, stream), "Failed to enqueue async memset operation"); return absl::OkStatus(); }