Skip to content

Commit

Permalink
aarch64: shuffle: fix segv for bf16 cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kawakami-k committed May 13, 2024
1 parent df30226 commit 0e2ca13
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
9 changes: 6 additions & 3 deletions src/cpu/aarch64/shuffle/jit_uni_shuffle.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2022 FUJITSU LIMITED
* Copyright 2020-2024 Intel Corporation
* Copyright 2022-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,7 +58,10 @@ status_t jit_uni_shuffle_t<isa>::pd_t::init(engine_t *engine) {
if (blocked_format == format_tag::undef) return status::unimplemented;

conf_.blk_size = src_d.blocking_desc().strides[ndims() - 1];
conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
/* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, <mod> #1]" is used
to gather data for bf16, simd_w must be calculated
with sizeof(unsigned). */
conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(unsigned);

const bool has_spatial = utils::one_of(ndims(), 3, 4, 5);
const dim_t HW = H() * W();
Expand Down
45 changes: 28 additions & 17 deletions src/cpu/aarch64/shuffle/jit_uni_shuffle_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2021-2022 Intel Corporation
* Copyright 2022 FUJITSU LIMITED
* Copyright 2021-2024 Intel Corporation
* Copyright 2022-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,8 +48,10 @@ template <cpu_isa_t isa>
void jit_uni_shuffle_kernel_t<isa>::prepare_mask() {
using namespace data_type;
if (conf_.simd_tail > 0) {
assert(utils::one_of(conf_.data_type, f32, s32));
assert(conf_.simd_tail < isa_sveLen / sizeof(float));
/* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, <mod> #1]" is used
to gather data for bf16, simd_tail must be evaluated
with sizeof(unsigned). */
assert(conf_.simd_tail < isa_sveLen / sizeof(unsigned));
index(vmm_tmp_.s, 0, 1);
cmplt(k_tail_mask_.s, P_ALL_ONE / T_z, vmm_tmp_.s, conf_.simd_tail);
}
Expand All @@ -68,13 +70,17 @@ void jit_uni_shuffle_kernel_t<asimd>::prepare_mask() {}
template <cpu_isa_t isa>
void jit_uni_shuffle_kernel_t<isa>::gather_data(const XReg &reg_src_addr,
const int indices_idx, const int data_idx, const bool is_tail) {
if (conf_.dt_size == sizeof(float)) {
const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
using namespace data_type;
const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;

if (utils::one_of(conf_.data_type, f32, s32)) {
lsr(TRegS(indices_idx), TRegS(indices_idx), 2);
ld1w(TRegS(data_idx), mask / T_z,
ptr(reg_src_addr, TRegS(indices_idx), UXTW, 2));
} else {
assert(!"unsupported emu_gather_data");
} else if (conf_.data_type == bf16) {
lsr(TRegS(indices_idx), TRegS(indices_idx), 1);
ld1h(TRegS(data_idx), mask / T_z,
ptr(reg_src_addr, TRegS(indices_idx), UXTW, 1));
}
}

Expand All @@ -97,21 +103,26 @@ void jit_uni_shuffle_kernel_t<asimd>::gather_data(const XReg &addr,
template <cpu_isa_t isa>
void jit_uni_shuffle_kernel_t<isa>::store_data(const int data_idx,
const XReg &reg_dst_addr, const int offset, const bool is_tail) {
using namespace data_type;
const auto extend_for_padding
= is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w;
const PReg &mask = is_tail ? k_tail_mask_ : P_ALL_ONE;

add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);

if (extend_for_padding) {
sel(vmm_tmp_.s, k_tail_mask_, TRegS(data_idx), vmm_zero_.s);
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
if (utils::one_of(conf_.data_type, f32, s32))
st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
else // bf16
st1h(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
} else {
if (is_tail) {
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
st1w(TRegS(data_idx), k_tail_mask_, ptr(X_DEFAULT_ADDR));
} else {
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
st1w(TRegS(data_idx), P_ALL_ONE, ptr(X_DEFAULT_ADDR));
}
if (utils::one_of(conf_.data_type, f32, s32))
st1w(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR));
else // bf16
st1h(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR));
}

append_zero_padding(
reg_dst_, isa_sveLen > 128 ? extend_for_padding : false);
}
Expand Down

0 comments on commit 0e2ca13

Please sign in to comment.