Skip to content

Commit

Permalink
Various masked operations
Browse files Browse the repository at this point in the history
  • Loading branch information
mazimkhan committed Nov 18, 2024
1 parent d77be29 commit d68360c
Show file tree
Hide file tree
Showing 7 changed files with 510 additions and 0 deletions.
43 changes: 43 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,9 @@ not a concern, these are equivalent to, and potentially more efficient than,
<code>V **MaskedSatSubOr**(V no, M m, V a, V b)</code>: returns `a[i] +
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if
`m[i]` is false.
* `V`: `{i,f}` \
<code>V **MaskedAbsOr**(M m, V a, V b)</code>: returns the absolute value of
`a[i]` where m is active and returns `b[i]` otherwise.

#### Shifts

Expand Down Expand Up @@ -1050,6 +1053,9 @@ types, and on SVE/RVV.

* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`.

* <code>V **MaskedOrOrZero**(M m, V a, V b)</code>: returns `a[i] || b[i]`
or `zero` if `m[i]` is false.

The following three-argument functions may be more efficient than assembling
them from 2-argument functions:

Expand Down Expand Up @@ -1756,6 +1762,9 @@ All functions except `Stream` are defined in cache_control.h.
`DemoteToNearestInt(d, v)` is more efficient on some targets, including x86
and RVV.

* <code>Vec&lt;D&gt; **MaskedConvertToOrZero**(M m, D d, V v)</code>: returns `v[i]`
converted to `D` where m is active and returns zero otherwise.

#### Single vector demotion

These functions demote a full vector (or parts thereof) into a vector of half
Expand Down Expand Up @@ -2237,6 +2246,22 @@ The following `ReverseN` must not be called if `Lanes(D()) < N`:
must be in the range `[0, 2 * Lanes(d))` but need not be unique. The index
type `TI` must be an integer of the same size as `TFromD<D>`.
* <code>V **TableLookupLanesOr**(M m, V a, V b, unspecified)</code> returns the
result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and returns
`b[i]` where `m[i]` is false.
* <code>V **TableLookupLanesOrZero**(M m, V a, unspecified)</code> returns
the result of `TableLookupLanes(a, unspecified)` where `m[i]` is true, and
returns zero where `m[i]` is false.
* <code>V **TwoTablesLookupLanesOr**(D d, M m, V a, V b, unspecified)</code>
returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where
`m[i]` is true, and `a[i]` where `m[i]` is false.
* <code>V **TwoTablesLookupLanesOrZero**(D d, M m, V a, V b, unspecified)</code>
returns the result of `TwoTablesLookupLanes(V a, V b, unspecified)` where
`m[i]` is true, and zero where `m[i]` is false.
* <code>V **Per4LaneBlockShuffle**&lt;size_t kIdx3, size_t kIdx2, size_t
kIdx1, size_t kIdx0&gt;(V v)</code> does a per 4-lane block shuffle of `v`
if `Lanes(DFromV<V>())` is greater than or equal to 4 or a shuffle of the
Expand Down Expand Up @@ -2377,6 +2402,24 @@ more efficient on some targets.
* <code>T **ReduceMin**(D, V v)</code>: returns the minimum of all lanes.
* <code>T **ReduceMax**(D, V v)</code>: returns the maximum of all lanes.
### Masked reductions
**Note**: Horizontal operations (across lanes of the same vector) such as
reductions are slower than normal SIMD operations and are typically used outside
critical loops.
All ops in this section ignore lanes where `mask=false`. These are equivalent
to, and potentially more efficient than, `GetLane(SumOfLanes(d,
IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask
elements are false.
* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes
where `m[i]` is `true`.
* <code>T **MaskedReduceMin**(D, M m, V v)</code>: returns the minimum of all
lanes where `m[i]` is `true`.
* <code>T **MaskedReduceMax**(D, M m, V v)</code>: returns the maximum of all
lanes where `m[i]` is `true`.
### Crypto
Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
Expand Down
119 changes: 119 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(b, m, a); \
}
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
return sv##OP##_##CHAR##BITS##_z(m, a); \
}

// vector = f(vector, scalar), e.g. detail::AddN
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -252,6 +261,17 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_x(m, a, b); \
}
#define HWY_SVE_RETV_ARGMVV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(m, a, b); \
}
// User-specified mask. Mask=false value is zero.
#define HWY_SVE_RETV_ARGMVVZ(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_z(m, a, b); \
}

#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
Expand All @@ -260,6 +280,13 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
return sv##OP##_##CHAR##BITS(a, b, c); \
}

#define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_m(m, a, b, c); \
}

// ------------------------------ Lanes

namespace detail {
Expand Down Expand Up @@ -727,6 +754,9 @@ HWY_API V Or(const V a, const V b) {
return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
}

// ------------------------------ MaskedOrOrZero
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVVZ, MaskedOrOrZero, orr)

// ------------------------------ Xor

namespace detail {
Expand Down Expand Up @@ -862,6 +892,12 @@ HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs)
#endif // HWY_SVE_HAVE_2

// ------------------------------ MaskedAbsOr
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs)

// ------------------------------ MaskedAbsOrZero
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbsOrZero, abs)

// ================================================== ARITHMETIC

// Per-target flags to prevent generic_ops-inl.h defining Add etc.
Expand Down Expand Up @@ -1272,6 +1308,11 @@ HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)

#undef HWY_SVE_FMA

// ------------------------------ MaskedMulAdd
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad)
}

// ------------------------------ Round etc.

HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
Expand Down Expand Up @@ -1515,6 +1556,7 @@ HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) {
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVZ, MaskedMaxOrZero, max)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul)
Expand Down Expand Up @@ -2849,6 +2891,41 @@ HWY_API svfloat32_t DemoteTo(Simd<float, N, kPow2> d, const svuint64_t v) {
HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt)
#undef HWY_SVE_CONVERT

// ------------------------------ MaskedConvertToOrZero F

#define HWY_SVE_MASKED_CONVERT_TO_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \
/* Float from signed */ \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
HWY_SVE_V(int, BITS) v) { \
return sv##OP##_##CHAR##BITS##_s##BITS##_z(m, v); \
} \
/* Float from unsigned */ \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
HWY_SVE_V(uint, BITS) v) { \
return sv##OP##_##CHAR##BITS##_u##BITS##_z(m, v); \
} \
/* Signed from float, rounding toward zero */ \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(int, BITS) \
NAME(svbool_t m, HWY_SVE_D(int, BITS, N, kPow2) /* d */, \
HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_s##BITS##_##CHAR##BITS##_z(m, v); \
} \
/* Unsigned from float, rounding toward zero */ \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(uint, BITS) \
NAME(svbool_t m, HWY_SVE_D(uint, BITS, N, kPow2) /* d */, \
HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_u##BITS##_##CHAR##BITS##_z(m, v); \
}

HWY_SVE_FOREACH_F(HWY_SVE_MASKED_CONVERT_TO_OR_ZERO, MaskedConvertToOrZero, cvt)
#undef HWY_SVE_MASKED_CONVERT_TO_OR_ZERO

// ------------------------------ NearestInt (Round, ConvertTo)
template <class VF, class DI = RebindToSigned<DFromV<VF>>>
HWY_API VFromD<DI> NearestInt(VF v) {
Expand Down Expand Up @@ -3288,6 +3365,25 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
return detail::MaxOfLanesM(detail::MakeMask(d), v);
}

#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) {
return detail::SumOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) {
return detail::MinOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) {
return detail::MaxOfLanesM(m, v);
}

// ------------------------------ SumOfLanes

template <class D, HWY_IF_LANES_GT_D(D, 1)>
Expand Down Expand Up @@ -4755,6 +4851,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
return IfThenElse(IsNegative(v), yes, no);
}
// ------------------------------ IfNegativeThenNegOrUndefIfZero

#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#else
#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#endif

#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \
}

HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg)

#undef HWY_SVE_NEG_IF

// ------------------------------ AverageRound (ShiftRight)

Expand Down Expand Up @@ -6291,13 +6404,19 @@ HWY_API V HighestSetBitIndex(V v) {
#undef HWY_SVE_IF_NOT_EMULATED_D
#undef HWY_SVE_PTRUE
#undef HWY_SVE_RETV_ARGMVV
#undef HWY_SVE_RETV_ARGMVVZ
#undef HWY_SVE_RETV_ARGPV
#undef HWY_SVE_RETV_ARGPVN
#undef HWY_SVE_RETV_ARGPVV
#undef HWY_SVE_RETV_ARGV
#undef HWY_SVE_RETV_ARGVN
#undef HWY_SVE_RETV_ARGMV
#undef HWY_SVE_RETV_ARGMV_M
#undef HWY_SVE_RETV_ARGMV_Z
#undef HWY_SVE_RETV_ARGVV
#undef HWY_SVE_RETV_ARGMVV_M
#undef HWY_SVE_RETV_ARGVVV
#undef HWY_SVE_RETV_ARGMVVV
#undef HWY_SVE_T
#undef HWY_SVE_UNDEFINED
#undef HWY_SVE_V
Expand Down
67 changes: 67 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,18 @@ HWY_API V SaturatedAbs(V v) {

#endif

// ------------------------------ MaskedAbsOr
template <class V, HWY_IF_SIGNED_V(V), class M>
HWY_API V MaskedAbsOr(M m, V v, V no) {
return IfThenElse(m, Abs(v), no);
}

// ------------------------------ MaskedAbsOrZero
template <class V, HWY_IF_SIGNED_V(V), class M>
HWY_API V MaskedAbsOrZero(M m, V v) {
return IfThenElseZero(m, Abs(v));
}

// ------------------------------ Reductions

// Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled,
Expand Down Expand Up @@ -882,6 +894,28 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
}
#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8

#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
HWY_API TFromD<D> MaskedReduceSum(D d, M m, VFromD<D> v) {
return ReduceSum(d, IfThenElseZero(m, v));
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) {
return ReduceMin(d, IfThenElse(m, v, MaxOfLanes(d, v)));
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) {
return ReduceMax(d, IfThenElseZero(m, v));
}

#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR

// ------------------------------ IsEitherNaN
#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_IS_EITHER_NAN
Expand Down Expand Up @@ -6444,6 +6478,30 @@ HWY_API V ReverseBits(V v) {
}
#endif // HWY_NATIVE_REVERSE_BITS_UI16_32_64

// ------------------------------ TableLookupLanesOr
template <class V, class M>
HWY_API V TableLookupLanesOr(M m, V a, V b, IndicesFromD<DFromV<V>> idx) {
return IfThenElse(m, TableLookupLanes(a, idx), b);
}

// ------------------------------ TableLookupLanesOrZero
template <class V, class M>
HWY_API V TableLookupLanesOrZero(M m, V a, IndicesFromD<DFromV<V>> idx) {
return IfThenElseZero(m, TableLookupLanes(a, idx));
}

// ------------------------------ TwoTablesLookupLanesOr
template <class D, class V, class M>
HWY_API V TwoTablesLookupLanesOr(D d, M m, V a, V b, IndicesFromD<D> idx) {
return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), a);
}

// ------------------------------ TwoTablesLookupLanesOrZero
template <class D, class V, class M>
HWY_API V TwoTablesLookupLanesOrZero(D d, M m, V a, V b, IndicesFromD<D> idx) {
return IfThenElse(m, TwoTablesLookupLanes(d, a, b, idx), Zero(d));
}

// ------------------------------ Per4LaneBlockShuffle

#if (defined(HWY_NATIVE_PER4LANEBLKSHUF_DUP32) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -7299,6 +7357,15 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

template <class V, class M>
HWY_API V MaskedMaxOrZero(M m, V a, V b) {
return IfThenElseZero(m, (Max(a, b)));
}

template <class V, class M>
HWY_API V MaskedOrOrZero(M m, V a, V b) {
return IfThenElseZero(m, Or(a, b));
}
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
Loading

0 comments on commit d68360c

Please sign in to comment.