From c1f77686fb63719b1679c07d98df62d0e0995663 Mon Sep 17 00:00:00 2001 From: Mohammad Azim Khan Date: Fri, 15 Nov 2024 16:03:14 +0000 Subject: [PATCH] Load/Store, masked set and counting operations --- g3doc/quick_reference.md | 28 +++++++++++ hwy/ops/arm_sve-inl.h | 80 ++++++++++++++++++++++++++++++++ hwy/ops/generic_ops-inl.h | 97 +++++++++++++++++++++++++++++++++++++++ hwy/tests/count_test.cc | 43 +++++++++++++++++ hwy/tests/logical_test.cc | 28 +++++++++++ hwy/tests/mask_test.cc | 64 ++++++++++++++++++++++++++ hwy/tests/memory_test.cc | 81 ++++++++++++++++++++++++++++++++ 7 files changed, 421 insertions(+) diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index 8220e9b718..c19e9f2c9f 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -429,6 +429,10 @@ for comparisons, for example `Lt` instead of `operator<`. the result, with `t0` in the least-significant (lowest-indexed) lane of each 128-bit block and `tK` in the most-significant (highest-indexed) lane of each 128-bit block: `{t0, t1, ..., tK}` +* V **SetOr**(V no, M m, T a): returns N-lane vector with lane + `i` equal to `a` if `m[i]` is true else `no[i]`. +* V **SetOrZero**(D d, M m, T a): returns N-lane vector with lane + `i` equal to `a` if `m[i]` is true else 0. ### Getting/setting lanes @@ -1015,6 +1019,10 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2): leading zeros in each lane. For any lanes where ```a[i]``` is zero, ```sizeof(TFromV) * 8``` is returned in the corresponding result lanes. +* `V`: `{u,i}` \ + V **MaskedLeadingZeroCountOrZero**(M m, `V a): returns the + result of LeadingZeroCount where `m[i]` is true, and zero otherwise. + * `V`: `{u,i}` \ V **TrailingZeroCount**(V a): returns the number of trailing zeros in each lane. For any lanes where ```a[i]``` is zero, @@ -1029,6 +1037,12 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2): ```HighestValue>>()``` is returned in the corresponding result lanes. +* bool **AllOnes**(D, V v): returns whether all bits in `v[i]` + are set. + +* bool **AllZeros**(D, V v): returns whether all bits in `v[i]` + are clear. + The following operate on individual bits within each lane. Note that the non-operator functions (`And` instead of `&`) must be used for floating-point types, and on SVE/RVV. @@ -1508,6 +1522,9 @@ aligned memory at indices which are not a multiple of the vector length): * Vec<D> **LoadU**(D, const T* p): returns `p[i]`. +* Vec<D> **MaskedLoadU**(D, M m, const T* p): returns `p[i]` + where mask is true and returns zero otherwise. + * Vec<D> **LoadDup128**(D, const T* p): returns one 128-bit block loaded from `p` and broadcasted into all 128-bit block\[s\]. This may be faster than broadcasting single values, and is more convenient than @@ -1538,6 +1555,10 @@ aligned memory at indices which are not a multiple of the vector length): lanes from `p` to the first (lowest-index) lanes of the result vector and fills the remaining lanes with `no`. Like LoadN, this does not fault. +* Vec<D> **LoadHigher**(D d, V v, T* p): Loads `Lanes(d)/2` lanes from + `p` into the upper lanes of the result vector and the lower half of `v` into + the lower lanes. + #### Store * void **Store**(Vec<D> v, D, T* aligned): copies `v[i]` @@ -1577,6 +1598,13 @@ aligned memory at indices which are not a multiple of the vector length): StoreN does not modify any memory past `p + HWY_MIN(Lanes(d), max_lanes_to_store) - 1`. +* void **StoreTruncated**(Vec<DFrom> v, DFrom d, To* HWY_RESTRICT + p): Truncates elements of `v` to type `To` and stores on `p`. It is + similar to performing `TruncateTo` followed by `StoreN` where + `max_lanes_to_store` is `Lanes(d)`. + + StoreTruncated does not modify any memory past `p + Lanes(d) - 1`. + #### Interleaved * void **LoadInterleaved2**(D, const T* p, Vec<D>& v0, diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 2dde1479de..37ba9bf17a 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -412,6 +412,27 @@ using VFromD = decltype(Set(D(), TFromD())); using VBF16 = VFromD>; +// ------------------------------ SetOr/SetOrZero + +#define HWY_SVE_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) inactive, \ + svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ + return sv##OP##_##CHAR##BITS##_m(inactive, m, op); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET_OR, SetOr, dup_n) +#undef HWY_SVE_SET_OR + +#define HWY_SVE_SET_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ + return sv##OP##_##CHAR##BITS##_z(m, op); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET_OR_ZERO, SetOrZero, dup_n) +#undef HWY_SVE_SET_OR_ZERO + // ------------------------------ Zero template @@ -1856,6 +1877,18 @@ HWY_API void BlendedStore(VFromD v, MFromD m, D d, #undef HWY_SVE_MEM +#define HWY_SVE_MASKED_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + MaskedLoadU(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \ + } + +HWY_SVE_FOREACH(HWY_SVE_MASKED_MEM, _, _) + +#undef HWY_SVE_MASKED_MEM + #if HWY_TARGET != HWY_SVE2_128 namespace detail { #define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ @@ -1902,6 +1935,37 @@ HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { #endif // HWY_TARGET != HWY_SVE2_128 +// Truncate to smaller size and store +#ifdef HWY_NATIVE_STORE_TRUNCATED +#undef HWY_NATIVE_STORE_TRUNCATED +#else +#define HWY_NATIVE_STORE_TRUNCATED +#endif + +#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + const HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \ + } + +#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8) +#define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16) +#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32) + +HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b) +HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, StoreTruncated, st1b) +HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, StoreTruncated, st1h) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, StoreTruncated, st1w) + +#undef HWY_SVE_STORE_TRUNCATED + // ------------------------------ Load/Store // SVE only requires lane alignment, not natural alignment of the entire @@ -6258,6 +6322,22 @@ HWY_API V HighestSetBitIndex(V v) { return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v))); } +#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#endif + +#define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + const DFromV d; \ + return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, + MaskedLeadingZeroCountOrZero, clz) +#undef HWY_SVE_LEADING_ZERO_COUNT + // ================================================== END MACROS #undef HWY_SVE_ALL_PTRUE #undef HWY_SVE_D diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index 99b518d99c..8b000eb453 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -97,6 +97,21 @@ HWY_API Vec Inf(D d) { return BitCast(d, Set(du, max_x2 >> 1)); } +// ------------------------------ SetOr/SetOrZero + +template , typename D = DFromV, + typename M = MFromD> +HWY_API V SetOr(V no, M m, T a) { + D d; + return IfThenElse(m, Set(d, a), no); +} + +template , typename M = MFromD, + typename T = TFromD> +HWY_API V SetOrZero(D d, M m, T a) { + return IfThenElseZero(m, Set(d, a)); +} + // ------------------------------ ZeroExtendResizeBitCast // The implementation of detail::ZeroExtendResizeBitCast for the HWY_EMU128 @@ -336,6 +351,20 @@ HWY_API Mask DemoteMaskTo(DTo d_to, DFrom d_from, Mask m) { #endif // HWY_NATIVE_DEMOTE_MASK_TO +// ------------------------------ LoadHigher +#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_HIGHER +#undef HWY_NATIVE_LOAD_HIGHER +#else +#define HWY_NATIVE_LOAD_HIGHER +#endif +template (), HWY_IF_LANES_GT_D(D, 1)> +HWY_API V LoadHigher(D d, V a, T* p) { + const V b = LoadU(d, p); + return ConcatLowerLower(d, b, a); +} +#endif // HWY_NATIVE_LOAD_HIGHER + // ------------------------------ CombineMasks #if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE)) @@ -1118,6 +1147,12 @@ HWY_API V MulByFloorPow2(V v, V exp) { #endif // HWY_NATIVE_MUL_BY_POW2 +// ------------------------------ MaskedLoadU +template +HWY_API VFromD MaskedLoadU(D d, M m, + const TFromD* HWY_RESTRICT unaligned) { + return IfThenElseZero(m, LoadU(d, unaligned)); +} // ------------------------------ LoadInterleaved2 #if HWY_IDE || \ @@ -2574,6 +2609,25 @@ HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, #endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) +// ------------------------------ StoreTruncated +#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_STORE_TRUNCATED +#undef HWY_NATIVE_STORE_TRUNCATED +#else +#define HWY_NATIVE_STORE_TRUNCATED +#endif + +template , + HWY_IF_T_SIZE_GT_D(DFrom, sizeof(To)), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(VFromD)> +HWY_API void StoreTruncated(VFromD v, const DFrom d, + To* HWY_RESTRICT p) { + DTo dsmall; + StoreN(TruncateTo(dsmall, v), dsmall, p, Lanes(d)); +} + +#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE)) + // ------------------------------ Scatter #if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) @@ -3808,6 +3862,21 @@ HWY_API V TrailingZeroCount(V v) { } #endif // HWY_NATIVE_LEADING_ZERO_COUNT +// ------------------------------ MaskedLeadingZeroCountOrZero +#if (defined(HWY_NATIVE_MASKED_LEADING_ZERO_COUNT) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#endif + +template +HWY_API V MaskedLeadingZeroCountOrZero(M m, V v) { + return IfThenElseZero(m, LeadingZeroCount(v)); +} +#endif // HWY_NATIVE_MASKED_LEADING_ZERO_COUNT + // ------------------------------ AESRound // Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. @@ -7299,6 +7368,34 @@ HWY_API V BitShuffle(V v, VI idx) { #endif // HWY_NATIVE_BITSHUFFLE +// ------------------------------ AllOnes/AllZeros +#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ALLONES +#undef HWY_NATIVE_ALLONES +#else +#define HWY_NATIVE_ALLONES +#endif + +template +HWY_API bool AllOnes(V a) { + DFromV d; + return AllTrue(d, Eq(Not(a), Zero(d))); +} +#endif // HWY_NATIVE_ALLONES + +#if (defined(HWY_NATIVE_ALLZEROS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ALLZEROS +#undef HWY_NATIVE_ALLZEROS +#else +#define HWY_NATIVE_ALLZEROS +#endif + +template +HWY_API bool AllZeros(V a) { + DFromV d; + return AllTrue(d, Eq(a, Zero(d))); +} +#endif // HWY_NATIVE_ALLZEROS // ================================================== Operator wrapper // SVE* and RVV currently cannot define operators and have already defined diff --git a/hwy/tests/count_test.cc b/hwy/tests/count_test.cc index cc2d841122..40939d949c 100644 --- a/hwy/tests/count_test.cc +++ b/hwy/tests/count_test.cc @@ -132,6 +132,48 @@ HWY_NOINLINE void TestAllLeadingZeroCount() { ForIntegerTypes(ForPartialVectors()); } +struct TestMaskedLeadingZeroCount { + template + HWY_ATTR_NO_MSAN HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + using TU = MakeUnsigned; + const RebindToUnsigned du; + size_t N = Lanes(d); + const MFromD first_3 = FirstN(d, 3); + auto data = AllocateAligned(N); + auto lzcnt = AllocateAligned(N); + HWY_ASSERT(data && lzcnt); + + constexpr T kNumOfBitsInT = static_cast(sizeof(T) * 8); + for (size_t j = 0; j < N; j++) { + if (j < 3) { + lzcnt[j] = static_cast(kNumOfBitsInT - 2); + } else { + lzcnt[j] = static_cast(0); + } + } + HWY_ASSERT_VEC_EQ( + d, lzcnt.get(), + MaskedLeadingZeroCountOrZero(first_3, Set(d, static_cast(2)))); + + for (size_t j = 0; j < N; j++) { + if (j < 3) { + lzcnt[j] = static_cast(1); + } else { + lzcnt[j] = static_cast(0); + } + } + HWY_ASSERT_VEC_EQ( + d, lzcnt.get(), + MaskedLeadingZeroCountOrZero( + first_3, BitCast(d, Set(du, TU{1} << (kNumOfBitsInT - 2))))); + } +}; + +HWY_NOINLINE void TestAllMaskedLeadingZeroCount() { + ForIntegerTypes(ForPartialVectors()); +} + template static HWY_INLINE T TrailingZeroCountOfValue(T val) { @@ -303,6 +345,7 @@ namespace { HWY_BEFORE_TEST(HwyCountTest); HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllPopulationCount); HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllLeadingZeroCount); +HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllMaskedLeadingZeroCount); HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllTrailingZeroCount); HWY_EXPORT_AND_TEST_P(HwyCountTest, TestAllHighestSetBitIndex); HWY_AFTER_TEST(); diff --git a/hwy/tests/logical_test.cc b/hwy/tests/logical_test.cc index ecd7589c9e..6036594413 100644 --- a/hwy/tests/logical_test.cc +++ b/hwy/tests/logical_test.cc @@ -146,6 +146,32 @@ HWY_NOINLINE void TestAllTestBit() { ForIntegerTypes(ForPartialVectors()); } +struct TestAllOnes { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + auto v0s = Zero(d); + HWY_ASSERT(AllZeros(v0s)); + auto v1s = Not(v0s); + HWY_ASSERT(AllOnes(v1s)); + const size_t kNumBits = sizeof(T) * 8; + for (size_t i = 0; i < kNumBits; ++i) { + const Vec bit1 = Set(d, static_cast(1ull << i)); + const Vec bit2 = Set(d, static_cast(1ull << ((i + 1) % kNumBits))); + const Vec bits12 = Or(bit1, bit2); + HWY_ASSERT(!AllOnes(bit1)); + HWY_ASSERT(!AllZeros(bit1)); + HWY_ASSERT(!AllOnes(bit2)); + HWY_ASSERT(!AllZeros(bit2)); + HWY_ASSERT(!AllOnes(bits12)); + HWY_ASSERT(!AllZeros(bits12)); + } + } +}; + +HWY_NOINLINE void TestAllAllOnes() { + ForIntegerTypes(ForPartialVectors()); +} + } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -159,6 +185,8 @@ HWY_BEFORE_TEST(HwyLogicalTest); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical); HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllOnes); + HWY_AFTER_TEST(); } // namespace } // namespace hwy diff --git a/hwy/tests/mask_test.cc b/hwy/tests/mask_test.cc index 3ad55f5ced..407c8a498d 100644 --- a/hwy/tests/mask_test.cc +++ b/hwy/tests/mask_test.cc @@ -547,6 +547,68 @@ HWY_NOINLINE void TestAllDup128MaskFromMaskBits() { ForAllTypes(ForPartialVectors()); } +struct TestSetOr { + template + void testWithMask(D d, MFromD m) { + TFromD a = 1; + auto yes = Set(d, a); + auto no = Set(d, 2); + auto expected = IfThenElse(m, yes, no); + auto actual = SetOr(no, m, a); + HWY_ASSERT_VEC_EQ(d, expected, actual); + } + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // All False + testWithMask(d, MaskFalse(d)); + auto N = Lanes(d); + // All True + testWithMask(d, FirstN(d, N)); + // Lower half + testWithMask(d, FirstN(d, N / 2)); + // Upper half + testWithMask(d, Not(FirstN(d, N / 2))); + // Interleaved + testWithMask(d, + MaskFromVec(InterleaveLower(Zero(d), Set(d, (TFromD)-1)))); + } +}; + +HWY_NOINLINE void TestAllSetOr() { + ForAllTypes(ForShrinkableVectors()); +} + +struct TestSetOrZero { + template + void testWithMask(D d, MFromD m) { + TFromD a = 1; + auto yes = Set(d, a); + auto no = Zero(d); + auto expected = IfThenElse(m, yes, no); + auto actual = SetOrZero(d, m, a); + HWY_ASSERT_VEC_EQ(d, expected, actual); + } + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // All False + testWithMask(d, MaskFalse(d)); + auto N = Lanes(d); + // All True + testWithMask(d, FirstN(d, N)); + // Lower half + testWithMask(d, FirstN(d, N / 2)); + // Upper half + testWithMask(d, Not(FirstN(d, N / 2))); + // Interleaved + testWithMask(d, + MaskFromVec(InterleaveLower(Zero(d), Set(d, (TFromD)-1)))); + } +}; + +HWY_NOINLINE void TestAllSetOrZero() { + ForAllTypes(ForShrinkableVectors()); +} + } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -571,6 +633,8 @@ HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllSetAtOrBeforeFirst); HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllSetOnlyFirst); HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllSetAtOrAfterFirst); HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllDup128MaskFromMaskBits); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllSetOr); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllSetOrZero); HWY_AFTER_TEST(); } // namespace } // namespace hwy diff --git a/hwy/tests/memory_test.cc b/hwy/tests/memory_test.cc index 8b698f3308..6406201d36 100644 --- a/hwy/tests/memory_test.cc +++ b/hwy/tests/memory_test.cc @@ -67,6 +67,18 @@ struct TestLoadStore { HWY_ASSERT_EQ(i + 2, lanes3[i]); } + // Unaligned masked load + const MFromD first_3 = FirstN(d, 3); + const VFromD vu2 = MaskedLoadU(d, first_3, &lanes[1]); + Store(vu2, d, lanes3.get()); + for (size_t i = 0; i < N; ++i) { + if (i < 3) { + HWY_ASSERT_EQ(i + 2, lanes3[i]); + } else { + HWY_ASSERT_EQ(0, lanes3[i]); + } + } + // Unaligned store StoreU(lo2, d, &lanes2[N / 2]); size_t i = 0; @@ -559,6 +571,73 @@ HWY_NOINLINE void TestAllStoreN() { ForAllTypesAndSpecial(ForPartialVectors()); } +template +constexpr bool IsSupportedTruncation() { + return (sizeof(To) < sizeof(From) && Rebind().Pow2() >= -3 && + Rebind().Pow2() + 4 >= static_cast(CeilLog2(sizeof(To)))); +} + +struct TestStoreTruncated { + template ()>* = nullptr> + HWY_NOINLINE void testTo(From, To, const D) { + // do nothing + } + + template ()>* = nullptr> + HWY_NOINLINE void testTo(From, To, const D d) { + constexpr uint32_t base = 0xFA578D00; + const Vec src = Iota(d, base & hwy::LimitsMax()); + const Rebind dTo; + const Vec v_expected = + Iota(dTo, base & hwy::LimitsMax()); + const size_t NFrom = Lanes(d); + auto expected = AllocateAligned(NFrom); + StoreN(v_expected, dTo, expected.get(), NFrom); + auto actual = AllocateAligned(NFrom); + StoreTruncated(src, d, actual.get()); + HWY_ASSERT_ARRAY_EQ(expected.get(), actual.get(), NFrom); + } + + template + HWY_NOINLINE void operator()(T from, const D d) { + testTo(from, uint8_t(), d); + testTo(from, uint16_t(), d); + testTo(from, uint32_t(), d); + } +}; + +HWY_NOINLINE void TestAllStoreTruncated() { + ForU163264(ForPartialVectors()); +} + +struct TestLoadHigher { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const Vec a = Set(d, 1); + + // Generate a generic vector, then extract the pointer to the first entry + AlignedFreeUniquePtr pa = AllocateAligned(N); + std::fill(pa.get(), pa.get() + N, 20.0); + T* pointer = pa.get(); + + const Vec b = Set(d, 20); + const Vec expected_output_lanes = ConcatLowerLower(d, b, a); + + HWY_ASSERT_VEC_EQ(d, expected_output_lanes, LoadHigher(d, a, pointer)); + } + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + (void)d; + } +}; + +HWY_NOINLINE void TestAllLoadHigher() { + ForAllTypes(ForPartialVectors()); +} + } // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -579,6 +658,8 @@ HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllCache); HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadN); HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadNOr); HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStoreN); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStoreTruncated); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadHigher); HWY_AFTER_TEST(); } // namespace } // namespace hwy