From 5c6faced86dacac55377a80e8eb63486dfddf1b6 Mon Sep 17 00:00:00 2001 From: Erik Schnetter Date: Tue, 25 Jul 2023 19:19:19 -0400 Subject: [PATCH] Arith: Correct type conversions for vector/scalar binary operations --- Arith/src/vect.hxx | 158 +++++++++++++++++++++++++++++---------------- 1 file changed, 102 insertions(+), 56 deletions(-) diff --git a/Arith/src/vect.hxx b/Arith/src/vect.hxx index 29073b24e..40edcc8a6 100644 --- a/Arith/src/vect.hxx +++ b/Arith/src/vect.hxx @@ -101,6 +101,12 @@ construct_array(const F &f) { // arithmetic operations, which is most useful for multi-dimensional // array indices. +template struct vect; + +template struct is_vect : std::false_type {}; +template struct is_vect > : std::true_type {}; +template constexpr bool is_vect_v = is_vect::value; + template struct vect { array elts; @@ -322,112 +328,152 @@ template struct vect { y); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator+(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() + std::declval()), D> + operator+(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a + b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator-(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() - std::declval()), D> + operator-(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a - b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator*(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() - std::declval()), D> + operator*(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a * b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator/(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() / std::declval()), D> + operator/(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a / b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator%(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() % std::declval()), D> + operator%(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a % b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - div_floor(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect(), std::declval())), D> + div_floor(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return div_floor(a, b); }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - mod_floor(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect(), std::declval())), D> + mod_floor(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return mod_floor(a, b); }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator&(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() & std::declval()), D> + operator&(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a & b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator|(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() | std::declval()), D> + operator|(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a | b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator^(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() ^ std::declval()), D> + operator^(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a ^ b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator<<(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() << std::declval()), D> + operator<<(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a << b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator>>(const T &a, const vect &x) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() >> std::declval()), D> + operator>>(const U &a, const vect &x) { return fmap([&](const T &b) ARITH_INLINE { return a >> b; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator+(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() + std::declval()), D> + operator+(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b + a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator-(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() - std::declval()), D> + operator-(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b - a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator*(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() * std::declval()), D> + operator*(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b * a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator/(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() / std::declval()), D> + operator/(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b / a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator%(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() % std::declval()), D> + operator%(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b % a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - div_floor(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect(), std::declval())), D> + div_floor(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return div_floor(b, a); }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - mod_floor(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect(), std::declval())), D> + mod_floor(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return mod_floor(b, a); }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator&(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() & std::declval()), D> + operator&(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b & a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator|(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() | std::declval()), D> + operator|(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b | a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator^(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() ^ std::declval()), D> + operator^(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b ^ a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator<<(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() << std::declval()), D> + operator<<(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b << a; }, x); } - friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator>>(const vect &x, const T &a) { + template > * = nullptr> + friend constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST + vect() >> std::declval()), D> + operator>>(const vect &x, const U &a) { return fmap([&](const T &b) ARITH_INLINE { return b >> a; }, x); } - constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator+=(const vect &x) { - return *this = *this + x; - } - constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect - operator-=(const vect &x) { - return *this = *this - x; - } template constexpr ARITH_INLINE ARITH_DEVICE ARITH_HOST vect operator+=(const vect &x) {