Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load/Store, masked set and counting operations #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ for comparisons, for example `Lt` instead of `operator<`.
the result, with `t0` in the least-significant (lowest-indexed) lane of each
128-bit block and `tK` in the most-significant (highest-indexed) lane of
each 128-bit block: `{t0, t1, ..., tK}`
* <code>V **SetOr**(V no, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else `no[i]`.
* <code>V **SetOrZero**(D d, M m, T a)</code>: returns N-lane vector with lane
`i` equal to `a` if `m[i]` is true else 0.

### Getting/setting lanes

Expand Down Expand Up @@ -1015,6 +1019,10 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
leading zeros in each lane. For any lanes where ```a[i]``` is zero,
```sizeof(TFromV<V>) * 8``` is returned in the corresponding result lanes.

* `V`: `{u,i}` \
<code>V **MaskedLeadingZeroCountOrZero**(M m, `V a)</code>: returns the
result of LeadingZeroCount where `m[i]` is true, and zero otherwise.

* `V`: `{u,i}` \
<code>V **TrailingZeroCount**(V a)</code>: returns the number of
trailing zeros in each lane. For any lanes where ```a[i]``` is zero,
Expand All @@ -1029,6 +1037,12 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
```HighestValue<MakeSigned<TFromV<V>>>()``` is returned in the
corresponding result lanes.

* <code>bool **AllOnes**(D, V v)</code>: returns whether all bits in `v[i]`
are set.

* <code>bool **AllZeros**(D, V v)</code>: returns whether all bits in `v[i]`
are clear.

The following operate on individual bits within each lane. Note that the
non-operator functions (`And` instead of `&`) must be used for floating-point
types, and on SVE/RVV.
Expand Down Expand Up @@ -1508,6 +1522,9 @@ aligned memory at indices which are not a multiple of the vector length):

* <code>Vec&lt;D&gt; **LoadU**(D, const T* p)</code>: returns `p[i]`.

* <code>Vec&lt;D&gt; **MaskedLoadU**(D, M m, const T* p)</code>: returns `p[i]`
where mask is true and returns zero otherwise.

* <code>Vec&lt;D&gt; **LoadDup128**(D, const T* p)</code>: returns one 128-bit
block loaded from `p` and broadcasted into all 128-bit block\[s\]. This may
be faster than broadcasting single values, and is more convenient than
Expand Down Expand Up @@ -1538,6 +1555,10 @@ aligned memory at indices which are not a multiple of the vector length):
lanes from `p` to the first (lowest-index) lanes of the result vector and
fills the remaining lanes with `no`. Like LoadN, this does not fault.

* <code> Vec&lt;D&gt; **LoadHigher**(D d, V v, T* p)</code>: Loads `Lanes(d)/2` lanes from
`p` into the upper lanes of the result vector and the lower half of `v` into
the lower lanes.

#### Store

* <code>void **Store**(Vec&lt;D&gt; v, D, T* aligned)</code>: copies `v[i]`
Expand Down Expand Up @@ -1577,6 +1598,13 @@ aligned memory at indices which are not a multiple of the vector length):
StoreN does not modify any memory past
`p + HWY_MIN(Lanes(d), max_lanes_to_store) - 1`.

* <code>void **StoreTruncated**(Vec&lt;DFrom&gt; v, DFrom d, To* HWY_RESTRICT
p)</code>: Truncates elements of `v` to type `To` and stores on `p`. It is
similar to performing `TruncateTo` followed by `StoreN` where
`max_lanes_to_store` is `Lanes(d)`.

StoreTruncated does not modify any memory past `p + Lanes(d) - 1`.

#### Interleaved

* <code>void **LoadInterleaved2**(D, const T* p, Vec&lt;D&gt;&amp; v0,
Expand Down
80 changes: 80 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,27 @@ using VFromD = decltype(Set(D(), TFromD<D>()));

using VBF16 = VFromD<ScalableTag<bfloat16_t>>;

// ------------------------------ SetOr/SetOrZero

#define HWY_SVE_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) inactive, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_m(inactive, m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_SET_OR, SetOr, dup_n)
#undef HWY_SVE_SET_OR

#define HWY_SVE_SET_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \
svbool_t m, HWY_SVE_T(BASE, BITS) op) { \
return sv##OP##_##CHAR##BITS##_z(m, op); \
}

HWY_SVE_FOREACH(HWY_SVE_SET_OR_ZERO, SetOrZero, dup_n)
#undef HWY_SVE_SET_OR_ZERO

// ------------------------------ Zero

template <class D>
Expand Down Expand Up @@ -1856,6 +1877,18 @@ HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d,

#undef HWY_SVE_MEM

#define HWY_SVE_MASKED_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \
template <size_t N, int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
MaskedLoadU(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m, \
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \
}

HWY_SVE_FOREACH(HWY_SVE_MASKED_MEM, _, _)

#undef HWY_SVE_MASKED_MEM

#if HWY_TARGET != HWY_SVE2_128
namespace detail {
#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -1902,6 +1935,37 @@ HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {

#endif // HWY_TARGET != HWY_SVE2_128

// Truncate to smaller size and store
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \
template <size_t N, int kPow2> \
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
const HWY_SVE_D(BASE, BITS, N, kPow2) d, \
HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \
sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \
}

#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8)
#define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16)
#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32)

HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b)
HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h)
HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, StoreTruncated, st1w)

#undef HWY_SVE_STORE_TRUNCATED

// ------------------------------ Load/Store

// SVE only requires lane alignment, not natural alignment of the entire
Expand Down Expand Up @@ -6258,6 +6322,22 @@ HWY_API V HighestSetBitIndex(V v) {
return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v)));
}

#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

#define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
const DFromV<decltype(v)> d; \
return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \
}

HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT,
MaskedLeadingZeroCountOrZero, clz)
#undef HWY_SVE_LEADING_ZERO_COUNT

// ================================================== END MACROS
#undef HWY_SVE_ALL_PTRUE
#undef HWY_SVE_D
Expand Down
97 changes: 97 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ HWY_API Vec<D> Inf(D d) {
return BitCast(d, Set(du, max_x2 >> 1));
}

// ------------------------------ SetOr/SetOrZero

template <class V, typename T = TFromV<V>, typename D = DFromV<V>,
typename M = MFromD<D>>
HWY_API V SetOr(V no, M m, T a) {
D d;
return IfThenElse(m, Set(d, a), no);
}

template <class D, typename V = VFromD<D>, typename M = MFromD<D>,
typename T = TFromD<D>>
HWY_API V SetOrZero(D d, M m, T a) {
return IfThenElseZero(m, Set(d, a));
}

// ------------------------------ ZeroExtendResizeBitCast

// The implementation of detail::ZeroExtendResizeBitCast for the HWY_EMU128
Expand Down Expand Up @@ -336,6 +351,20 @@ HWY_API Mask<DTo> DemoteMaskTo(DTo d_to, DFrom d_from, Mask<DFrom> m) {

#endif // HWY_NATIVE_DEMOTE_MASK_TO

// ------------------------------ LoadHigher
#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_LOAD_HIGHER
#undef HWY_NATIVE_LOAD_HIGHER
#else
#define HWY_NATIVE_LOAD_HIGHER
#endif
template <class D, typename T, class V = VFromD<D>(), HWY_IF_LANES_GT_D(D, 1)>
HWY_API V LoadHigher(D d, V a, T* p) {
const V b = LoadU(d, p);
return ConcatLowerLower(d, b, a);
}
#endif // HWY_NATIVE_LOAD_HIGHER

// ------------------------------ CombineMasks

#if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -1118,6 +1147,12 @@ HWY_API V MulByFloorPow2(V v, V exp) {

#endif // HWY_NATIVE_MUL_BY_POW2

// ------------------------------ MaskedLoadU
template <class D, class M>
HWY_API VFromD<D> MaskedLoadU(D d, M m,
const TFromD<D>* HWY_RESTRICT unaligned) {
return IfThenElseZero(m, LoadU(d, unaligned));
}
// ------------------------------ LoadInterleaved2

#if HWY_IDE || \
Expand Down Expand Up @@ -2574,6 +2609,25 @@ HWY_API void StoreN(VFromD<D> v, D d, T* HWY_RESTRICT p,

#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ StoreTruncated
#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_STORE_TRUNCATED
#undef HWY_NATIVE_STORE_TRUNCATED
#else
#define HWY_NATIVE_STORE_TRUNCATED
#endif

template <class DFrom, class To, class DTo = Rebind<To, DFrom>,
HWY_IF_T_SIZE_GT_D(DFrom, sizeof(To)),
HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(VFromD<DFrom>)>
HWY_API void StoreTruncated(VFromD<DFrom> v, const DFrom d,
To* HWY_RESTRICT p) {
DTo dsmall;
StoreN(TruncateTo(dsmall, v), dsmall, p, Lanes(d));
}

#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE))

// ------------------------------ Scatter

#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -3808,6 +3862,21 @@ HWY_API V TrailingZeroCount(V v) {
}
#endif // HWY_NATIVE_LEADING_ZERO_COUNT

// ------------------------------ MaskedLeadingZeroCountOrZero
#if (defined(HWY_NATIVE_MASKED_LEADING_ZERO_COUNT) == \
defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#else
#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT
#endif

template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), class M>
HWY_API V MaskedLeadingZeroCountOrZero(M m, V v) {
return IfThenElseZero(m, LeadingZeroCount(v));
}
#endif // HWY_NATIVE_MASKED_LEADING_ZERO_COUNT

// ------------------------------ AESRound

// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes.
Expand Down Expand Up @@ -7299,6 +7368,34 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

// ------------------------------ AllOnes/AllZeros
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLONES
#undef HWY_NATIVE_ALLONES
#else
#define HWY_NATIVE_ALLONES
#endif

template <class V>
HWY_API bool AllOnes(V a) {
DFromV<V> d;
return AllTrue(d, Eq(Not(a), Zero(d)));
}
#endif // HWY_NATIVE_ALLONES

#if (defined(HWY_NATIVE_ALLZEROS) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLZEROS
#undef HWY_NATIVE_ALLZEROS
#else
#define HWY_NATIVE_ALLZEROS
#endif

template <class V>
HWY_API bool AllZeros(V a) {
DFromV<V> d;
return AllTrue(d, Eq(a, Zero(d)));
}
#endif // HWY_NATIVE_ALLZEROS
// ================================================== Operator wrapper

// SVE* and RVV currently cannot define operators and have already defined
Expand Down
43 changes: 43 additions & 0 deletions hwy/tests/count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,48 @@ HWY_NOINLINE void TestAllLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestLeadingZeroCount>());
}

struct TestMaskedLeadingZeroCount {
template <class T, class D>
HWY_ATTR_NO_MSAN HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;
using TU = MakeUnsigned<T>;
const RebindToUnsigned<decltype(d)> du;
size_t N = Lanes(d);
const MFromD<D> first_3 = FirstN(d, 3);
auto data = AllocateAligned<T>(N);
auto lzcnt = AllocateAligned<T>(N);
HWY_ASSERT(data && lzcnt);

constexpr T kNumOfBitsInT = static_cast<T>(sizeof(T) * 8);
for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(kNumOfBitsInT - 2);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCountOrZero(first_3, Set(d, static_cast<T>(2))));

for (size_t j = 0; j < N; j++) {
if (j < 3) {
lzcnt[j] = static_cast<T>(1);
} else {
lzcnt[j] = static_cast<T>(0);
}
}
HWY_ASSERT_VEC_EQ(
d, lzcnt.get(),
MaskedLeadingZeroCountOrZero(
first_3, BitCast(d, Set(du, TU{1} << (kNumOfBitsInT - 2)))));
}
};

HWY_NOINLINE void TestAllMaskedLeadingZeroCount() {
ForIntegerTypes(ForPartialVectors<TestMaskedLeadingZeroCount>());
}

template <class T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))>
static HWY_INLINE T TrailingZeroCountOfValue(T val) {
Expand Down Expand Up @@ -303,6 +345,7 @@ namespace {
HWY_BEFORE_TEST(HwyCountTest);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllPopulationCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllMaskedLeadingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllTrailingZeroCount);
HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllHighestSetBitIndex);
HWY_AFTER_TEST();
Expand Down
Loading
Loading