Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpu: sycl: binary: add support for remaining post ops #1880

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 107 additions & 4 deletions src/gpu/sycl/binary_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,24 @@ struct binary_kernel_vec_t {
xpu::sycl::in_memory_arg_t &src0, xpu::sycl::in_memory_arg_t &src1,
xpu::sycl::out_memory_arg_t &dst,
xpu::sycl::in_memory_arg_t &src0_scale,
xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt)
xpu::sycl::in_memory_arg_t &src1_scale, data_type_t scales_dt,
xpu::sycl::in_memory_arg_t &po1_src,
xpu::sycl::in_memory_arg_t &po2_src,
xpu::sycl::in_memory_arg_t &po3_src,
xpu::sycl::in_memory_arg_t &po4_src,
xpu::sycl::in_memory_arg_t &po5_src)
mgouicem marked this conversation as resolved.
Show resolved Hide resolved
: conf_(conf)
, src0_(src0)
, src1_(src1)
, dst_(dst)
, src0_scale_(src0_scale)
, src1_scale_(src1_scale)
, scales_dt_(scales_dt) {}
, scales_dt_(scales_dt)
, po1_src_(po1_src)
, po2_src_(po2_src)
, po3_src_(po3_src)
, po4_src_(po4_src)
, po5_src_(po5_src) {}

void operator()(::sycl::nd_item<1> item) const {
auto sg = item.get_sub_group();
Expand Down Expand Up @@ -73,7 +83,7 @@ struct binary_kernel_vec_t {
any_broadcast |= conf_.broadcast_dims[i];
}
}
if (!any_broadcast
if (!any_broadcast && conf_.post_ops.get_post_op() == 0
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size) {
for (int i = 0; i < conf_.block_size / vec_len; i++) {
Expand Down Expand Up @@ -123,7 +133,8 @@ struct binary_kernel_vec_t {
if (conf_.do_scale_src1) src1 *= sm_1;

auto acc = compute_alg_n(src0, src1, conf_.alg_kind);
acc = conf_.post_ops.apply(acc, dst);
::sycl::vec<float, 16> post_po_sr = post_op_src_val(idx);
acc = conf_.post_ops.apply(acc, dst, post_po_sr);
store_float_value(
dst_md().data_type(), acc, dst_ptr(), idx);
}
Expand All @@ -146,6 +157,93 @@ struct binary_kernel_vec_t {
return static_cast<float *>(src1_scale_.get_pointer());
}

inline ::sycl::vec<float, 16> post_op_src_val(dim_t data_l_off) const {
::sycl::vec<float, 16> post_po_sr;
const auto maxPostPo = conf_.post_ops.get_post_op();

for (dim_t po_idx = 0; po_idx < maxPostPo; po_idx++) {
float res = 0.0f;
if (po_idx == 0)
res = get_post_op_val(po1_src_, po_idx, data_l_off);
else if (po_idx == 1)
res = get_post_op_val(po2_src_, po_idx, data_l_off);
else if (po_idx == 2)
res = get_post_op_val(po3_src_, po_idx, data_l_off);
else if (po_idx == 3)
res = get_post_op_val(po4_src_, po_idx, data_l_off);
else if (po_idx == 4)
res = get_post_op_val(po5_src_, po_idx, data_l_off);

mgouicem marked this conversation as resolved.
Show resolved Hide resolved
post_po_sr[po_idx] = res;
}
return post_po_sr;
}

float get_post_op_val(const xpu::sycl::in_memory_arg_t &bin_src_op,
dim_t &idx, dim_t offset) const {
auto src1_desc = conf_.binary_src_arr[idx];

const auto off = get_binary_src1_off(
src1_desc, offset, dst_md().dims(), dst_md().ndims());

auto dst = load_float_value(
src1_desc.data_type(), bin_src_op.get_pointer(), off);
return dst;
}

dim_t get_binary_src1_off(const xpu::sycl::md_t &src1_md, dim_t l_offset,
const xpu::sycl::md_t::dims32_t &dst_dims,
const xpu::sycl::md_t::dim32_t &dst_ndims) const {
const dim_t mask_binary_po
= get_dims_mask(dst_dims, src1_md.dims(), dst_ndims);
return get_po_tensor_off(
src1_md, l_offset, dst_dims, dst_ndims, mask_binary_po);
}

inline dim_t get_dims_mask(const xpu::sycl::md_t::dims32_t &dims1,
const xpu::sycl::md_t::dims32_t &dims2, const dim_t &ndims,
bool skip_dim_of_one = false) const {
dim_t mask = 0;
for (dim_t d = 0; d < ndims; ++d) {
// Disable mask_bit for dimensions of `1` by request.
dim_t mask_bit = skip_dim_of_one && dims1[d] == 1 ? 0 : (1 << d);
mask += dims1[d] == dims2[d] ? mask_bit : 0;
}
return mask;
}

inline dim_t get_po_tensor_off(const xpu::sycl::md_t &tensor_md,
dim_t l_offset, const xpu::sycl::md_t::dims32_t &dst_dims,
const dim_t &dst_ndims, const dim_t &mask) const {
dims_t l_dims_po {};
get_l_dims_po(l_dims_po, l_offset, dst_dims, dst_ndims, mask);

return tensor_md.off_v(l_dims_po);
}

inline void get_l_dims_po(dims_t l_dims_po, dim_t l_offset,
const xpu::sycl::md_t::dims32_t &dst_dims, const dim_t &dst_ndims,
const dim_t &mask) const {

l_dims_by_l_offset(l_dims_po, l_offset, dst_dims, dst_ndims);
utils::apply_mask_on_dims(l_dims_po, dst_ndims, mask);
}

inline void l_dims_by_l_offset(dims_t dims_pos, dim_t l_offset,
const xpu::sycl::md_t::dims32_t &dims, const dim_t &ndims) const {
for (dim_t rd = 0; rd < ndims; ++rd) {
const dim_t d = ndims - 1 - rd;
/* switch to faster 32-bit division when possible. */
if (l_offset <= INT32_MAX && dims[d] <= INT32_MAX) {
dims_pos[d] = (int32_t)l_offset % (int32_t)dims[d];
l_offset = (int32_t)l_offset / (int32_t)dims[d];
} else {
dims_pos[d] = l_offset % dims[d];
l_offset /= dims[d];
}
}
}

template <int width>
::sycl::vec<float, width> compute_alg(::sycl::vec<float, width> src0,
::sycl::vec<float, width> src1, alg_kind_t alg) const {
Expand Down Expand Up @@ -199,6 +297,11 @@ struct binary_kernel_vec_t {
xpu::sycl::in_memory_arg_t src0_scale_;
xpu::sycl::in_memory_arg_t src1_scale_;
data_type_t scales_dt_;
xpu::sycl::in_memory_arg_t po1_src_;
xpu::sycl::in_memory_arg_t po2_src_;
xpu::sycl::in_memory_arg_t po3_src_;
xpu::sycl::in_memory_arg_t po4_src_;
xpu::sycl::in_memory_arg_t po5_src_;
};

} // namespace sycl
Expand Down
22 changes: 21 additions & 1 deletion src/gpu/sycl/ref_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ status_t ref_binary_t::pd_t::init_conf() {

conf_.post_ops = sycl_post_ops_t(attr());

for (auto i = 0; i < conf_.post_ops.get_post_op(); ++i) {
const auto &e = attr()->post_ops_.entry_[i];
if (e.is_binary() || e.is_prelu()) {
conf_.binary_src_arr[i] = xpu::sycl::md_t(
arg_md(DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1));
}
}
return status::success;
}

Expand All @@ -62,6 +69,7 @@ status_t ref_binary_t::init(engine_t *engine) {
}

status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {

parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
auto src0_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_0);
auto src1_mem_arg = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_SRC_1);
Expand All @@ -76,9 +84,21 @@ status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {
.data_type()
: data_type_t::dnnl_f32;

auto src_mem_po_1 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1));
auto src_mem_po_2 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1));
auto src_mem_po_3 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1));
auto src_mem_po_4 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1));
auto src_mem_po_5 = CTX_IN_SYCL_KERNEL_MEMORY(
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(4) | DNNL_ARG_SRC_1));

