Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex number operations #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ HWY_TESTS = [
("hwy/tests/", "combine_test"),
("hwy/tests/", "compare_test"),
("hwy/tests/", "compress_test"),
("hwy/tests/", "complex_arithmetic_test"),
("hwy/tests/", "concat_test"),
("hwy/tests/", "convert_test"),
("hwy/tests/", "count_test"),
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ set(HWY_TEST_FILES
hwy/tests/cast_test.cc
hwy/tests/combine_test.cc
hwy/tests/compare_test.cc
hwy/tests/complex_arithmetic_test.cc
hwy/tests/compress_test.cc
hwy/tests/concat_test.cc
hwy/tests/convert_test.cc
Expand Down
26 changes: 26 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,32 @@ not a concern, these are equivalent to, and potentially more efficient than,
<code>V **MaskedSatSubOr**(V no, M m, V a, V b)</code>: returns `a[i] +
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if
`m[i]` is false.
#### Complex number operations

Complex types are represented as complex value pairs of real and imaginary
components, with the real components in even-indexed lanes and the imaginary
components in odd-indexed lanes.

All multiplies in this section are performing complex multiplication,
i.e. `(a + ib)(c + id)`.

Take `j` to be the even values of `i`.

* <code>V **CplxConj**(V v)</code>: returns the complex conjugate of the vector,
this negates the imaginary lanes. This is equivalent to `OddEven(Neg(a), a)`.
* <code>V **MulCplx**(V a, V b)</code>: returns `(a[j] + i.a[j + 1])(b[j] + i.b[j + 1])`
* <code>V **MulCplxConj**(V a, V b)</code>: returns `(a[j] + i.a[j + 1])(b[j] - i.b[j + 1])`
* <code>V **MulCplxAdd**(V a, V b, V c)</code>: returns
`(a[j] + i.a[j + 1])(b[j] + i.b[j + 1]) + (c[j] + i.c[j + 1])`
* <code>V **MulCplxConjAdd**(V a, V b, V c)</code>: returns
`(a[j] + i.a[j + 1])(b[j] - i.b[j + 1]) + (c[j] + i.c[j + 1])`
* <code>V **MaskedMulCplxConjAddOrZero**(M mask, V a, V b, V c)</code>: returns
`(a[j] + i.a[j + 1])(b[j] - i.b[j + 1]) + (c[j] + i.c[j + 1])` or `0` if
`mask[i]` is false.
* <code>V **MaskedMulCplxConjOrZero**(M mask, V a, V b)</code>: returns
`(a[j] + i.a[j + 1])(b[j] - i.b[j + 1])` or `0` if `mask[i]` is false.
* <code>V **MaskedMulCplxOr**(M mask, V a, V b, V c)</code>: returns `(a[j] +
i.a[j + 1])(b[j] + i.b[j + 1])` or `c[i]` if `mask[i]` is false.

#### Shifts

Expand Down
124 changes: 124 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6013,6 +6013,130 @@ HWY_API VFromD<DU64> SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a,
return svdot_u64(sum, a, b);
}

// ------------------------------ MulCplx* / MaskedMulCplx*

// Per-target flag to prevent generic_ops-inl.h from defining MulCplx*.
#ifdef HWY_NATIVE_CPLX
#undef HWY_NATIVE_CPLX
#else
#define HWY_NATIVE_CPLX
#endif

template <class V, HWY_IF_NOT_UNSIGNED(TFromV<V>)>
HWY_API V CplxConj(V a) {
return OddEven(Neg(a), a);
}

namespace detail {
#define HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, ROT) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME##ROT(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b, c, ROT); \
} \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME##Z##ROT(svbool_t m, HWY_SVE_V(BASE, BITS) a, \
HWY_SVE_V(BASE, BITS) b, HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_z(m, a, b, c, ROT); \
}

#define HWY_SVE_CPLX_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 0) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 90) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 180) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 270)

// Only SVE2 has complex multiply add for integer types
// and these do not include masked variants
HWY_SVE_FOREACH_F(HWY_SVE_CPLX_FMA, CplxMulAdd, cmla)
#undef HWY_SVE_CPLX_FMA
#undef HWY_SVE_CPLX_FMA_ROT
} // namespace detail

template <class V, class M, HWY_IF_FLOAT_V(V)>
HWY_API V MaskedMulCplxConjAddOrZero(M mask, V a, V b, V c) {
return detail::CplxMulAddZ270(mask, detail::CplxMulAddZ0(mask, c, b, a), b,
a);
}

template <class V, class M, HWY_IF_FLOAT_V(V)>
HWY_API V MaskedMulCplxConjOrZero(M mask, V a, V b) {
return MaskedMulCplxConjAddOrZero(mask, a, b, Zero(DFromV<V>()));
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulCplxAdd(V a, V b, V c) {
return detail::CplxMulAdd90(detail::CplxMulAdd0(c, a, b), a, b);
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulCplx(V a, V b) {
return MulCplxAdd(a, b, Zero(DFromV<V>()));
}

template <class V, class M, HWY_IF_FLOAT_V(V)>
HWY_API V MaskedMulCplxOr(M mask, V a, V b, V c) {
return IfThenElse(mask, MulCplx(a, b), c);
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulCplxConjAdd(V a, V b, V c) {
return detail::CplxMulAdd270(detail::CplxMulAdd0(c, b, a), b, a);
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulCplxConj(V a, V b) {
return MulCplxConjAdd(a, b, Zero(DFromV<V>()));
}

// TODO SVE2 does have intrinsics for integers but not masked variants
template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulCplx(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(Add(Mul(u, y), Mul(v, x)), Sub(Mul(u, x), Mul(v, y)));
}

template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulCplxConj(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(Sub(Mul(v, x), Mul(u, y)), Add(Mul(u, x), Mul(v, y)));
}

template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulCplxAdd(V a, V b, V c) {
return Add(MulCplx(a, b), c);
}

template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulCplxConjAdd(V a, V b, V c) {
return Add(MulCplxConj(a, b), c);
}

template <class V, class M, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MaskedMulCplxConjAddOrZero(M mask, V a, V b, V c) {
return IfThenElseZero(mask, MulCplxConjAdd(a, b, c));
}

template <class V, class M, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MaskedMulCplxConjOrZero(M mask, V a, V b) {
return IfThenElseZero(mask, MulCplxConj(a, b));
}

template <class V, class M, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MaskedMulCplxOr(M mask, V a, V b, V c) {
return IfThenElse(mask, MulCplx(a, b), c);
}

// ------------------------------ AESRound / CLMul

// Static dispatch with -march=armv8-a+sve2+aes, or dynamic dispatch WITHOUT a
Expand Down
65 changes: 65 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4304,6 +4304,71 @@ HWY_API V MulSub(V mul, V x, V sub) {
return Sub(Mul(mul, x), sub);
}
#endif // HWY_NATIVE_INT_FMA
// ------------------------------ MulCplx* / MaskedMulCplx*

#if (defined(HWY_NATIVE_CPLX) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_CPLX
#undef HWY_NATIVE_CPLX
#else
#define HWY_NATIVE_CPLX
#endif

#if HWY_TARGET != HWY_SCALAR || HWY_IDE

template <class V, HWY_IF_NOT_UNSIGNED(TFromV<V>)>
HWY_API V CplxConj(V a) {
return OddEven(Neg(a), a);
}

template <class V>
HWY_API V MulCplx(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(Add(Mul(u, y), Mul(v, x)), Sub(Mul(u, x), Mul(v, y)));
}

template <class V>
HWY_API V MulCplxConj(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(Sub(Mul(v, x), Mul(u, y)), Add(Mul(u, x), Mul(v, y)));
}

template <class V>
HWY_API V MulCplxAdd(V a, V b, V c) {
return Add(MulCplx(a, b), c);
}

template <class V>
HWY_API V MulCplxConjAdd(V a, V b, V c) {
return Add(MulCplxConj(a, b), c);
}

template <class V, class M>
HWY_API V MaskedMulCplxConjAddOrZero(M mask, V a, V b, V c) {
return IfThenElseZero(mask, MulCplxConjAdd(a, b, c));
}

template <class V, class M>
HWY_API V MaskedMulCplxConjOrZero(M mask, V a, V b) {
return IfThenElseZero(mask, MulCplxConj(a, b));
}

template <class V, class M>
HWY_API V MaskedMulCplxOr(M mask, V a, V b, V c) {
return IfThenElse(mask, MulCplx(a, b), c);
}
#endif // HWY_TARGET != HWY_SCALAR

#endif // HWY_NATIVE_CPLX

// ------------------------------ Integer MulSub / NegMulSub
#if (defined(HWY_NATIVE_INT_FMSUB) == defined(HWY_TARGET_TOGGLE))
Expand Down
Loading
Loading