Skip to content

Commit

Permalink
Load/Store, masked set and counting operations
Browse files Browse the repository at this point in the history
  • Loading branch information
mazimkhan committed Nov 18, 2024
1 parent d77be29 commit c1f7768
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 0 deletions.
28 changes: 28 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ for comparisons, for example `Lt` instead of `operator<`.
the result, with `t0` in the least-significant (lowest-indexed) lane of each
128-bit block and `tK` in the most-significant (highest-indexed) lane of
each 128-bit block: `{t0, t1, ..., tK}`
* <code>V **SetOr**(V no, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else `no[i]`.
* <code>V **SetOrZero**(D d, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else 0.

### Getting/setting lanes

Expand Down Expand Up @@ -1015,6 +1019,10 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
leading zeros in each lane. For any lanes where ```a[i]``` is zero,
```sizeof(TFromV<V>) * 8``` is returned in the corresponding result lanes.

* `V`: `{u,i}` \
<code>V **MaskedLeadingZeroCountOrZero**(M m, `V a)</code>: returns the
result of LeadingZeroCount where `m[i]` is true, and zero otherwise.

* `V`: `{u,i}` \
<code>V **TrailingZeroCount**(V a)</code>: returns the number of
trailing zeros in each lane. For any lanes where ```a[i]``` is zero,
Expand All @@ -1029,6 +1037,12 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
```HighestValue<MakeSigned<TFromV<V>>>()``` is returned in the
corresponding result lanes.

* <code>bool **AllOnes**(D, V v)</code>: returns whether all bits in `v[i]`
are set.

* <code>bool **AllZeros**(D, V v)</code>: returns whether all bits in `v[i]`
are clear.

The following operate on individual bits within each lane. Note that the
non-operator functions (`And` instead of `&`) must be used for floating-point
types, and on SVE/RVV.
Expand Down Expand Up @@ -1508,6 +1522,9 @@ aligned memory at indices which are not a multiple of the vector length):

* <code>Vec&lt;D&gt; **LoadU**(D, const T* p)</code>: returns `p[i]`.

* <code>Vec&lt;D&gt; **MaskedLoadU**(D, M m, const T* p)</code>: returns `p[i]`
where mask is true and returns zero otherwise.

* <code>Vec&lt;D&gt; **LoadDup128**(D, const T* p)</code>: returns one 128-bit
block loaded from `p` and broadcasted into all 128-bit block\[s\]. This may
be faster than broadcasting single values, and is more convenient than
Expand Down Expand Up @@ -1538,6 +1555,10 @@ aligned memory at indices which are not a multiple of the vector length):
lanes from `p` to the first (lowest-index) lanes of the result vector and
fills the remaining lanes with `no`. Like LoadN, this does not fault.

* <code> Vec&lt;D&gt; **LoadHigher**(D d, V v, T* p)</code>: Loads `Lanes(d)/2` lanes from
`p` into the upper lanes of the result vector and the lower half of `v` into
the lower lanes.

#### Store

* <code>void **Store**(Vec&lt;D&gt; v, D, T* aligned)</code>: copies `v[i]`
Expand Down Expand Up @@ -1577,6 +1598,13 @@ aligned memory at indices which are not a multiple of the vector length):
StoreN does not modify any memory past
`p + HWY_MIN(Lanes(d), max_lanes_to_store) - 1`.

* <code>void **StoreTruncated**(Vec&lt;DFrom&gt; v, DFrom d, To* HWY_RESTRICT
p)</code>: Truncates elements of `v` to type `To` and stores on `p`. It is
similar to performing `TruncateTo` followed by `StoreN` where
`max_lanes_to_store` is `Lanes(d)`.

StoreTruncated does not modify any memory past `p + Lanes(d) - 1`.

#### Interleaved

* <code>void **LoadInterleaved2**(D, const T* p, Vec&lt;D&gt;&amp; v0,
Expand Down
80 changes: 80 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,27 @@ using VFromD = decltype(Set(D(), TFromD<D>()));

using VBF16 = VFromD<ScalableTag<bfloat16_t>>;

// ------------------------------ SetOr/SetOrZero

#define HWY_SVE_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) inactive, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_m(inactive, m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_SET_OR, SetOr, dup_n)
#undef HWY_SVE_SET_OR

#define HWY_SVE_SET_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_z(m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_SET_OR_ZERO, SetOrZero, dup_n)
#undef HWY_SVE_SET_OR_ZERO

// ------------------------------ Zero

template <class D>
Expand Down Expand Up @@ -1856,6 +1877,18 @@ HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d,

#undef HWY_SVE_MEM

#define HWY_SVE_MASKED_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
MaskedLoadU(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m, \
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \
}

HWY_SVE_FOREACH(HWY_SVE_MASKED_MEM, _, _)

#undef HWY_SVE_MASKED_MEM

#if HWY_TARGET != HWY_SVE2_128
namespace detail {
#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -1902,6 +1935,37 @@ HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {

#endif // HWY_TARGET != HWY_SVE2_128

// Truncate to smaller size and store
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \
template <size_t N, int kPow2> \
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
const HWY_SVE_D(BASE, BITS, N, kPow2) d, \
HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \
sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \
}

#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8)
#define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16)
#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32)

HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, StoreTruncated, st1w)

#undef HWY_SVE_STORE_TRUNCATED

// ------------------------------ Load/Store

// SVE only requires lane alignment, not natural alignment of the entire
Expand Down Expand Up @@ -6258,6 +6322,22 @@ HWY_API V HighestSetBitIndex(V v) {
return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v)));
}

#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

#define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
const DFromV<decltype(v)> d; \
return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \
}

HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT,
MaskedLeadingZeroCountOrZero, clz)
#undef HWY_SVE_LEADING_ZERO_COUNT

// ================================================== END MACROS
#undef HWY_SVE_ALL_PTRUE
#undef HWY_SVE_D
Expand Down
97 changes: 97 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ HWY_API Vec<D> Inf(D d) {
return BitCast(d, Set(du, max_x2 >> 1));
}

// ------------------------------ SetOr/SetOrZero

template <class V, typename T = TFromV<V>, typename D = DFromV<V>,
typename M = MFromD<D>>
HWY_API V SetOr(V no, M m, T a) {
D d;
return IfThenElse(m, Set(d, a), no);
}

template <class D, typename V = VFromD<D>, typename M = MFromD<D>,
typename T = TFromD<D>>
HWY_API V SetOrZero(D d, M m, T a) {
return IfThenElseZero(m, Set(d, a));
}

// ------------------------------ ZeroExtendResizeBitCast

// The implementation of detail::ZeroExtendResizeBitCast for the HWY_EMU128
Expand Down Expand Up @@ -336,6 +351,20 @@ HWY_API Mask<DTo> DemoteMaskTo(DTo d_to, DFrom d_from, Mask<DFrom> m) {

#endif // HWY_NATIVE_DEMOTE_MASK_TO

// ------------------------------ LoadHigher
#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_LOAD_HIGHER
#undef HWY_NATIVE_LOAD_HIGHER
#else
#define HWY_NATIVE_LOAD_HIGHER
#endif
template <class D, typename T, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V LoadHigher(D d, V a, T* p) {
const V b = LoadU(d, p);
return ConcatLowerLower(d, b, a);
}
#endif // HWY_NATIVE_LOAD_HIGHER

// ------------------------------ CombineMasks

#if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -1118,6 +1147,12 @@ HWY_API V MulByFloorPow2(V v, V exp) {

#endif // HWY_NATIVE_MUL_BY_POW2

// ------------------------------ MaskedLoadU
template <class D, class M>
HWY_API VFromD<D> MaskedLoadU(D d, M m,
const TFromD<D>* HWY_RESTRICT unaligned) {
return IfThenElseZero(m, LoadU(d, unaligned));
}
// ------------------------------ LoadInterleaved2

#if HWY_IDE || \
Expand Down Expand Up @@ -2574,6 +2609,25 @@ HWY_API void StoreN(VFromD<D> v, D d, T* HWY_RESTRICT p,

#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ StoreTruncated
#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

template <class DFrom, class To, class DTo = Rebind<To, DFrom>,
HWY_IF_T_SIZE_GT_D(DFrom, sizeof(To)),
HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(VFromD<DFrom>)>
HWY_API void StoreTruncated(VFromD<DFrom> v, const DFrom d,
To* HWY_RESTRICT p) {
DTo dsmall;
StoreN(TruncateTo(dsmall, v), dsmall, p, Lanes(d));
}

#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ Scatter

#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -3808,6 +3862,21 @@ HWY_API V TrailingZeroCount(V v) {
}
#endif // HWY_NATIVE_LEADING_ZERO_COUNT

// ------------------------------ MaskedLeadingZeroCountOrZero
#if (defined(HWY_NATIVE_MASKED_LEADING_ZERO_COUNT) == \
defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), class M>
HWY_API V MaskedLeadingZeroCountOrZero(M m, V v) {
return IfThenElseZero(m, LeadingZeroCount(v));
}
#endif // HWY_NATIVE_MASKED_LEADING_ZERO_COUNT

// ------------------------------ AESRound

// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes.
Expand Down Expand Up @@ -7299,6 +7368,34 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

// ------------------------------ AllOnes/AllZeros
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLONES
#undef HWY_NATIVE_ALLONES
#else
#define HWY_NATIVE_ALLONES
#endif

template <class V>
HWY_API bool AllOnes(V a) {
DFromV<V> d;
return AllTrue(d, Eq(Not(a), Zero(d)));
}
#endif // HWY_NATIVE_ALLONES

#if (defined(HWY_NATIVE_ALLZEROS) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLZEROS
#undef HWY_NATIVE_ALLZEROS
#else
#define HWY_NATIVE_ALLZEROS
#endif

template <class V>
HWY_API bool AllZeros(V a) {
DFromV<V> d;
return AllTrue(d, Eq(a, Zero(d)));
}
#endif // HWY_NATIVE_ALLZEROS
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
43 changes: 43 additions & 0 deletions hwy/tests/count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,48 @@ HWY_NOINLINE void TestAllLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestLeadingZeroCount>());
}

struct TestMaskedLeadingZeroCount {
template <class T, class D>
HWY_ATTR_NO_MSAN HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;
using TU = MakeUnsigned<T>;
const RebindToUnsigned<decltype(d)> du;
size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto data = AllocateAligned<T>(N);
auto lzcnt = AllocateAligned<T>(N);
HWY_ASSERT(data && lzcnt);

constexpr T kNumOfBitsInT = static_cast<T>(sizeof(T) * 8);
for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(kNumOfBitsInT - 2);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCountOrZero(first_3, Set(d, static_cast<T>(2))));

for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(1);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCountOrZero(
first_3, BitCast(d, Set(du, TU{1} << (kNumOfBitsInT - 2)))));
}
};

HWY_NOINLINE void TestAllMaskedLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestMaskedLeadingZeroCount>());
}

template <class T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))>
static HWY_INLINE T TrailingZeroCountOfValue(T val) {
Expand Down Expand Up @@ -303,6 +345,7 @@ namespace {
HWY_BEFORE_TEST(HwyCountTest);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllPopulationCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllMaskedLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllTrailingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllHighestSetBitIndex);
HWY_AFTER_TEST();
Expand Down
Loading

0 comments on commit c1f7768

Please sign in to comment.