binary_kernel_vec_t binary_kernel(pd()->conf_, src0_mem_arg,
src1_mem_arg, dst_mem_arg, src0_scale_mem_arg,
src1_scale_mem_arg, scales_dt);
src1_scale_mem_arg, scales_dt, src_mem_po_1, src_mem_po_2,
src_mem_po_3, src_mem_po_4, src_mem_po_5);

const int block_size = pd()->conf_.block_size;
const int wg_size = pd()->conf_.wg_size;
Expand Down
19 changes: 7 additions & 12 deletions src/gpu/sycl/ref_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
const memory_desc_wrapper dst_d(dst_md());

const bool ok = set_default_params() == status::success
&& attr_.set_default_formats(dst_md()) == status::success
&& check_data_types(src0_d, src1_d, dst_d)
&& check_formats(src0_d, src1_d, dst_d)
&& attr()->has_default_values(
Expand All @@ -72,18 +73,12 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
}

bool post_ops_ok() const {
for (int i = 0; i < attr()->post_ops_.len(); i++) {
const auto &e = attr()->post_ops_.entry_[i];
if (!IMPLICATION(e.is_eltwise(),
utils::one_of(e.eltwise.alg, alg_kind::eltwise_relu,
alg_kind::eltwise_linear))) {
return false;
}
}
// Binary, prelu and dw conv post-ops are not supported.
// Dw conv post-ops are not supported.
return attr()->post_ops_.len() <= sycl_post_ops_t::max_post_ops
&& attr()->post_ops_.has_default_values(
{primitive_kind::eltwise});
{primitive_kind::eltwise, primitive_kind::binary,
primitive_kind::prelu,
primitive_kind::sum});
}

static bool check_data_types(const memory_desc_wrapper &src0,
Expand All @@ -100,7 +95,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
}

return IMPLICATION(utils::one_of(bf16, src0_dt, src1_dt, dst_dt),
src0_dt == src1_dt == dst_dt);
src0_dt == dst_dt && src1_dt == dst_dt);
}

static bool check_formats(const memory_desc_wrapper &src0,
Expand All @@ -109,7 +104,7 @@ struct ref_binary_t : public sycl_gpu_primitive_t {
using namespace format_tag;

for (const auto &mdw : {src0, src1, dst}) {
if (mdw.matches_one_of_tag(ab, abc, abcd, abcde) == undef) {
if (mdw.matches_one_of_tag(a, ab, abc, abcd, abcde) == undef) {
return false;
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/gpu/sycl/sycl_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct sycl_binary_conf_t {
int wg_size;
int wk_size;

xpu::sycl::md_t binary_src_arr[8];

sycl_post_ops_t post_ops;
};

Expand Down