From 170bfaf3563d634f5433e822b03785fbef36465f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Fri, 26 Apr 2024 11:30:07 +0200 Subject: [PATCH] gpu: sycl: binary: add support for remaining post ops --- src/gpu/sycl/binary_kernels.hpp | 111 ++++++++++++++++++++++++++- src/gpu/sycl/ref_binary.cpp | 22 +++++- src/gpu/sycl/ref_binary.hpp | 19 ++--- src/gpu/sycl/sycl_primitive_conf.hpp | 2 + 4 files changed, 137 insertions(+), 17 deletions(-) diff --git a/src/gpu/sycl/binary_kernels.hpp b/src/gpu/sycl/binary_kernels.hpp index 8fb556fab63..49c35a9c575 100644 --- a/src/gpu/sycl/binary_kernels.hpp +++ b/src/gpu/sycl/binary_kernels.hpp @@ -36,14 +36,24 @@ struct binary_kernel_vec_t { xpu::sycl::in_memory_arg_t &src0, xpu::sycl::in_memory_arg_t &src1, xpu::sycl::out_memory_arg_t &dst, xpu::sycl::in_memory_arg_t &src0_scale, - xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt) + xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt, + xpu::sycl::in_memory_arg_t &po1_src, + xpu::sycl::in_memory_arg_t &po2_src, + xpu::sycl::in_memory_arg_t &po3_src, + xpu::sycl::in_memory_arg_t &po4_src, + xpu::sycl::in_memory_arg_t &po5_src) : conf_(conf) , src0_(src0) , src1_(src1) , dst_(dst) , src0_scale_(src0_scale) , src1_scale_(src1_scale) - , scales_dt_(scales_dt) {} + , scales_dt_(scales_dt) + , po1_src_(po1_src) + , po2_src_(po2_src) + , po3_src_(po3_src) + , po4_src_(po4_src) + , po5_src_(po5_src) {} void operator()(::sycl::nd_item<1> item) const { auto sg = item.get_sub_group(); @@ -73,7 +83,7 @@ struct binary_kernel_vec_t { any_broadcast |= conf_.broadcast_dims[i]; } } - if (!any_broadcast + if (!any_broadcast && conf_.post_ops.get_post_op() == 0 && sg_base_idx + (sg.get_local_range()[0] * conf_.block_size) < conf_.wk_size) { for (int i = 0; i < conf_.block_size / vec_len; i++) { @@ -123,7 +133,8 @@ struct binary_kernel_vec_t { if (conf_.do_scale_src1) src1 *= sm_1; auto acc = compute_alg_n(src0, src1, conf_.alg_kind); - acc = conf_.post_ops.apply(acc, dst); + ::sycl::vec post_po_sr = post_op_src_val(idx); + acc = conf_.post_ops.apply(acc, dst, post_po_sr); store_float_value( dst_md().data_type(), acc, dst_ptr(), idx); } @@ -146,6 +157,93 @@ struct binary_kernel_vec_t { return static_cast(src1_scale_.get_pointer()); } + inline ::sycl::vec post_op_src_val(dim_t data_l_off) const { + ::sycl::vec post_po_sr; + const auto maxPostPo = conf_.post_ops.get_post_op(); + + for (dim_t po_idx = 0; po_idx < maxPostPo; po_idx++) { + float res = 0.0f; + if (po_idx == 0) + res = get_post_op_val(po1_src_, po_idx, data_l_off); + else if (po_idx == 1) + res = get_post_op_val(po2_src_, po_idx, data_l_off); + else if (po_idx == 2) + res = get_post_op_val(po3_src_, po_idx, data_l_off); + else if (po_idx == 3) + res = get_post_op_val(po4_src_, po_idx, data_l_off); + else if (po_idx == 4) + res = get_post_op_val(po5_src_, po_idx, data_l_off); + + post_po_sr[po_idx] = res; + } + return post_po_sr; + } + + float get_post_op_val(const xpu::sycl::in_memory_arg_t &bin_src_op, + dim_t &idx, dim_t offset) const { + auto src1_desc = conf_.binary_src_arr[idx]; + + const auto off = get_binary_src1_off( + src1_desc, offset, dst_md().dims(), dst_md().ndims()); + + auto dst = load_float_value( + src1_desc.data_type(), bin_src_op.get_pointer(), off); + return dst; + } + + dim_t get_binary_src1_off(const xpu::sycl::md_t &src1_md, dim_t l_offset, + const xpu::sycl::md_t::dims32_t &dst_dims, + const xpu::sycl::md_t::dim32_t &dst_ndims) const { + const dim_t mask_binary_po + = get_dims_mask(dst_dims, src1_md.dims(), dst_ndims); + return get_po_tensor_off( + src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po); + } + + inline dim_t get_dims_mask(const xpu::sycl::md_t::dims32_t &dims1, + const xpu::sycl::md_t::dims32_t &dims2, const dim_t &ndims, + bool skip_dim_of_one = false) const { + dim_t mask = 0; + for (dim_t d = 0; d < ndims; ++d) { + // Disable mask_bit for dimensions of `1` by request. + dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d); + mask += dims1[d] == dims2[d] ? mask_bit : 0; + } + return mask; + } + + inline dim_t get_po_tensor_off(const xpu::sycl::md_t &tensor_md, + dim_t l_offset, const xpu::sycl::md_t::dims32_t &dst_dims, + const dim_t &dst_ndims, const dim_t &mask) const { + dims_t l_dims_po {}; + get_l_dims_po(l_dims_po, l_offset, dst_dims, dst_ndims, mask); + + return tensor_md.off_v(l_dims_po); + } + + inline void get_l_dims_po(dims_t l_dims_po, dim_t l_offset, + const xpu::sycl::md_t::dims32_t &dst_dims, const dim_t &dst_ndims, + const dim_t &mask) const { + + l_dims_by_l_offset(l_dims_po, l_offset, dst_dims, dst_ndims); + utils::apply_mask_on_dims(l_dims_po, dst_ndims, mask); + } + + inline void l_dims_by_l_offset(dims_t dims_pos, dim_t l_offset, + const xpu::sycl::md_t::dims32_t &dims, const dim_t &ndims) const { + for (dim_t rd = 0; rd < ndims; ++rd) { + const dim_t d = ndims - 1 - rd; + /* switch to faster 32-bit division when possible. */ + if (l_offset <= INT32_MAX && dims[d] <= INT32_MAX) { + dims_pos[d] = (int32_t)l_offset % (int32_t)dims[d]; + l_offset = (int32_t)l_offset / (int32_t)dims[d]; + } else { + dims_pos[d] = l_offset % dims[d]; + l_offset /= dims[d]; + } + } + } + template ::sycl::vec compute_alg(::sycl::vec src0, ::sycl::vec src1, alg_kind_t alg) const { @@ -199,6 +297,11 @@ struct binary_kernel_vec_t { xpu::sycl::in_memory_arg_t src0_scale_; xpu::sycl::in_memory_arg_t src1_scale_; data_type_t scales_dt_; + xpu::sycl::in_memory_arg_t po1_src_; + xpu::sycl::in_memory_arg_t po2_src_; + xpu::sycl::in_memory_arg_t po3_src_; + xpu::sycl::in_memory_arg_t po4_src_; + xpu::sycl::in_memory_arg_t po5_src_; }; } // namespace sycl diff --git a/src/gpu/sycl/ref_binary.cpp b/src/gpu/sycl/ref_binary.cpp index 5a89b009afd..26882e08c9d 100644 --- a/src/gpu/sycl/ref_binary.cpp +++ b/src/gpu/sycl/ref_binary.cpp @@ -52,6 +52,13 @@ status_t ref_binary_t::pd_t::init_conf() { conf_.post_ops = sycl_post_ops_t(attr()); + for (auto i = 0; i < conf_.post_ops.get_post_op(); ++i) { + const auto &e = attr()->post_ops_.entry_[i]; + if (e.is_binary() || e.is_prelu()) { + conf_.binary_src_arr[i] = xpu::sycl::md_t( + arg_md(DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1)); + } + } return status::success; } @@ -62,6 +69,7 @@ status_t ref_binary_t::init(engine_t *engine) { } status_t ref_binary_t::execute(const exec_ctx_t &ctx) const { + parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) { auto src0_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0); auto src1_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_1); @@ -76,9 +84,21 @@ status_t ref_binary_t::execute(const exec_ctx_t &ctx) const { .data_type() : data_type_t::dnnl_f32; + auto src_mem_po_1 = CTX_IN_SYCL_KERNEL_MEMORY( + (DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1)); + auto src_mem_po_2 = CTX_IN_SYCL_KERNEL_MEMORY( + (DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1)); + auto src_mem_po_3 = CTX_IN_SYCL_KERNEL_MEMORY( + (DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1)); + auto src_mem_po_4 = CTX_IN_SYCL_KERNEL_MEMORY( + (DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1)); + auto src_mem_po_5 = CTX_IN_SYCL_KERNEL_MEMORY( + (DNNL_ARG_ATTR_MULTIPLE_POST_OP(4) | DNNL_ARG_SRC_1)); + binary_kernel_vec_t binary_kernel(pd()->conf_, src0_mem_arg, src1_mem_arg, dst_mem_arg, src0_scale_mem_arg, - src1_scale_mem_arg, scales_dt); + src1_scale_mem_arg, scales_dt, src_mem_po_1, src_mem_po_2, + src_mem_po_3, src_mem_po_4, src_mem_po_5); const int block_size = pd()->conf_.block_size; const int wg_size = pd()->conf_.wg_size; diff --git a/src/gpu/sycl/ref_binary.hpp b/src/gpu/sycl/ref_binary.hpp index c7c4f90fe52..09bde014c6d 100644 --- a/src/gpu/sycl/ref_binary.hpp +++ b/src/gpu/sycl/ref_binary.hpp @@ -48,6 +48,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t { const memory_desc_wrapper dst_d(dst_md()); const bool ok = set_default_params() == status::success + && attr_.set_default_formats(dst_md()) == status::success && check_data_types(src0_d, src1_d, dst_d) && check_formats(src0_d, src1_d, dst_d) && attr()->has_default_values( @@ -72,18 +73,12 @@ struct ref_binary_t : public sycl_gpu_primitive_t { } bool post_ops_ok() const { - for (int i = 0; i < attr()->post_ops_.len(); i++) { - const auto &e = attr()->post_ops_.entry_[i]; - if (!IMPLICATION(e.is_eltwise(), - utils::one_of(e.eltwise.alg, alg_kind::eltwise_relu, - alg_kind::eltwise_linear))) { - return false; - } - } - // Binary, prelu and dw conv post-ops are not supported. + // Dw conv post-ops are not supported. return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops && attr()->post_ops_.has_default_values( - {primitive_kind::eltwise}); + {primitive_kind::eltwise, primitive_kind::binary, + primitive_kind::prelu, + primitive_kind::sum}); } static bool check_data_types(const memory_desc_wrapper &src0, @@ -100,7 +95,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t { } return IMPLICATION(utils::one_of(bf16, src0_dt, src1_dt, dst_dt), - src0_dt == src1_dt == dst_dt); + src0_dt == dst_dt && src1_dt == dst_dt); } static bool check_formats(const memory_desc_wrapper &src0, @@ -109,7 +104,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t { using namespace format_tag; for (const auto &mdw : {src0, src1, dst}) { - if (mdw.matches_one_of_tag(ab, abc, abcd, abcde) == undef) { + if (mdw.matches_one_of_tag(a, ab, abc, abcd, abcde) == undef) { return false; } } diff --git a/src/gpu/sycl/sycl_primitive_conf.hpp b/src/gpu/sycl/sycl_primitive_conf.hpp index 2adcfb2034c..d809134905c 100644 --- a/src/gpu/sycl/sycl_primitive_conf.hpp +++ b/src/gpu/sycl/sycl_primitive_conf.hpp @@ -44,6 +44,8 @@ struct sycl_binary_conf_t { int wg_size; int wk_size; + xpu::sycl::md_t binary_src_arr[8]; + sycl_post_ops_t post_ops; };