Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masked compare and floating point classifications #8

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,24 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
neither NaN nor infinity, i.e. normal, subnormal or zero. Equivalent to
`Not(Or(IsNaN(v), IsInf(v)))`.

#### Masked floating-point classification

All ops in this section return `false` for `mask=false` lanes. These are
equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.

* `V`: `{f}` \
<code>M **MaskedIsNaN**(V v)</code>: returns mask indicating whether `v[i]`
is "not a number" (unordered) or `false` if `m[i]` is false.

* `V`: `{f}` \
<code>M **MaskedIsInf**(V v)</code>: returns mask indicating whether `v[i]`
is positive or negative infinity or `false` if `m[i]` is false.

* `V`: `{f}` \
<code>M **MaskedIsFinite**(V v)</code>: returns mask indicating whether
`v[i]` is neither NaN nor infinity, i.e. normal, subnormal or zero or
`false` if `m[i]` is false. Equivalent to `Not(Or(IsNaN(v), IsInf(v)))`.

### Logical

* `V`: `{u,i}` \
Expand Down Expand Up @@ -1477,6 +1495,29 @@ These return a mask (see above) indicating whether the condition is true.
for comparing 64-bit keys alongside 64-bit values. Only available if
`HWY_TARGET != HWY_SCALAR`.

#### Masked comparison

All ops in this section return `false` for `mask=false` lanes. These are
equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.

* <code>M **MaskedCompEq**(M m, V a, V b)</code>: returns `a[i] == b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompNe**(M m, V a, V b)</code>: returns `a[i] != b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompLt**(M m, V a, V b)</code>: returns `a[i] < b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompGt**(M m, V a, V b)</code>: returns `a[i] > b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompLe**(M m, V a, V b)</code>: returns `a[i] <= b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedCompGe**(M m, V a, V b)</code>: returns `a[i] >= b[i]` or
`false` if `m[i]` is false.

### Memory

Memory operands are little-endian, otherwise their order would depend on the
Expand Down
71 changes: 71 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,77 @@ HWY_API svbool_t IsFinite(const V v) {
return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>()));
}

// ------------------------------ MaskedCompEq etc.
#ifdef HWY_NATIVE_MASKED_COMP
#undef HWY_NATIVE_MASKED_COMP
#else
#define HWY_NATIVE_MASKED_COMP
#endif

// mask = f(mask, vector, vector)
#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \
HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS(m, a, b); \
}

namespace detail {
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple)

} // namespace detail

#undef HWY_SVE_COMPARE_Z

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompEq(M m, V a, V b) {
return detail::MaskedEq(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompNe(M m, V a, V b) {
return detail::MaskedNe(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompLt(M m, V a, V b) {
return detail::MaskedLt(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompGt(M m, V a, V b) {
// Swap args to reverse comparison
return detail::MaskedLt(m, b, a);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompLe(M m, V a, V b) {
return detail::MaskedLe(m, a, b);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedCompGe(M m, V a, V b) {
// Swap args to reverse comparison
return detail::MaskedLe(m, b, a);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsInf(const M m, const V v) {
return And(m, IsInf(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsFinite(const M m, const V v) {
return And(m, IsFinite(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
return detail::MaskedNe(m, v, v);
}

// ================================================== MEMORY

// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream
Expand Down
54 changes: 54 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,60 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif // HWY_NATIVE_MASKED_ARITH

// ------------------------------ MaskedCompEq etc.
#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_COMP
#undef HWY_NATIVE_MASKED_COMP
#else
#define HWY_NATIVE_MASKED_COMP
#endif

template <class V, class M>
HWY_API auto MaskedCompEq(M m, V a, V b) -> decltype(a == b) {
return And(m, Eq(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompNe(M m, V a, V b) -> decltype(a == b) {
return And(m, Ne(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompLt(M m, V a, V b) -> decltype(a == b) {
return And(m, Lt(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompGt(M m, V a, V b) -> decltype(a == b) {
return And(m, Gt(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompLe(M m, V a, V b) -> decltype(a == b) {
return And(m, Le(a, b));
}

template <class V, class M>
HWY_API auto MaskedCompGe(M m, V a, V b) -> decltype(a == b) {
return And(m, Ge(a, b));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsInf(const M m, const V v) {
return And(m, IsInf(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsFinite(const M m, const V v) {
return And(m, IsFinite(v));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
return And(m, IsNaN(v));
}
#endif // HWY_NATIVE_MASKED_COMP

// ------------------------------ IfNegativeThenNegOrUndefIfZero

#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
Expand Down
174 changes: 173 additions & 1 deletion hwy/tests/compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,176 @@ HWY_NOINLINE void TestAllEq128Upper() {
ForGEVectors<128, TestEq128Upper>()(uint64_t());
}

} // namespace
struct TestMaskedComparision {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

const Vec<D> v0 = Zero(d);
const Vec<D> v2 = Iota(d, 2);
const Vec<D> v2b = Iota(d, 2);
const Vec<D> v3 = Iota(d, 3);
const size_t N = Lanes(d);

const Mask<D> mask_false = MaskFalse(d);
const Mask<D> mask_true = MaskTrue(d);

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask_true, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompEq(mask_true, v2, v2b));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompNe(mask_true, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask_true, v2, v2b));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLt(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask_true, v0, v2));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompGe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedCompLe(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask_true, v0, v2));

auto bool_lanes = AllocateAligned<T>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
}

const Vec<D> mask_i = Load(d, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(d)));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompEq(mask, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompEq(mask, v2, v2b));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompNe(mask, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompNe(mask, v2, v2b));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLt(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGt(mask, v0, v2));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask, MaskedCompGe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedCompLe(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompLe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedCompGe(mask, v0, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedComparision() {
ForAllTypes(ForPartialVectors<TestMaskedComparision>());
}

struct TestMaskedFloatClassification {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

const Vec<D> v0 = Zero(d);
const Vec<D> v1 = Iota(d, 2);
const Vec<D> v2 = Inf(d);
const Vec<D> v3 = NaN(d);
const size_t N = Lanes(d);

const Mask<D> mask_false = MaskFalse(d);
const Mask<D> mask_true = MaskTrue(d);

// Test against all zeros
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v0));

// Test against finite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v1));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsFinite(mask_true, v1));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v1));

// Test against infinite values
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsInf(mask_true, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v2));

// Test against NaN values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask_true, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask_true, v3));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsNaN(mask_true, v3));

auto bool_lanes = AllocateAligned<T>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
}

const Vec<D> mask_i = Load(d, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(d)));

// Test against all zeros
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v0));

// Test against finite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v1));
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsFinite(mask, v1));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v1));

// Test against infinite values
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsInf(mask, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v2));

// Test against NaN values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsInf(mask, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsFinite(mask, v3));
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsNaN(mask, v3));
}
}
};

HWY_NOINLINE void TestAllMaskedFloatClassification() {
ForFloatTypes(ForPartialVectors<TestMaskedFloatClassification>());
}
} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
Expand All @@ -695,6 +864,9 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper);

HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedComparision);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedFloatClassification);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading