diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip index ca44bba42..034f66fa6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip @@ -12,11 +12,14 @@ #include #include -#include +#if !defined(USE_ROCM) +#include +#endif + #include #if defined(USE_ROCM) - +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index 3450d2cad..066a84c6d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -15,11 +15,14 @@ #include #include -#include +#if !defined(USE_ROCM) +#include +#endif + #include #if defined(USE_ROCM) - +#include #include "kernels/fp8_rowwise_kernel_manifest.h" namespace fbgemm_gpu { diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip index 04b2ba3b0..4e925aed2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip @@ -12,11 +12,14 @@ #include #include -#include +#if !defined(USE_ROCM) +#include +#endif + #include #if defined(USE_ROCM) - +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h index 717720580..da4087989 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_common.h @@ -12,10 +12,13 @@ #include #include +#if !defined(USE_ROCM) #include +#endif #include #if defined(USE_ROCM) +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"