diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 60d40b71d..f64787739 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -593,6 +593,29 @@ void _layer_norm_backward_kernel( dY, X, mean_data, var_data, dgamma, dbeta, config_w); } +template +void bridge_layer_norm_kernel( + const Tensor& X, + const Tensor& gamma, + const Tensor& beta, + int64_t M, + int64_t N, + double eps, + Tensor& Y, + Tensor& mean, + Tensor& rstd) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + gamma.scalar_type(), + "layer_norm_xpu", + [&]() { + using acc_t = acc_type_device; + _layer_norm_kernel( + X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); + }); +} + std::tuple layer_norm_kernel( const Tensor& X, const Tensor& gamma, @@ -608,11 +631,10 @@ std::tuple layer_norm_kernel( at::ScalarType::Half, at::ScalarType::BFloat16, X.scalar_type(), - "layer_norm_xpu", + "bridge_layer_norm_xpu", [&]() { - using acc_t = acc_type_device; - _layer_norm_kernel( - X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); + bridge_layer_norm_kernel( + X, gamma, beta, M, N, eps, Y, mean, rstd); }); }