diff --git a/.clang-files b/.clang-files index 6a96950bd..891664468 100644 --- a/.clang-files +++ b/.clang-files @@ -12,11 +12,13 @@ ./src/cnfizers/Tseitin.h ./src/common/numbers/Integer.h +./src/common/numbers/Number.cc ./src/common/numbers/Number.h ./src/common/numbers/Real.h ./src/common/ApiException.h ./src/common/FlaPartitionMap.cc ./src/common/FlaPartitionMap.h +./src/common/FunUtils.h ./src/common/InternalException.h ./src/common/PartitionInfo.cc ./src/common/PartitionInfo.h diff --git a/src/common/FunUtils.h b/src/common/FunUtils.h new file mode 100644 index 000000000..51e1516ad --- /dev/null +++ b/src/common/FunUtils.h @@ -0,0 +1,56 @@ +#ifndef OPENSMT_FUNUTILS_H +#define OPENSMT_FUNUTILS_H + +#include +#include +#include +#include + +#define FORWARD(arg) std::forward(arg) + +namespace opensmt { +template +struct CompoundAssignOf; + +template<> +struct CompoundAssignOf> { + constexpr auto & operator()(auto & lhs, auto const & rhs) const { + lhs += rhs; + return lhs; + } +}; +template<> +struct CompoundAssignOf> { + constexpr auto & operator()(auto & lhs, auto const & rhs) const { + lhs -= rhs; + return lhs; + } +}; +template<> +struct CompoundAssignOf> { + constexpr auto & operator()(auto & lhs, auto const & rhs) const { + lhs *= rhs; + return lhs; + } +}; +template<> +struct CompoundAssignOf> { + constexpr auto & operator()(auto & lhs, auto const & rhs) const { + lhs *= rhs; + return lhs; + } +}; + +template +inline O intToOrdering(std::integral auto cmp) { + if (cmp == 0) { + return O::equivalent; + } else if (cmp < 0) { + return O::less; + } else { + return O::greater; + } +} +} // namespace opensmt + +#endif // OPENSMT_FUNUTILS_H diff --git a/src/common/TypeUtils.h b/src/common/TypeUtils.h index c5845654d..02d8cf43f 100644 --- a/src/common/TypeUtils.h +++ b/src/common/TypeUtils.h @@ -41,6 +41,12 @@ class span { T * _beg; uint32_t _size; }; + +// This is useful e.g. for std::visit(..., std::variant) +template +struct Overload : Ts... { + using Ts::operator()...; +}; } // namespace opensmt #endif // OPENSMT_TYPEUTILS_H diff --git a/src/common/numbers/CMakeLists.txt b/src/common/numbers/CMakeLists.txt index e103a21fd..ec2163e9a 100644 --- a/src/common/numbers/CMakeLists.txt +++ b/src/common/numbers/CMakeLists.txt @@ -1,7 +1,9 @@ target_sources(common PUBLIC + "${CMAKE_CURRENT_LIST_DIR}/FastInteger.cc" "${CMAKE_CURRENT_LIST_DIR}/FastRational.cc" "${CMAKE_CURRENT_LIST_DIR}/Integer.h" + "${CMAKE_CURRENT_LIST_DIR}/Number.cc" "${CMAKE_CURRENT_LIST_DIR}/Number.h" "${CMAKE_CURRENT_LIST_DIR}/Real.h" PRIVATE @@ -11,6 +13,7 @@ PRIVATE install(FILES ${CMAKE_CURRENT_LIST_DIR}/Integer.h ${CMAKE_CURRENT_LIST_DIR}/Number.h + ${CMAKE_CURRENT_LIST_DIR}/FastInteger.h ${CMAKE_CURRENT_LIST_DIR}/FastRational.h ${CMAKE_CURRENT_LIST_DIR}/Real.h ${CMAKE_CURRENT_LIST_DIR}/NumberUtils.h diff --git a/src/common/numbers/FastInteger.cc b/src/common/numbers/FastInteger.cc new file mode 100644 index 000000000..0d1bf721b --- /dev/null +++ b/src/common/numbers/FastInteger.cc @@ -0,0 +1,88 @@ +#include "FastInteger.h" + +namespace opensmt { +FastInteger::FastInteger(const char* str, const int base) : FastRational(str, base) { + assert(isIntegerValue()); +} + +FastInteger gcd(FastInteger const & a, FastInteger const & b) +{ + assert(a.isIntegerValue() and b.isIntegerValue()); + if (a.wordPartValid() && b.wordPartValid()) { + // Refers to FastRational.h:gcd + return FastInteger(gcd(a.num, b.num)); + } + else { + a.ensure_mpq_valid(); + b.ensure_mpq_valid(); + mpz_gcd(FastInteger::mpz(), mpq_numref(a.mpq), mpq_numref(b.mpq)); + return FastInteger(FastInteger::mpz()); + } +} + +FastInteger lcm(FastInteger const & a, FastInteger const & b) +{ + assert(a.isIntegerValue() and b.isIntegerValue()); + if (a.wordPartValid() && b.wordPartValid()) { + // Refers to FastRational.h:lcm + return lcm(a.num, b.num); + } + else { + a.ensure_mpq_valid(); + b.ensure_mpq_valid(); + mpz_lcm(FastInteger::mpz(), mpq_numref(a.mpq), mpq_numref(b.mpq)); + return FastInteger(FastInteger::mpz()); + } +} + +// The quotient of n and d using the fast rationals. +// Divide n by d, forming a quotient q. +// Rounds q down towards -infinity, and q will satisfy n = q*d + r for some 0 <= abs(r) <= abs(d) +FastInteger fastint_fdiv_q(FastInteger const & n, FastInteger const & d) { + assert(n.isIntegerValue() && d.isIntegerValue()); + if (n.wordPartValid() && d.wordPartValid()) { + word num = n.num; + word den = d.num; + word quo; + if (num == INT_MIN) // The abs is guaranteed to overflow. Otherwise this is always fine + goto overflow; + // After this -INT_MIN+1 <= numerator <= INT_MAX, and therefore the result always fits into a word. + quo = num / den; + if (num % den != 0 && ((num < 0 && den >=0) || (den < 0 && num >= 0))) // The result should be negative + quo--; // INT_MAX-1 >= quo >= -INT_MIN + + return quo; + } +overflow: + n.ensure_mpq_valid(); + d.ensure_mpq_valid(); + mpz_fdiv_q(FastInteger::mpz(), mpq_numref(n.mpq), mpq_numref(d.mpq)); + return FastInteger(FastInteger::mpz()); +} + +//void mpz_divexact (mpz_ptr, mpz_srcptr, mpz_srcptr); +FastInteger divexact(FastInteger const & n, FastInteger const & d) { + assert(d != 0); + assert(n.isIntegerValue() && d.isIntegerValue()); + if (n.wordPartValid() && d.wordPartValid()) { + word num = n.num; + word den = d.num; + word quo; + if (den != 0){ + quo = num / den; + return quo; + } + else { + // Division by zero + assert(false); + return FastInteger(0); + } + } else { + assert(n.mpqPartValid() || d.mpqPartValid()); + n.ensure_mpq_valid(); + d.ensure_mpq_valid(); + mpz_divexact(FastInteger::mpz(), mpq_numref(n.mpq), mpq_numref(d.mpq)); + return FastInteger(FastInteger::mpz()); + } +} +} diff --git a/src/common/numbers/FastInteger.h b/src/common/numbers/FastInteger.h new file mode 100644 index 000000000..d6b423160 --- /dev/null +++ b/src/common/numbers/FastInteger.h @@ -0,0 +1,121 @@ +// +// Created by Tomas Kolarik in 08/2024. +// + +#ifndef OPENSMT_FAST_INTEGER_H +#define OPENSMT_FAST_INTEGER_H + +#include "FastRational.h" + +#include + +namespace opensmt { +// TODO: inefficient, rational representation & uses mpq instead of mpz +class FastInteger : public FastRational { +public: + using FastRational::FastRational; + explicit FastInteger(FastRational rat) : FastRational(std::move(rat)) { assert(isIntegerValue()); } + explicit FastInteger(char const *, int const base = 10); + FastInteger & operator=(FastRational const & other) { + assert(this != &other); + assert(other.isIntegerValue()); + FastRational::operator=(other); + return *this; + } + FastInteger & operator=(FastRational && other) { + assert(other.isIntegerValue()); + FastRational::operator=(std::move(other)); + return *this; + } + FastInteger & operator=(std::integral auto i) { return operator=(FastInteger(i)); } + + FastInteger ceil() const noexcept { return *this; } + FastInteger floor() const noexcept { return *this; } + + FastInteger operator-() const { return static_cast(FastRational::operator-()); } + FastInteger operator+(FastInteger const & b) const { + return static_cast(FastRational::operator+(b)); + } + FastInteger operator-(FastInteger const & b) const { + return static_cast(FastRational::operator-(b)); + } + FastInteger operator*(FastInteger const & b) const { + return static_cast(FastRational::operator*(b)); + } + FastInteger & operator+=(FastInteger const & b) { + FastRational::operator+=(b); + return *this; + } + FastInteger & operator-=(FastInteger const & b) { + FastRational::operator-=(b); + return *this; + } + FastInteger & operator*=(FastInteger const & b) { + FastRational::operator*=(b); + return *this; + } + FastInteger & operator+=(std::integral auto i) { return operator+=(FastInteger(i)); } + FastInteger & operator-=(std::integral auto i) { return operator-=(FastInteger(i)); } + FastInteger & operator*=(std::integral auto i) { return operator*=(FastInteger(i)); } + FastRational & operator+=(FastRational const &) = delete; + FastRational & operator-=(FastRational const &) = delete; + FastRational & operator*=(FastRational const &) = delete; + + FastRational operator/(FastInteger const & b) const { return FastRational::operator/(b); } + void operator/=(FastInteger const &) = delete; + FastRational & operator/=(FastRational const &) = delete; + + // The return value will have the sign of d + FastInteger operator%(FastInteger const & d) const { + assert(isIntegerValue() && d.isIntegerValue()); + if (wordPartValid() && d.wordPartValid()) { + uword w = absVal(num % d.num); // Largest value is absVal(INT_MAX % INT_MIN) = INT_MAX + return (word)(d.num > 0 ? w : -w); // No overflow since 0 <= w <= INT_MAX + } + FastRational r = operator/(d); + auto i = FastInteger(r.floor()); + return operator-(i * d); + } + FastInteger & operator%=(FastInteger const & d) { + //+ it would be more efficient the other way around + return operator=(operator%(d)); + } + + std::optional tryGetValue() const { + if (!wordPartValid()) return {}; + return num; + } + +private: + FastInteger(std::integral auto, std::integral auto) = delete; + + using FastRational::get_d; + using FastRational::get_den; + using FastRational::get_num; + using FastRational::tryGetNumDen; + + friend FastInteger gcd(FastInteger const &, FastInteger const &); + friend FastInteger lcm(FastInteger const &, FastInteger const &); + friend FastInteger fastint_fdiv_q(FastInteger const &, FastInteger const &); + friend FastInteger divexact(FastInteger const &, FastInteger const &); +}; + +static_assert(!std::integral); + +// The result could not fit into integer -> FastInteger +template +FastInteger lcm(integer a, integer b) { + if (a == 0) return 0; + if (b == 0) return 0; + FastRational rat = (b > a) ? FastRational(b / gcd(a, b)) * a : FastRational(a / gcd(a, b)) * b; + assert(rat.isIntegerValue()); + return static_cast(rat); +} + +FastInteger gcd(FastInteger const &, FastInteger const &); +FastInteger lcm(FastInteger const &, FastInteger const &); +FastInteger fastint_fdiv_q(FastInteger const & n, FastInteger const & d); +FastInteger divexact(FastInteger const & n, FastInteger const & d); +} // namespace opensmt + +#endif // OPENSMT_FAST_INTEGER_H diff --git a/src/common/numbers/FastRational.cc b/src/common/numbers/FastRational.cc index 68566c05f..e4879f562 100644 --- a/src/common/numbers/FastRational.cc +++ b/src/common/numbers/FastRational.cc @@ -6,6 +6,7 @@ Copyright (c) 2008, 2009 Centre national de la recherche scientifique (CNRS) */ #include "FastRational.h" +#include "FastInteger.h" #include #include @@ -105,95 +106,16 @@ void FastRational::print_(std::ostream & out) const } } -std::string FastRational::get_str() const +std::string FastRational::toString() const { std::ostringstream os; print_(os); return os.str(); } -FastRational gcd(FastRational const & a, FastRational const & b) -{ - assert(a.isInteger() and b.isInteger()); - if (a.wordPartValid() && b.wordPartValid()) { - return FastRational(gcd(a.num, b.num)); - } - else { - a.ensure_mpq_valid(); - b.ensure_mpq_valid(); - mpz_gcd(FastRational::mpz(), mpq_numref(a.mpq), mpq_numref(b.mpq)); - return FastRational(FastRational::mpz()); - } -} - -FastRational lcm(FastRational const & a, FastRational const & b) -{ - assert(a.isInteger() and b.isInteger()); - if (a.wordPartValid() && b.wordPartValid()) { - return lcm(a.num, b.num); - } - else { - a.ensure_mpq_valid(); - b.ensure_mpq_valid(); - mpz_lcm(FastRational::mpz(), mpq_numref(a.mpq), mpq_numref(b.mpq)); - return FastRational(FastRational::mpz()); - } -} - FastRational fastrat_round_to_int(const FastRational& n) { FastRational res = n + FastRational(1, 2); - return fastrat_fdiv_q(res.get_num(), res.get_den()); -} - -// The quotient of n and d using the fast rationals. -// Divide n by d, forming a quotient q. -// Rounds q down towards -infinity, and q will satisfy n = q*d + r for some 0 <= abs(r) <= abs(d) -FastRational fastrat_fdiv_q(FastRational const & n, FastRational const & d) { - assert(n.isInteger() && d.isInteger()); - if (n.wordPartValid() && d.wordPartValid()) { - word num = n.num; - word den = d.num; - word quo; - if (num == INT_MIN) // The abs is guaranteed to overflow. Otherwise this is always fine - goto overflow; - // After this -INT_MIN+1 <= numerator <= INT_MAX, and therefore the result always fits into a word. - quo = num / den; - if (num % den != 0 && ((num < 0 && den >=0) || (den < 0 && num >= 0))) // The result should be negative - quo--; // INT_MAX-1 >= quo >= -INT_MIN - - return quo; - } -overflow: - n.ensure_mpq_valid(); - d.ensure_mpq_valid(); - mpz_fdiv_q(FastRational::mpz(), mpq_numref(n.mpq), mpq_numref(d.mpq)); - return FastRational(FastRational::mpz()); -} - -//void mpz_divexact (mpz_ptr, mpz_srcptr, mpz_srcptr); -FastRational divexact(FastRational const & n, FastRational const & d) { - assert(d != 0); - assert(n.isInteger() && d.isInteger()); - if (n.wordPartValid() && d.wordPartValid()) { - word num = n.num; - word den = d.num; - word quo; - if (den != 0){ - quo = num / den; - return quo; - } - else { - // Division by zero - assert(false); - return FastRational(0); - } - } else { - assert(n.mpqPartValid() || d.mpqPartValid()); - n.ensure_mpq_valid(); - d.ensure_mpq_valid(); - mpz_divexact(FastRational::mpz(), mpq_numref(n.mpq), mpq_numref(d.mpq)); - return FastRational(FastRational::mpz()); - } + return fastint_fdiv_q(static_cast(res.get_num()), static_cast(res.get_den())); } // Given as input the sequence Reals, return the smallest number m such that for each r in Reals, r*m is an integer @@ -201,7 +123,7 @@ FastRational get_multiplicand(const std::vector& reals) { std::vector dens; for (const auto & r : reals) { - if (!r.isInteger()) { + if (!r.isIntegerValue()) { dens.push_back(r.get_den()); } } @@ -219,7 +141,7 @@ FastRational get_multiplicand(const std::vector& reals) char *buf_new; for (int j = 0; j < dens.size(); j++) { - asprintf(&buf_new, "%s%s%s", buf, dens[j].get_str().c_str(), + asprintf(&buf_new, "%s%s%s", buf, dens[j].toString().c_str(), j == dens.size() - 1 ? "" : ", "); free(buf); buf = buf_new; @@ -234,7 +156,7 @@ FastRational get_multiplicand(const std::vector& reals) else { // We filter in place the integers in dens. The last two are guaranteed to be ; int k = 0; - FastRational m = lcm(dens[dens.size()-1], dens[dens.size()-2]); + FastRational m = lcm(static_cast(dens[dens.size()-1]), static_cast(dens[dens.size()-2])); mult *= m; for (size_t j = 0; j < dens.size()-2; j++) { FastRational n = (m/dens[j]).get_den(); @@ -245,7 +167,7 @@ FastRational get_multiplicand(const std::vector& reals) } } #ifdef PRINTALOT - printf("Multiplicand is %s\n", mult.get_str().c_str()); + printf("Multiplicand is %s\n", mult.toString().c_str()); #endif return mult; } diff --git a/src/common/numbers/FastRational.h b/src/common/numbers/FastRational.h index 7948137bf..03ece6af9 100644 --- a/src/common/numbers/FastRational.h +++ b/src/common/numbers/FastRational.h @@ -7,6 +7,8 @@ Copyright (c) 2008, 2009 Centre national de la recherche scientifique (CNRS) #ifndef FAST_RATIONALS_H #define FAST_RATIONALS_H +#include "NumberConcept.h" + #include #include #include @@ -15,6 +17,7 @@ Copyright (c) 2008, 2009 Centre national de la recherche scientifique (CNRS) #include #include #include +#include namespace opensmt { @@ -73,6 +76,7 @@ inline ulword absVal(lword x) { class FastRational { +protected: class mpqPool { std::stack store; // uses deque as storage to avoid realloc @@ -90,7 +94,7 @@ class FastRational inline static thread_local mpz_class temp; inline static mpz_ptr mpz() { return temp.get_mpz_t(); } - +private: // Bit masks for questioning state: static const unsigned char wordValidMask = 0x1; static const unsigned char mpqMemoryAllocatedMask = 0x2; @@ -155,6 +159,7 @@ class FastRational state = State::WORD_VALID; } } +protected: void ensure_mpq_valid() const { if (!mpqPartValid()) { assert(wordPartValid()); @@ -202,10 +207,6 @@ class FastRational friend inline void subtractionAssign (FastRational &, const FastRational &); friend inline void multiplicationAssign(FastRational &, const FastRational &); friend inline void divisionAssign (FastRational &, const FastRational &); - friend FastRational gcd (FastRational const &, FastRational const &); - friend FastRational lcm (FastRational const &, FastRational const &); - friend FastRational fastrat_fdiv_q (FastRational const & n, FastRational const & d); - friend FastRational divexact (FastRational const & n, FastRational const & d); static inline int compare(lword a, lword b) { if (a < b) return -1; @@ -218,7 +219,7 @@ class FastRational public: void print (std::ostream &) const; - std::string get_str() const; + std::string toString() const; inline double get_d () const; @@ -268,7 +269,7 @@ class FastRational bool operator!=( const FastRational & b ) const { return !(*this == b); } inline unsigned size() const; - uint32_t getHashValue() const { + uint32_t hash() const { if (wordPartValid()) { return 37*(uint32_t)num + 13*(uint32_t)den; } @@ -287,7 +288,13 @@ class FastRational } } - bool isInteger() const { + struct Hash { + uint32_t operator() (const FastRational& s) const { + return (uint32_t)s.hash(); + } + }; + + bool isIntegerValue() const { if (wordPartValid()) return den == 1; else { @@ -297,7 +304,7 @@ class FastRational } inline FastRational ceil() const { - if (isInteger()) + if (isIntegerValue()) return *this; if (wordPartValid()) { @@ -312,7 +319,7 @@ class FastRational } inline FastRational floor() const { - if (isInteger()) return *this; + if (isIntegerValue()) return *this; return ceil() - 1; } bool isWellFormed() const; @@ -395,36 +402,10 @@ class FastRational assert(wordPartValid() or not fitsWord()); return wordPartValid() && num == 1 && den == 1; } - - - // Return *this % d. The return value will have the sign of d - FastRational operator%(const FastRational& d) { - assert(isInteger() && d.isInteger()); - if (wordPartValid() && d.wordPartValid()) { - uword w = absVal(num % d.num); // Largest value is absVal(INT_MAX % INT_MIN) = INT_MAX - return (word)(d.num > 0 ? w : -w); // No overflow since 0 <= w <= INT_MAX - } - FastRational r = (*this) / d; - r = r.floor(); - r = (*this) - r*d; - return r; - } }; -FastRational fastrat_fdiv_q(FastRational const & n, FastRational const & d); FastRational fastrat_round_to_int(const FastRational& n); -struct FastRationalHash { - uint32_t operator() (const FastRational& s) const { - return (uint32_t)s.getHashValue(); - } -}; - -inline std::ostream & operator<<(std::ostream & out, const FastRational & r) -{ - r.print(out); - return out; -} inline FastRational::FastRational(const FastRational& x) { if (x.wordPartValid()) { num = x.num; @@ -527,7 +508,10 @@ inline int FastRational::sign() const { } } -template integer gcd(integer a, integer b) { +static_assert(!std::integral); + +template +integer gcd(integer a, integer b) { if (a==0) return b; if (b==0) return a; if (b > a) { @@ -543,16 +527,6 @@ template integer gcd(integer a, integer b) { } } -template -FastRational lcm(integer a, integer b) { - if (a == 0) return 0; - if (b == 0) return 0; - if (b > a) - return FastRational(b / gcd(a, b)) * a; - else - return FastRational(a / gcd(a, b)) * b; -} - // Return 1 if |op1| > |op2|, -1 if |op1| < |op2|, and 0 if op1 = op2 inline int cmpabs(FastRational op1, FastRational op2) { @@ -562,8 +536,6 @@ inline int cmpabs(FastRational op1, FastRational op2) op2 = -op2; return op1.compare(op2); }; -template ulword gcd(ulword a, ulword b); -template uword gcd(uword a, uword b); #define CHECK_WORD(var, value) \ do { \ lword tmp = value; \ @@ -1047,14 +1019,6 @@ inline FastRational FastRational::inverse() const { return dest; } -inline FastRational abs(FastRational const & x) { - if (x.sign() >= 0) { - return x; - } else { - return -x; - } -} - FastRational get_multiplicand(const std::vector& reals); } diff --git a/src/common/numbers/Integer.h b/src/common/numbers/Integer.h index 7d8e6c563..bac829aed 100644 --- a/src/common/numbers/Integer.h +++ b/src/common/numbers/Integer.h @@ -5,10 +5,10 @@ #ifndef OPENSMT_INTEGER_H #define OPENSMT_INTEGER_H -#include "Number.h" +#include "FastInteger.h" namespace opensmt { -typedef Number Integer2; +using Integer = FastInteger; } #endif // OPENSMT_INTEGER_H diff --git a/src/common/numbers/Number.cc b/src/common/numbers/Number.cc new file mode 100644 index 000000000..071ea4851 --- /dev/null +++ b/src/common/numbers/Number.cc @@ -0,0 +1,7 @@ +#include "Number.h" + +namespace opensmt { +FlexibleNumber::FlexibleNumber(char const * str) : FlexibleNumber(Real{str}) { + tryTurnToInteger(); +} +} // namespace opensmt diff --git a/src/common/numbers/Number.h b/src/common/numbers/Number.h index 77e755f70..db01e6778 100644 --- a/src/common/numbers/Number.h +++ b/src/common/numbers/Number.h @@ -5,37 +5,376 @@ #ifndef OPENSMT_NUMBER_H #define OPENSMT_NUMBER_H -#define FAST_RATIONALS 1 +#include "Integer.h" +#include "NumberConcept.h" +#include "Real.h" -#ifdef FAST_RATIONALS -#include "FastRational.h" -#else -#include -#endif +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace opensmt { -#ifdef FAST_RATIONALS -using Number = FastRational; -using NumberHash = FastRationalHash; -#else -using Number = mpq_class; -#endif +class StrictNumber : public std::variant { +public: + struct Hash { + std::size_t operator()(StrictNumber const & n) const { return n.hash(); } + }; + + using variant::variant; + // No constructors from non-class values + + bool isInteger() const noexcept { return std::holds_alternative(*this); } + bool isReal() const noexcept { return std::holds_alternative(*this); } + + Integer const * tryGetInteger() const noexcept { return std::get_if(this); } + Integer * tryGetInteger() noexcept { return std::get_if(this); } + Real const * tryGetReal() const noexcept { return std::get_if(this); } + Real * tryGetReal() noexcept { return std::get_if(this); } + + int sign() const { + return std::visit([](auto & x) { return x.sign(); }, *this); + } + + bool isZero() const { + return std::visit([](auto & x) { return x.isZero(); }, *this); + } + bool isOne() const { + return std::visit([](auto & x) { return x.isOne(); }, *this); + } + bool isIntegerValue() const { + return std::visit([](auto & x) { return x.isIntegerValue(); }, *this); + } + + bool operator==(StrictNumber const & rhs) const { return std::is_eq(operator<=>(rhs)); } + std::strong_ordering operator<=>(StrictNumber const & rhs) const { return compareTp(rhs); } + + StrictNumber operator+(StrictNumber const & rhs) const { return arithOperatorTp>(rhs); } + StrictNumber operator-(StrictNumber const & rhs) const { return arithOperatorTp>(rhs); } + StrictNumber operator*(StrictNumber const & rhs) const { return arithOperatorTp>(rhs); } + StrictNumber operator/(StrictNumber const & rhs) const { return arithOperatorTp>(rhs); } + StrictNumber & operator+=(StrictNumber const & rhs) { return arithAssignOperatorTp>(rhs); } + StrictNumber & operator-=(StrictNumber const & rhs) { return arithAssignOperatorTp>(rhs); } + StrictNumber & operator*=(StrictNumber const & rhs) { return arithAssignOperatorTp>(rhs); } + StrictNumber & operator/=(StrictNumber const & rhs) { return arithAssignOperatorTp>(rhs); } + StrictNumber operator+(std::integral auto val) const { return arithOperatorTp>(val); } + StrictNumber operator-(std::integral auto val) const { return arithOperatorTp>(val); } + StrictNumber operator*(std::integral auto val) const { return arithOperatorTp>(val); } + StrictNumber operator/(std::integral auto val) const { return arithOperatorTp>(val); } + StrictNumber & operator+=(std::integral auto val) { return arithAssignOperatorTp>(val); } + StrictNumber & operator-=(std::integral auto val) { return arithAssignOperatorTp>(val); } + StrictNumber & operator*=(std::integral auto val) { return arithAssignOperatorTp>(val); } + StrictNumber & operator/=(std::integral auto val) { return arithAssignOperatorTp>(val); } + + StrictNumber ceil() const { return ceilTp(); } + StrictNumber floor() const { return floorTp(); } + + StrictNumber & negate() { return negateTp(); } + StrictNumber & reset() noexcept { return resetTp(); } + + std::optional tryMakeInteger() const { + using Ret = std::optional; + return std::visit(Overload{ + [](Integer const & x) -> Ret { return x; }, + [](Real const & x) { return tryMakeInteger(x); }, + }, + *this); + } + + Real makeReal() const { + return std::visit([](auto & x) -> Real { return x; }, *this); + } + + bool tryTurnToInteger() { + return std::visit(Overload{ + [](Integer & x) { return true; }, + [this](Real & x) { return tryTurnToInteger(x); }, + }, + *this); + } + + void turnToReal() { + return std::visit(Overload{ + [this](Integer & x) { emplace(std::move(x)); }, + [](Real & x) {}, + }, + *this); + } + + std::string toString() const { + return std::visit([](auto & x) { return x.toString(); }, *this); + } + void print(std::ostream & os) const { + std::visit([&os](auto & x) { return x.print(os); }, *this); + } + + std::size_t hash() const { + return std::visit([](auto & x) { return x.hash(); }, *this); + } + +protected: + // The following functions currently make this assumption + static_assert(std::derived_from); + + template + std::strong_ordering compareTp(T const & rhs) const { + constexpr bool const isStrict = std::is_same_v; + return intToOrdering(std::visit( + [](U const & x, V const & y) { + if constexpr (std::is_same_v) { + return x.compare(y); + } else if constexpr (isStrict) { + // if strict, then using distinct arguments is UB + assert(false); + return 1; + } else { + return static_cast(x).compare(static_cast(y)); + } + }, + *this, rhs)); + } + + template + T arithOperatorTp(T const & rhs) const { + constexpr bool const isStrict = std::is_same_v; + return std::visit( + [](U const & x, V const & y) -> T { + if constexpr (std::is_same_v) { + return Op{}(x, y); + } else if constexpr (isStrict) { + // if strict, then using distinct arguments is UB + assert(false); + return {}; + } else { + return Op{}(static_cast(x), static_cast(y)); + } + }, + *this, rhs); + } + + template + T & arithAssignOperatorTp(T const & rhs) { + using AssignOp = CompoundAssignOf; + constexpr bool const isStrict = std::is_same_v; + constexpr bool const isDiv = std::is_same_v>; + std::visit( + [](U & x, V const & y) { + if constexpr (std::is_same_v and not isDiv) { + AssignOp{}(x, y); + } else if constexpr (isStrict or isDiv) { + // UB + assert(false); + } else { + AssignOp{}(static_cast(x), static_cast(y)); + } + }, + *this, rhs); + + return static_cast(*this); + } + + template + T arithOperatorTp(std::integral auto val) const { + return std::visit([val](auto & x) -> T { return Op{}(x, val); }, *this); + } + + template + T & arithAssignOperatorTp(std::integral auto val) { + using AssignOp = CompoundAssignOf; + constexpr bool const isDiv = std::is_same_v>; + if constexpr (isDiv) { + // UB + assert(false); + } else { + std::visit([val](auto & x) { AssignOp{}(x, val); }, *this); + } + + return static_cast(*this); + } + + template + T ceilTp() const { + return std::visit([](auto & x) -> T { return x.ceil(); }, *this); + } + template + T floorTp() const { + return std::visit([](auto & x) -> T { return x.floor(); }, *this); + } -inline bool isNegative(Number const & num) { - return num.sign() < 0; + template + T & negateTp() { + std::visit([](auto & x) { x.negate(); }, *this); + return static_cast(*this); + } + template + T & resetTp() noexcept { + std::visit([](auto & x) { x.reset(); }, *this); + return static_cast(*this); + } + + static std::optional tryMakeInteger(Real const & x) { + if (not x.isIntegerValue()) { return std::nullopt; } + return static_cast(x); + } + bool tryTurnToInteger(Real & x) { + if (not x.isIntegerValue()) { return false; } + emplace(std::move(x)); + return true; + } +}; + +auto operator==(StrictNumber::variant const &, StrictNumber::variant const &) = delete; +auto operator!=(StrictNumber::variant const &, StrictNumber::variant const &) = delete; +auto operator<(StrictNumber::variant const &, StrictNumber::variant const &) = delete; +auto operator>(StrictNumber::variant const &, StrictNumber::variant const &) = delete; +auto operator<=(StrictNumber::variant const &, StrictNumber::variant const &) = delete; +auto operator>=(StrictNumber::variant const &, StrictNumber::variant const &) = delete; +auto operator<=>(StrictNumber::variant const &, StrictNumber::variant const &) = delete; + +static_assert(std::constructible_from); +static_assert(std::constructible_from); +static_assert(not std::convertible_to); +static_assert(not std::convertible_to); + +class FlexibleNumber : public StrictNumber { +public: + using StrictNumber::StrictNumber; + FlexibleNumber(std::integral auto val) : FlexibleNumber(Integer{val}) {} + explicit FlexibleNumber(std::integral auto den, std::integral auto num) : FlexibleNumber({Real{den, num}}) {} + explicit FlexibleNumber(char const *); + + operator Real() const { return makeReal(); } + + bool operator==(FlexibleNumber const & rhs) const { return std::is_eq(operator<=>(rhs)); } + std::strong_ordering operator<=>(FlexibleNumber const & rhs) const { return compareTp(rhs); } + + FlexibleNumber operator+(FlexibleNumber const & rhs) const { return arithOperatorTp>(rhs); } + FlexibleNumber operator-(FlexibleNumber const & rhs) const { return arithOperatorTp>(rhs); } + FlexibleNumber operator*(FlexibleNumber const & rhs) const { return arithOperatorTp>(rhs); } + FlexibleNumber operator/(FlexibleNumber const & rhs) const { return arithOperatorTp>(rhs); } + FlexibleNumber & operator+=(FlexibleNumber const & rhs) { return arithAssignOperatorTp>(rhs); } + FlexibleNumber & operator-=(FlexibleNumber const & rhs) { return arithAssignOperatorTp>(rhs); } + FlexibleNumber & operator*=(FlexibleNumber const & rhs) { + return arithAssignOperatorTp>(rhs); + } + FlexibleNumber & operator/=(FlexibleNumber const & rhs) { return arithAssignOperatorTp>(rhs); } + FlexibleNumber operator+(std::integral auto val) const { + return arithOperatorTp, FlexibleNumber>(val); + } + FlexibleNumber operator-(std::integral auto val) const { + return arithOperatorTp, FlexibleNumber>(val); + } + FlexibleNumber operator*(std::integral auto val) const { + return arithOperatorTp, FlexibleNumber>(val); + } + FlexibleNumber operator/(std::integral auto val) const { + return arithOperatorTp, FlexibleNumber>(val); + } + FlexibleNumber & operator+=(std::integral auto val) { + return arithAssignOperatorTp, FlexibleNumber>(val); + } + FlexibleNumber & operator-=(std::integral auto val) { + return arithAssignOperatorTp, FlexibleNumber>(val); + } + FlexibleNumber & operator*=(std::integral auto val) { + return arithAssignOperatorTp, FlexibleNumber>(val); + } + FlexibleNumber & operator/=(std::integral auto val) { + return arithAssignOperatorTp, FlexibleNumber>(val); + } + + FlexibleNumber ceil() const { return ceilTp(); } + FlexibleNumber floor() const { return floorTp(); } + + FlexibleNumber & negate() { return negateTp(); } + FlexibleNumber & reset() noexcept { return resetTp(); } +}; + +static_assert(std::constructible_from); +static_assert(std::constructible_from); +static_assert(not std::convertible_to); +static_assert(std::convertible_to); + +static_assert(std::derived_from); + +// This way it works both for StrictNumber and FlexibleNumber +constexpr StrictNumber const & castStrict(StrictNumber const & n) { + return static_cast(n); +} +constexpr StrictNumber & castStrict(StrictNumber & n) { + return static_cast(n); +} +constexpr StrictNumber && castStrict(StrictNumber && n) { + return static_cast(n); +} + +constexpr FlexibleNumber const & castFlexible(StrictNumber const & n) { + return static_cast(n); +} +constexpr FlexibleNumber & castFlexible(StrictNumber & n) { + return static_cast(n); +} +constexpr FlexibleNumber && castFlexible(StrictNumber && n) { + return static_cast(n); } -inline bool isPositive(Number const & num) { - return num.sign() > 0; +inline Integer const & castInteger(StrictNumber const & n) { + assert(n.isInteger()); + return *n.tryGetInteger(); +} +inline Integer & castInteger(StrictNumber & n) { + return const_cast(castInteger(std::as_const(n))); } -inline bool isNonNegative(Number const & num) { - return num.sign() >= 0; +inline Real const & castReal(StrictNumber const & n) { + assert(n.isReal()); + return *n.tryGetReal(); +} +inline Real & castReal(StrictNumber & n) { + return const_cast(castReal(std::as_const(n))); } -inline bool isNonPositive(Number const & num) { - return num.sign() <= 0; +// To have more control over the code, we use the stricter variant by default +// If flexibility is desired, explicitly use `castFlexible` +using Number = StrictNumber; + +static_assert(number); +static_assert(number); +static_assert(number); +static_assert(number); +static_assert(number); + +template +inline T makeNumber(Integer x, bool makeReal = false) { + T n{std::move(x)}; + assert(n.isInteger()); + if (makeReal) { n.turnToReal(); } + return n; +} +template +inline T makeNumber(Real x, bool makeInt = false) { + constexpr bool const isStrict = std::is_same_v; + T n{std::move(x)}; + assert(n.isReal()); + if (makeInt) { + [[maybe_unused]] bool const success = n.tryTurnToInteger(); + assert(not isStrict or success); + } + return n; } } // namespace opensmt +namespace std { +template<> +struct hash : opensmt::StrictNumber::Hash {}; +} // namespace std + #endif // OPENSMT_NUMBER_H diff --git a/src/common/numbers/NumberConcept.h b/src/common/numbers/NumberConcept.h new file mode 100644 index 000000000..4bd605525 --- /dev/null +++ b/src/common/numbers/NumberConcept.h @@ -0,0 +1,65 @@ +// +// Created by Tomas Kolarik on 06.11.24 +// + +#ifndef OPENSMT_NUMBERCONCEPT_H +#define OPENSMT_NUMBERCONCEPT_H + +#include +#include +#include + +namespace opensmt { +template +concept number = requires(T & t, std::ostream & os) { + { t.sign() } -> std::convertible_to; + { t.isZero() } -> std::convertible_to; + { t.isOne() } -> std::convertible_to; + { t.isIntegerValue() } -> std::convertible_to; + { t.ceil() } -> std::same_as; + { t.floor() } -> std::same_as; + t.negate(); + t.reset(); + { t.toString() } -> std::convertible_to; + t.print(os); + { t.hash() } -> std::convertible_to; +}; + +inline bool isNegative(number auto const & x) { + return x.sign() < 0; +} + +inline bool isPositive(number auto const & x) { + return x.sign() > 0; +} + +inline bool isNonNegative(number auto const & x) { + return x.sign() >= 0; +} + +inline bool isNonPositive(number auto const & x) { + return x.sign() <= 0; +} + +template +inline T operator-(T x) { + x.negate(); + return x; +} + +template +inline T abs(T const & x) { + if (isNonNegative(x)) { + return x; + } else { + return -x; + } +} + +inline std::ostream & operator<<(std::ostream & os, number auto const & x) { + x.print(os); + return os; +} +} // namespace opensmt + +#endif // OPENSMT_NUMBERCONCEPT_H diff --git a/src/common/numbers/NumberUtils.h b/src/common/numbers/NumberUtils.h index bd3c331ee..eeda8eaa1 100644 --- a/src/common/numbers/NumberUtils.h +++ b/src/common/numbers/NumberUtils.h @@ -12,24 +12,12 @@ #include namespace opensmt { - typedef mpz_class Integer; //PS. related to BV logic - - void static inline wordToBinary(const Integer x, char *&bin, const int width) { - bin = (char *) malloc(width + 1); - - int p = 0; - Integer one = 1; - for (Integer i = (one << (width - 1)); i > 0; i >>= 1) - bin[p++] = ((x & i) == i) ? '1' : '0'; - bin[p] = '\0'; - } - - void static inline wordToBinary(const unsigned x, char *&bin, const int width) { + void static inline wordToBinary(const auto x, char *&bin, const int width) { bin = (char *) malloc(width + 1); int p = 0; - Integer one = 1; - for (Integer i = (one << (width - 1)); i > 0; i >>= 1) + mpz_class one = 1; + for (mpz_class i = (one << (width - 1)); i > 0; i >>= 1) bin[p++] = ((x & i) == i) ? '1' : '0'; bin[p] = '\0'; } diff --git a/src/common/numbers/Real.h b/src/common/numbers/Real.h index 63ddd22dc..bef280bb3 100644 --- a/src/common/numbers/Real.h +++ b/src/common/numbers/Real.h @@ -5,10 +5,20 @@ #ifndef OPENSMT_REAL_H #define OPENSMT_REAL_H -#include "Number.h" +#define FAST_RATIONALS + +#ifdef FAST_RATIONALS +#include "FastRational.h" +#else +#include +#endif namespace opensmt { -typedef Number Real; -} +#ifdef FAST_RATIONALS +using Real = FastRational; +#else +using Real = mpq_class; +#endif +} // namespace opensmt #endif // OPENSMT_REAL_H diff --git a/src/logics/ArithLogic.cc b/src/logics/ArithLogic.cc index 74c8b86c7..f9713be39 100644 --- a/src/logics/ArithLogic.cc +++ b/src/logics/ArithLogic.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -28,7 +29,6 @@ namespace { protected: ArithLogic & l; virtual void Op(Number & s, Number const & v) const = 0; - virtual Number getIdOp() const = 0; virtual void constSimplify(SymRef s, vec const & terms, SymRef & s_new, vec & terms_new) const = 0; @@ -40,7 +40,6 @@ namespace { class SimplifyConstSum : public SimplifyConst { void Op(Number & s, Number const & v) const { s += v; } - Number getIdOp() const { return 0; } void constSimplify(SymRef s, vec const & terms, SymRef & s_new, vec & terms_new) const; public: @@ -49,7 +48,6 @@ namespace { class SimplifyConstTimes : public SimplifyConst { void Op(Number & s, Number const & v) const { s *= v; } - Number getIdOp() const { return 1; } void constSimplify(SymRef s, vec const & terms, SymRef & s_new, vec & terms_new) const; public: @@ -58,10 +56,9 @@ namespace { class SimplifyConstDiv : public SimplifyConst { void Op(Number & s, Number const & v) const { - if (v == 0) { printf("explicit div by zero\n"); } + if (v.isZero()) { printf("explicit div by zero\n"); } s /= v; } - Number getIdOp() const { return 1; } void constSimplify(SymRef s, vec const & terms, SymRef & s_new, vec & terms_new) const; public: @@ -199,6 +196,14 @@ Number const & ArithLogic::getNumConst(PTRef tr) const { return *numbers[id]; } +Integer const & ArithLogic::getIntConst(PTRef tr) const { + return castInteger(getNumConst(tr)); +} + +Real const & ArithLogic::getRealConst(PTRef tr) const { + return castReal(getNumConst(tr)); +} + pair> ArithLogic::getConstantAndFactors(PTRef sum) const { assert(isPlus(sum)); vec varFactors; @@ -245,7 +250,7 @@ pair ArithLogic::splitTermToVarAndConst(PTRef term) const { PTRef ArithLogic::normalizeMul(PTRef mul) { assert(isTimes(mul)); auto [v, c] = splitTermToVarAndConst(mul); - if (getNumConst(c) < 0) { + if (isNegative(getNumConst(c))) { return mkNeg(v); } else { return v; @@ -319,7 +324,7 @@ namespace { auto val = substitutions[var]; if (not logic.isZero(val)) { poly_t constantPoly; - constantPoly.addTerm(PTRef_Undef, coeff * logic.getNumConst(val)); + constantPoly.addTerm(PTRef_Undef, coeff * logic.getNumConst(val).makeReal()); poly.merge(constantPoly, 1); } changed = true; @@ -404,13 +409,13 @@ lbool ArithLogic::arithmeticElimination(vec const & top_level_arith, Subs assert(logic.isLinearTerm(polyTerm)); if (logic.isLinearFactor(polyTerm)) { auto [var, c] = logic.splitTermToVarAndConst(polyTerm); - auto coeff = logic.getNumConst(c); + auto coeff = logic.getNumConst(c).makeReal(); poly.addTerm(var, std::move(coeff)); } else { assert(logic.isPlus(polyTerm)); for (PTRef factor : logic.getPterm(polyTerm)) { auto [var, c] = logic.splitTermToVarAndConst(factor); - auto coeff = logic.getNumConst(c); + auto coeff = logic.getNumConst(c).makeReal(); poly.addTerm(var, std::move(coeff)); } } @@ -543,7 +548,8 @@ PTRef ArithLogic::mkNeg(PTRef tr) { } PTRef ArithLogic::mkConst(SRef sort, Number const & c) { - std::string str = c.get_str(); // MB: I cannot store c.get_str().c_str() directly, since that is a pointer + assert(c.isInteger() or sort == sort_REAL or c.isIntegerValue()); + std::string str = c.toString(); // MB: I cannot store c.toString().c_str() directly, since that is a pointer // inside temporary object -> crash. char const * val = str.c_str(); PTRef ptr = PTRef_Undef; @@ -553,8 +559,12 @@ PTRef ArithLogic::mkConst(SRef sort, Number const & c) { for (auto i = numbers.size(); i <= id; i++) { numbers.emplace_back(); } - if (numbers[id] == nullptr) { numbers[id] = new Number(val); } - assert(c == *numbers[id]); + if (numbers[id] == nullptr) { + assert(sort == sort_REAL or sort == sort_INT); + if (sort == sort_REAL) { numbers[id] = new Number{Real{val}}; } + else { numbers[id] = new Number{Integer{val}}; } + } + assert(castFlexible(c) == castFlexible(*numbers[id])); markConstant(id); return ptr; } @@ -800,12 +810,11 @@ PTRef ArithLogic::mkMod(vec && args) { if (isZero(divisor)) { throw ArithDivisionByZeroException(); } if (isOne(divisor) or isMinusOne(divisor)) { return getTerm_IntZero(); } if (isConstant(dividend)) { - auto const & dividendValue = getNumConst(dividend); - auto const & divisorValue = getNumConst(divisor); - assert(dividendValue.isInteger() and divisorValue.isInteger()); + auto const & dividendValue = getIntConst(dividend); + auto const & divisorValue = getIntConst(divisor); // evaluate immediately the operation on two constants auto realDiv = dividendValue / divisorValue; - auto intDiv = divisorValue.sign() > 0 ? realDiv.floor() : realDiv.ceil(); + auto intDiv = static_cast(isPositive(divisorValue) ? realDiv.floor() : realDiv.ceil()); auto intMod = dividendValue - intDiv * divisorValue; assert(intMod.sign() >= 0 and intMod < abs(divisorValue)); return mkIntConst(intMod); @@ -824,12 +833,11 @@ PTRef ArithLogic::mkIntDiv(vec && args) { if (isMinusOne(divisor)) { return mkNeg(dividend); } if (isConstant(divisor) and isConstant(dividend)) { - auto const & dividendValue = getNumConst(dividend); - auto const & divisorValue = getNumConst(divisor); - assert(dividendValue.isInteger() and divisorValue.isInteger()); + auto const & dividendValue = getIntConst(dividend); + auto const & divisorValue = getIntConst(divisor); // evaluate immediately the operation on two constants auto realDiv = dividendValue / divisorValue; - auto intDiv = divisorValue.sign() > 0 ? realDiv.floor() : realDiv.ceil(); + auto intDiv = static_cast(isPositive(divisorValue) ? realDiv.floor() : realDiv.ceil()); return mkIntConst(intDiv); } return mkFun(sym_Int_DIV, std::move(args)); @@ -848,7 +856,7 @@ PTRef ArithLogic::mkRealDiv(vec && args) { simp.simplify(get_sym_Real_DIV(), args, s_new, args_new); if (isRealDiv(s_new)) { assert((isNumTerm(args_new[0]) || isPlus(args_new[0])) && isConstant(args_new[1])); - args_new[1] = mkRealConst(getNumConst(args_new[1]).inverse()); // mkConst(1/getRealConst(args_new[1])); + args_new[1] = mkRealConst(getRealConst(args_new[1]).inverse()); // mkConst(1/getRealConst(args_new[1])); return mkTimes(args_new); } PTRef tr = mkFun(s_new, std::move(args_new)); @@ -875,20 +883,24 @@ PTRef ArithLogic::mkConst(SRef s, char const * name) { PTRef ptr = PTRef_Undef; if (s == sort_REAL or s == sort_INT) { char * rat; - if (s == sort_REAL) + Number * numPtr; + if (s == sort_REAL) { + // TK: Currently this must be consistent with Real::toString, otherwise it will make duplicates in term_store stringToRational(rat, name); - else { + numPtr = new Number{Real{rat}}; + } else { if (not isIntString(name)) throw ApiException("Not parseable as an integer"); rat = strdup(name); + numPtr = new Number{Integer{rat}}; } ptr = mkVar(s, rat, true); + free(rat); // Store the value of the number as a real SymId id = sym_store[getPterm(ptr).symb()].getId(); for (auto i = numbers.size(); i <= id; i++) numbers.emplace_back(nullptr); if (numbers[id] != nullptr) { delete numbers[id]; } - numbers[id] = new Number(rat); - free(rat); + numbers[id] = numPtr; markConstant(id); } else ptr = Logic::mkConst(s, name); @@ -1131,12 +1143,12 @@ std::string ArithLogic::printTerm_(PTRef tr, bool ext, bool safe) const { bool is_neg = false; char * tmp_str; stringToRational(tmp_str, sym_store.getName(getPterm(tr).symb())); - Number v(tmp_str); + Real v{tmp_str}; if (!isNonNegative(v)) { v.negate(); is_neg = true; } - std::string rat_str = v.get_str(); + std::string rat_str = v.toString(); free(tmp_str); bool is_div = false; unsigned i = 0; @@ -1194,32 +1206,31 @@ std::string ArithLogic::printTerm_(PTRef tr, bool ext, bool safe) const { pair ArithLogic::sumToNormalizedIntPair(PTRef sum) { auto [constantValue, varFactors] = getConstantAndFactors(sum); + auto & constantValueInt = castInteger(constantValue); vec vars; vars.capacity(varFactors.size()); - std::vector coeffs; + std::vector coeffs; coeffs.reserve(varFactors.size()); for (PTRef factor : varFactors) { auto [var, coeff] = splitTermToVarAndConst(factor); assert(ArithLogic::isNumVarLike(var) and isNumConst(coeff)); vars.push(var); - coeffs.push_back(getNumConst(coeff)); + coeffs.push_back(getIntConst(coeff)); } bool changed = false; // Keep track if any change to varFactors occurs bool allIntegers = - std::all_of(coeffs.begin(), coeffs.end(), [](Number const & coeff) { return coeff.isInteger(); }); + std::all_of(coeffs.begin(), coeffs.end(), [](Real const & coeff) { return coeff.isIntegerValue(); }); if (not allIntegers) { // first ensure that all coeffs are integers - // this would probably not work when `Number` is not `FastRational` - using Integer = FastRational; // TODO: change when we have FastInteger auto lcmOfDenominators = Integer(1); - auto accumulateLCMofDenominators = [&lcmOfDenominators](FastRational const & next) { - if (next.isInteger()) { + auto accumulateLCMofDenominators = [&lcmOfDenominators](Real const & next) { + if (next.isIntegerValue()) { // denominator is 1 => lcm of denominators stays the same return; } - Integer den = next.get_den(); + Integer den = static_cast(next.get_den()); if (lcmOfDenominators == 1) { lcmOfDenominators = std::move(den); return; @@ -1229,26 +1240,27 @@ pair ArithLogic::sumToNormalizedIntPair(PTRef sum) { std::for_each(coeffs.begin(), coeffs.end(), accumulateLCMofDenominators); for (auto & coeff : coeffs) { coeff *= lcmOfDenominators; - assert(coeff.isInteger()); + assert(coeff.isIntegerValue()); } // DONT forget to update also the constant factor - constantValue *= lcmOfDenominators; + constantValueInt *= lcmOfDenominators; changed = true; } - assert(std::all_of(coeffs.begin(), coeffs.end(), [](Number const & coeff) { return coeff.isInteger(); })); + assert(std::all_of(coeffs.begin(), coeffs.end(), [](Real const & coeff) { return coeff.isIntegerValue(); })); // Now make sure all coeffs are coprime - auto coeffs_gcd = abs(coeffs[0]); + Integer coeffs_gcd = static_cast(abs(coeffs[0])); for (std::size_t i = 1; i < coeffs.size() && coeffs_gcd != 1; ++i) { - coeffs_gcd = gcd(coeffs_gcd, abs(coeffs[i])); - assert(coeffs_gcd.isInteger()); + coeffs_gcd = gcd(coeffs_gcd, static_cast(abs(coeffs[i]))); } if (coeffs_gcd != 1) { for (auto & coeff : coeffs) { coeff /= coeffs_gcd; - assert(coeff.isInteger()); + assert(coeff.isIntegerValue()); } // DONT forget to update also the constant factor - constantValue /= coeffs_gcd; + auto tmpReal = std::move(constantValueInt) / coeffs_gcd; + assert(tmpReal.isIntegerValue()); + constantValueInt = static_cast(tmpReal); changed = true; } // update the factors @@ -1259,7 +1271,7 @@ pair ArithLogic::sumToNormalizedIntPair(PTRef sum) { } PTRef normalizedSum = varFactors.size() == 1 ? varFactors[0] : mkFun(get_sym_Int_PLUS(), std::move(varFactors)); // 0 <= normalizedSum + constantValue - constantValue.negate(); + constantValueInt.negate(); return {std::move(constantValue), normalizedSum}; } @@ -1275,20 +1287,21 @@ pair ArithLogic::sumToNormalizedIntPair(PTRef sum) { pair ArithLogic::sumToNormalizedRealPair(PTRef sum) { auto [constantValue, varFactors] = getConstantAndFactors(sum); + auto & constantValueReal = castReal(constantValue); PTRef leadingFactor = varFactors[0]; // normalize the sum according to the leading factor auto [var, coeff] = splitTermToVarAndConst(leadingFactor); - Number normalizationCoeff = abs(getNumConst(coeff)); + Real normalizationCoeff = abs(getRealConst(coeff)); // varFactors come from a normalized sum, no need to call normalization code again PTRef normalizedSum = varFactors.size() == 1 ? varFactors[0] : mkFun(get_sym_Real_PLUS(), std::move(varFactors)); - if (normalizationCoeff != 1) { + if (not normalizationCoeff.isOne()) { // normalize the whole sum normalizedSum = mkTimes(normalizedSum, mkRealConst(normalizationCoeff.inverse())); // DON'T forget to update also the constant factor! - constantValue /= normalizationCoeff; + constantValueReal /= normalizationCoeff; } - constantValue.negate(); // moving the constant to the LHS of the inequality + constantValueReal.negate(); // moving the constant to the LHS of the inequality return {std::move(constantValue), normalizedSum}; } @@ -1308,7 +1321,7 @@ PTRef ArithLogic::sumToNormalizedInequality(PTRef sum) { PTRef ArithLogic::sumToNormalizedEquality(PTRef sum) { auto [lhsVal, rhs] = sumToNormalizedPair(sum); SRef sort = getSortRef(sum); - if (isSortInt(sort) and not lhsVal.isInteger()) { return getTerm_false(); } + if (isSortInt(sort) and not lhsVal.isIntegerValue()) { return getTerm_false(); } // Ensure that in equality we always have positive leading variable if (hasNegativeLeadingVariable(rhs)) { rhs = mkNeg(rhs); diff --git a/src/logics/ArithLogic.h b/src/logics/ArithLogic.h index ecbdc9348..10ed1d7b1 100644 --- a/src/logics/ArithLogic.h +++ b/src/logics/ArithLogic.h @@ -47,13 +47,19 @@ class ArithLogic : public Logic { PTRef mkConst(SRef s, std::string const & name) { return mkConst(s, name.c_str()); } PTRef mkConst(SRef s, Number const & c); PTRef mkIntConst(Number const & c) { + assert(c.isInteger()); if (not hasIntegers()) { throw ApiException("Create Int constant in non-integral logic"); } return mkConst(getSort_int(), c); } + // To allow in-place construction from Integer-specific arguments + PTRef mkIntConst(Integer c) { return mkIntConst(Number{std::move(c)}); } PTRef mkRealConst(Number const & c) { + assert(c.isReal()); if (not hasReals()) { throw ApiException("Create Real constant in non-real logic"); } return mkConst(getSort_real(), c); } + // To allow in-place construction from Real-specific arguments + PTRef mkRealConst(Real c) { return mkRealConst(Number{std::move(c)}); } PTRef mkIntVar(char const * name) { if (not hasIntegers()) { throw ApiException("Create Int var in non-integral logic"); } return mkVar(sort_INT, name, false); @@ -95,6 +101,8 @@ class ArithLogic : public Logic { bool yieldsSortNum(PTRef tr) const { return yieldsSortInt(tr) or yieldsSortReal(tr); } Number const & getNumConst(PTRef tr) const; + Integer const & getIntConst(PTRef tr) const; + Real const & getRealConst(PTRef tr) const; bool isUFEquality(PTRef tr) const override { return !isNumEq(tr) && Logic::isUFEquality(tr); } bool isAtom(PTRef tr) const override { diff --git a/src/logics/BVLogic.h b/src/logics/BVLogic.h index 9103642dc..7d3af801d 100644 --- a/src/logics/BVLogic.h +++ b/src/logics/BVLogic.h @@ -93,8 +93,8 @@ class BVLogic: public Logic virtual std::string const getName() const override { return "QF_BV"; } // virtual PTRef insertTerm(SymRef sym, vec& terms, char** msg); - PTRef mkBVConst (const int c) { char* num; wordToBinary(c, num, getBitWidth()); PTRef tr = Logic::mkConst(sort_BVNUM, num); free(num); return tr; } // Convert the int c to binary - PTRef mkBVConst (const char* c) { char* num; wordToBinary(Integer(c), num, getBitWidth()); PTRef tr = Logic::mkConst(sort_BVNUM, num); free(num); return tr; } // Convert the string c to binary + PTRef mkBVConst (const int c) { char* num; wordToBinary(unsigned(c), num, getBitWidth()); PTRef tr = Logic::mkConst(sort_BVNUM, num); free(num); return tr; } // Convert the int c to binary + PTRef mkBVConst (const char* c) { char* num; wordToBinary(mpz_class(c), num, getBitWidth()); PTRef tr = Logic::mkConst(sort_BVNUM, num); free(num); return tr; } // Convert the string c to binary virtual PTRef mkBVNumVar (const char* name) { return mkVar(sort_BVNUM, name); } virtual bool isBuiltinSortSym(SSymRef ssr) const override { return (ssr == sort_store.getSortSym(sort_BVNUM)); } virtual bool isBuiltinSort(SRef sr) const override { return (sr == sort_BVNUM); } diff --git a/src/rewriters/DivModRewriter.h b/src/rewriters/DivModRewriter.h index 2aa90038e..011a89181 100644 --- a/src/rewriters/DivModRewriter.h +++ b/src/rewriters/DivModRewriter.h @@ -38,8 +38,7 @@ class DivModConfig : public DefaultRewriterConfig { if (not inCache) { // collect the definitions to add assert(logic.isConstant(divisor)); - auto divisorVal = logic.getNumConst(divisor); - assert(divisorVal.isInteger()); + auto divisorVal = logic.getIntConst(divisor); // general case auto upperBound = abs(divisorVal) - 1; // dividend = divVar * divisor + modVar diff --git a/src/simplifiers/LA.cc b/src/simplifiers/LA.cc index 682c64d61..b5f18f248 100644 --- a/src/simplifiers/LA.cc +++ b/src/simplifiers/LA.cc @@ -58,7 +58,7 @@ void LAExpression::initialize(PTRef e, bool do_canonize) { // If it is times, then one side must be constant, other // is enqueued with a new constant auto [var, constant] = logic.splitTermToVarAndConst(t); - Real new_c = logic.getNumConst(constant); + Real new_c = logic.getRealConst(constant); new_c *= c; curr_term.emplace_back(var); curr_const.emplace_back(std::move(new_c)); @@ -66,7 +66,7 @@ void LAExpression::initialize(PTRef e, bool do_canonize) { // Otherwise it is a variable, Ite, UF or constant assert(logic.isNumVarLike(t) || logic.isConstant(t) || logic.isUF(t)); if (logic.isConstant(t)) { - const Real tval = logic.getNumConst(t); + const Real tval = logic.getRealConst(t); polynome[PTRef_Undef] += tval * c; } else { auto it = polynome.find(t); diff --git a/src/tsolvers/bvsolver/BitBlaster.cc b/src/tsolvers/bvsolver/BitBlaster.cc index b4f31cd18..d2e4b36cc 100644 --- a/src/tsolvers/bvsolver/BitBlaster.cc +++ b/src/tsolvers/bvsolver/BitBlaster.cc @@ -2121,7 +2121,7 @@ void BitBlaster::computeModel( ) value = value + coeff * bit; coeff = Real( 2 ) * coeff; } - model[e] = logic.mkBVNumVar(value.get_str().c_str()); + model[e] = logic.mkBVNumVar(value.toString().c_str()); } has_model = true; } diff --git a/src/tsolvers/egraph/EgraphModelBuilder.cc b/src/tsolvers/egraph/EgraphModelBuilder.cc index ce8f549fe..4d4b37c9a 100644 --- a/src/tsolvers/egraph/EgraphModelBuilder.cc +++ b/src/tsolvers/egraph/EgraphModelBuilder.cc @@ -18,8 +18,8 @@ Map EgraphModelBuilder::computeNumericValues(ModelBuilder ArithLogic & arithLogic = dynamic_cast(logic); Map updatedValues; std::unordered_set delayedNumericTerms; - Number maxModelValue = 0; - auto updateMaxValue = [&maxModelValue](Number const & newVal) { + FlexibleNumber maxModelValue = 0; + auto updateMaxValue = [&maxModelValue](FlexibleNumber const & newVal) { if (newVal > maxModelValue) { maxModelValue = newVal; } }; for (ERef eref : enode_store.getTermEnodes()) { @@ -31,7 +31,7 @@ Map EgraphModelBuilder::computeNumericValues(ModelBuilder if (updatedValues.has(root)) { continue; } if (arithLogic.isNumConst(ptref_root)) { updatedValues.insert(root, ptref_root); - updateMaxValue(arithLogic.getNumConst(ptref_root)); + updateMaxValue(castFlexible(arithLogic.getNumConst(ptref_root))); continue; } PTRef ptref = enode_store.getPTRef(eref); @@ -39,7 +39,7 @@ Map EgraphModelBuilder::computeNumericValues(ModelBuilder PTRef value = model.getVarVal(ptref); assert(arithLogic.isNumConst(value)); updatedValues.insert(root, value); - updateMaxValue(arithLogic.getNumConst(value)); + updateMaxValue(castFlexible(arithLogic.getNumConst(value))); delayedNumericTerms.erase(root); continue; } @@ -47,7 +47,7 @@ Map EgraphModelBuilder::computeNumericValues(ModelBuilder // continue with next Enode } for (ERef delayedTerm : delayedNumericTerms) { - Number nextValue = maxModelValue + 1; + FlexibleNumber nextValue = maxModelValue + 1; SRef sort = logic.getSortRef(getEnode(delayedTerm).getTerm()); if (arithLogic.isSortInt(sort)) { nextValue = nextValue.floor(); diff --git a/src/tsolvers/lasolver/CutCreator.cc b/src/tsolvers/lasolver/CutCreator.cc index b3dce4a46..21536e4e1 100644 --- a/src/tsolvers/lasolver/CutCreator.cc +++ b/src/tsolvers/lasolver/CutCreator.cc @@ -28,7 +28,7 @@ CutCreator::Cut CutCreator::makeCut(SparseLinearSystem && system, ColumnMapping for (uint32_t rowIndex = 0; rowIndex < dim; ++rowIndex) { auto const & row = matrixU[rowIndex]; auto product = row.product(varValues); - if (not product.isInteger()) { return {row.toVector(), product}; } + if (not product.isIntegerValue()) { return {row.toVector(), product}; } } return {}; } diff --git a/src/tsolvers/lasolver/Delta.cc b/src/tsolvers/lasolver/Delta.cc index d845907f6..f178601be 100644 --- a/src/tsolvers/lasolver/Delta.cc +++ b/src/tsolvers/lasolver/Delta.cc @@ -50,8 +50,8 @@ char * Delta::printValue() const { char * out; int written = -1; written = asprintf(&out, "(%s | %s)", - r.get_str().c_str(), - d.get_str().c_str()); + r.toString().c_str(), + d.toString().c_str()); assert(written >= 0); (void) written; return out; diff --git a/src/tsolvers/lasolver/FarkasInterpolator.cc b/src/tsolvers/lasolver/FarkasInterpolator.cc index 2628287df..fac81da9b 100644 --- a/src/tsolvers/lasolver/FarkasInterpolator.cc +++ b/src/tsolvers/lasolver/FarkasInterpolator.cc @@ -654,8 +654,8 @@ PTRef FarkasInterpolator::getFlexibleInterpolant(Real strengthFactor) { auto sidesA = extractSides(itpA); auto sidesB = extractSides(itpB); assert(sidesA.first == logic.mkNeg(sidesB.first)); - Real c1 = logic.getNumConst(sidesA.second); - Real c2 = logic.getNumConst(sidesB.second); + Real c1 = logic.getRealConst(sidesA.second); + Real c2 = logic.getRealConst(sidesB.second); Real lowerBound = c1; Real upperBound = -c2; Real strengthDiff = upperBound - lowerBound; diff --git a/src/tsolvers/lasolver/LABounds.cc b/src/tsolvers/lasolver/LABounds.cc index f35b1a206..4a9c4c464 100644 --- a/src/tsolvers/lasolver/LABounds.cc +++ b/src/tsolvers/lasolver/LABounds.cc @@ -41,13 +41,13 @@ LABoundStore::printBound(LABoundRef br) const Real const & s = d.D(); BoundT type = ba[br].getType(); if ((type == bound_l) && (s == 0)) - written = asprintf(&str_out, "%s <= %s", r.get_str().c_str(), v_str); + written = asprintf(&str_out, "%s <= %s", r.toString().c_str(), v_str); if ((type == bound_l) && (s != 0)) - written = asprintf(&str_out, "%s < %s", r.get_str().c_str(), v_str); + written = asprintf(&str_out, "%s < %s", r.toString().c_str(), v_str); if ((type == bound_u) && (s == 0)) - written = asprintf(&str_out, "%s <= %s", v_str, r.get_str().c_str()); + written = asprintf(&str_out, "%s <= %s", v_str, r.toString().c_str()); if ((type == bound_u) && (s != 0)) - written = asprintf(&str_out, "%s < %s", v_str, r.get_str().c_str()); + written = asprintf(&str_out, "%s < %s", v_str, r.toString().c_str()); assert(written >= 0); (void)written; free(v_str); diff --git a/src/tsolvers/lasolver/LASolver.cc b/src/tsolvers/lasolver/LASolver.cc index 1d11eaea0..2207a6b0c 100644 --- a/src/tsolvers/lasolver/LASolver.cc +++ b/src/tsolvers/lasolver/LASolver.cc @@ -38,13 +38,13 @@ LABoundStore::BoundInfo LASolver::addBound(PTRef leq_tr) { LABoundRef br_neg; if (sum_term_is_negated) { - Real constr_neg = -logic.getNumConst(const_tr); + auto constr_neg = -logic.getNumConst(const_tr).makeReal(); bi = boundStore.allocBoundPair(v, this->getBoundsValue(v, constr_neg, false)); br_pos = bi.ub; br_neg = bi.lb; } else { - const Real& constr = logic.getNumConst(const_tr); + auto constr = logic.getNumConst(const_tr).makeReal(); bi = boundStore.allocBoundPair(v, this->getBoundsValue(v, constr, true)); br_pos = bi.lb; br_neg = bi.ub; @@ -219,10 +219,6 @@ void LASolver::setBound(PTRef leq_tr) addBound(leq_tr); } -Number LASolver::getNum(PTRef r) { - return logic.getNumConst(r); -} - void LASolver::notifyVar(LVRef v) { assert(logic.isNumVar(getVarPTRef(v))); if (logic.yieldsSortInt(getVarPTRef(v))) { @@ -264,7 +260,7 @@ std::unique_ptr LASolver::expressionToLVarPoly(PTRef term) for (int i = 0; i < logic.getPterm(term).size(); i++) { auto [v,c] = logic.splitTermToVarAndConst(logic.getPterm(term)[i]); LVRef var = getLAVar_single(v); - Real coeff = getNum(c); + Real coeff = logic.getRealConst(c); if (negated) { coeff.negate(); } @@ -315,7 +311,7 @@ LVRef LASolver::registerArithmeticTerm(PTRef expr) { notifyVar(term.var); simplex.nonbasicVar(term.var); // MB: Notify must be called before the query isIntVar! - isInt &= isIntVar(term.var) && term.coeff.isInteger(); + isInt &= isIntVar(term.var) && term.coeff.isIntegerValue(); } simplex.newRow(x, std::move(poly)); if (isInt) { @@ -769,7 +765,7 @@ TRes LASolver::check(bool complete) { bool LASolver::isModelInteger(LVRef v) const { Delta val = simplex.getValuation(v); - return !( val.hasDelta() || !val.R().isInteger() ); + return !( val.hasDelta() || !val.R().isIntegerValue() ); } PTRef LASolver::interpolateUsingEngine(FarkasInterpolator & interpolator) const { @@ -857,7 +853,7 @@ std::pair> linearSystemFromConstraints(std uint32_t rows = constraints.size(); SparseColMatrix matrixA(RowCount{rows}, ColumnCount{columns}); - std::vector rhs(rows); + std::vector rhs(rows); std::vector columnPolynomials(columns); // Second pass to build the actual matrix @@ -868,7 +864,7 @@ std::pair> linearSystemFromConstraints(std for (PTRef arg : terms) { auto [var, constant] = logic.splitTermToVarAndConst(arg); auto col = varIndices[var]; - columnPolynomials[col].addTerm(IndexType{row}, logic.getNumConst(constant)); + columnPolynomials[col].addTerm(IndexType{row}, logic.getRealConst(constant)); } } for (uint32_t i = 0; i < columnPolynomials.size(); ++i) { @@ -924,7 +920,7 @@ TRes LASolver::cutFromProof() { auto const & val = isOnLower ? simplex.Lb(var) : simplex.Ub(var); assert(not val.hasDelta()); auto const & rhs = val.R(); - assert(rhs.isInteger()); + assert(rhs.isIntegerValue()); if (isOnLower and isOnUpper) { constraints.insert(constraints.begin(), DefiningConstraint{term, rhs}); } else { @@ -955,7 +951,7 @@ TRes LASolver::cutFromProof() { vec LASolver::collectEqualitiesFor(vec const & vars, std::unordered_set const & knownEqualities) { struct DeltaHash { std::size_t operator()(Delta const & d) const { - NumberHash hasher; + Real::Hash hasher; return (hasher(d.R()) ^ hasher(d.D())); } }; @@ -964,7 +960,7 @@ vec LASolver::collectEqualitiesFor(vec const & vars, std::unordere std::unordered_map, DeltaHash> eqClasses; for (PTRef var : vars) { if (logic.isNumConst(var)) { - eqClasses[logic.getNumConst(var)].push(var); + eqClasses[logic.getRealConst(var)].push(var); } else { assert(logic.isNumVar(var)); if (not laVarMapper.hasVar(var)) { // LASolver does not have any constraints on this LA var @@ -1006,7 +1002,7 @@ vec LASolver::collectEqualitiesFor(vec const & vars, std::unordere if (isNonPositive(diff.R()) and isNegative(diff.D())) { continue; } auto ratio = diff.R() / diff.D(); assert(isNegative(ratio)); - if (ratio < Number(-1)) { continue; } // MB: ratio is -delta; hence -1 <= ratio < 0 + if (ratio < Real{-1}) { continue; } // MB: ratio is -delta; hence -1 <= ratio < 0 // They could be equal for the right value of delta, add equalities for cross-product vec const & varsOfFirstVal = eqClasses.at(val); diff --git a/src/tsolvers/lasolver/LASolver.h b/src/tsolvers/lasolver/LASolver.h index eb09515b9..7caec308d 100644 --- a/src/tsolvers/lasolver/LASolver.h +++ b/src/tsolvers/lasolver/LASolver.h @@ -101,8 +101,6 @@ class LASolver : public TSolver { std::unique_ptr expressionToLVarPoly(PTRef term); - Number getNum(PTRef); - bool isIntVar(LVRef v) { return int_vars_map.has(v); } void markVarAsInt(LVRef); diff --git a/src/tsolvers/lasolver/LAVarMapper.cc b/src/tsolvers/lasolver/LAVarMapper.cc index bcf3d95c5..63250ec92 100644 --- a/src/tsolvers/lasolver/LAVarMapper.cc +++ b/src/tsolvers/lasolver/LAVarMapper.cc @@ -53,7 +53,7 @@ bool LAVarMapper::hasVar(PTId i) const { bool LAVarMapper::isNegated(PTRef tr) const { if (logic.isNumConst(tr)) - return logic.getNumConst(tr) < 0; // Case (0a) and (0b) + return isNegative(logic.getNumConst(tr)); // Case (0a) and (0b) if (logic.isNumVar(tr)) return false; // Case (1a) if (logic.isTimes(tr)) { diff --git a/src/tsolvers/lasolver/LIAInterpolator.cc b/src/tsolvers/lasolver/LIAInterpolator.cc index 122d01fb7..96f0e29f2 100644 --- a/src/tsolvers/lasolver/LIAInterpolator.cc +++ b/src/tsolvers/lasolver/LIAInterpolator.cc @@ -31,7 +31,7 @@ LAExplanations LAExplanations::getLIAExplanation(ArithLogic & logic, vec liaExplanations.explanations.push(PtAsgn(positiveInequality, l_True)); } else { // 'not (c <= term)' => 'c > term' => 'term < c' => 'term <= c-1' => -(c-1) <= -term - auto newBoundValue = (logic.getNumConst(boundVal) - 1); + auto newBoundValue = (logic.getIntConst(boundVal) - 1); newBoundValue.negate(); PTRef nInequality = logic.mkLeq(logic.mkIntConst(newBoundValue), logic.mkNeg(boundedTerm)); assert(logic.getTermFromLeq(nInequality) == logic.mkNeg(boundedTerm)); diff --git a/src/tsolvers/lasolver/SparseMatrix.cc b/src/tsolvers/lasolver/SparseMatrix.cc index 9fca963b4..32946fa0d 100644 --- a/src/tsolvers/lasolver/SparseMatrix.cc +++ b/src/tsolvers/lasolver/SparseMatrix.cc @@ -7,6 +7,8 @@ #include "SparseMatrix.h" +#include + namespace opensmt { void SparseColMatrix::Col::negate() { this->poly.negate(); @@ -104,7 +106,8 @@ namespace { uint32_t nextColIndex = 1; while (nextColIndex < activeColumns.size()) { auto const & nextCol = A[activeColumns[nextColIndex]]; - auto quotient = -fastrat_fdiv_q(nextCol.getFirstCoeff(), smallestValue); + auto quotient = -fastint_fdiv_q(static_cast(nextCol.getFirstCoeff()), + static_cast(smallestValue)); assert(not quotient.isZero()); addColumnMultiple(A, activeColumns[0], quotient, activeColumns[nextColIndex], U); if (not nextCol.isFirst( @@ -130,7 +133,8 @@ namespace { for (uint32_t col = 0; col < pivotIndex; ++col) { auto const * otherVal = A[col].tryGetCoeffFor(rowIndex); if (not otherVal) { continue; } - auto quotient = -fastrat_fdiv_q(*otherVal, pivotVal); + auto quotient = -fastint_fdiv_q(static_cast(*otherVal), + static_cast(pivotVal)); if (not quotient.isZero()) { addColumnMultiple(A, pivotIndex, quotient, ColIndex{col}, U); } } } diff --git a/src/tsolvers/stpsolver/IDLSolver.h b/src/tsolvers/stpsolver/IDLSolver.h index 25ccd5227..78ff424f3 100644 --- a/src/tsolvers/stpsolver/IDLSolver.h +++ b/src/tsolvers/stpsolver/IDLSolver.h @@ -15,8 +15,8 @@ class IDLSolver : public STPSolver { template<> SafeInt Converter::getValue(Number const & val) { - assert(val.isInteger()); - return SafeInt(static_cast(val.get_d())); + assert(castInteger(val).tryGetValue()); + return SafeInt(static_cast(*castInteger(val).tryGetValue())); } template<> diff --git a/src/tsolvers/stpsolver/RDLSolver.h b/src/tsolvers/stpsolver/RDLSolver.h index 944bb01a9..01e3430dd 100644 --- a/src/tsolvers/stpsolver/RDLSolver.h +++ b/src/tsolvers/stpsolver/RDLSolver.h @@ -15,12 +15,12 @@ class RDLSolver : public STPSolver { template<> Delta Converter::getValue(Number const & val) { - return Delta(val, 0); + return Delta(castReal(val), 0); } template<> Delta Converter::getValue(ptrdiff_t val) { - return Delta(Number(val, 1), 0); + return Delta(Real{static_cast(val), 1}, 0); } template<> @@ -41,7 +41,7 @@ void STPSolver::fillTheoryFunctions(ModelBuilder & modelBuilder) const { // Now we need to compute the proper values as Rationals, not as \delta-Rationals // Compute the right value for delta: Delta delta; - Number deltaVal; + Real deltaVal; bool deltaSet = false; // I need to iterate over all edges and find the minimum from deltas making the edges true auto const & edges = this->model->getGraph().addedEdges; @@ -60,7 +60,7 @@ void STPSolver::fillTheoryFunctions(ModelBuilder & modelBuilder) const { } } if (not deltaSet || delta > 1) { - deltaVal = Number(1); + deltaVal = Real{1}; } else { deltaVal = delta.R() / 2; } @@ -70,7 +70,7 @@ void STPSolver::fillTheoryFunctions(ModelBuilder & modelBuilder) const { if (var == PTRef_Undef) { continue; } assert(logic.isVar(var)); Delta const & varDeltaValue = entry.second; - Number varValue = varDeltaValue.R() + varDeltaValue.D() * deltaVal; + Real varValue = varDeltaValue.R() + varDeltaValue.D() * deltaVal; PTRef val = logic.mkRealConst(varValue); modelBuilder.addVarValue(var, val); } diff --git a/src/tsolvers/stpsolver/STPSolver_implementations.hpp b/src/tsolvers/stpsolver/STPSolver_implementations.hpp index 6ed18086a..45766fe82 100644 --- a/src/tsolvers/stpsolver/STPSolver_implementations.hpp +++ b/src/tsolvers/stpsolver/STPSolver_implementations.hpp @@ -47,7 +47,7 @@ typename STPSolver::ParsedPTRef STPSolver::parseRef(PTRef ref) const { assert(logic.isTimes(mul)); Pterm &mulPt = logic.getPterm(mul); - assert(logic.isNumConst(mulPt[0]) && logic.getNumConst(mulPt[0]) == -1); + assert(logic.isNumConst(mulPt[0]) && isNegative(logic.getNumConst(mulPt[0])) && abs(logic.getNumConst(mulPt[0])).isOne()); y = mulPt[1]; assert(logic.isNumVar(y)); } diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 5b84844f5..79935970b 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -1,4 +1,5 @@ include(GoogleTest) + add_executable(LRATest) target_sources(LRATest PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/test_Rationals.cc" @@ -10,6 +11,15 @@ target_link_libraries(LRATest OpenSMT gtest gtest_main) gtest_add_tests(TARGET LRATest) +add_executable(LIATest) +target_sources(LIATest + PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/test_Integers.cc" +) + +target_link_libraries(LIATest OpenSMT gtest gtest_main) + +gtest_add_tests(TARGET LIATest) + add_executable(RewritingTest) target_sources(RewritingTest PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/test_SimplifyAssignment.cc" diff --git a/test/unit/test_Integers.cc b/test/unit/test_Integers.cc new file mode 100644 index 000000000..1dd9a57ac --- /dev/null +++ b/test/unit/test_Integers.cc @@ -0,0 +1,276 @@ +#include +#include +#include + +namespace opensmt { + +TEST(Integers_test, test_negate_int32min) { + // INT32_MIN + Integer i {"-2147483648"}; + i.negate(); + EXPECT_TRUE(i > 0); +} + +TEST(Integers_test, test_negate_minus_int32min) { + // - INT32_MIN = 2^31 + Integer i {"2147483648"}; + Integer neg = -i; + EXPECT_TRUE(neg.isWellFormed()); + EXPECT_TRUE(neg < 0); + i.negate(); + EXPECT_TRUE(i.isWellFormed()); + EXPECT_EQ(i, neg); +} + +TEST(Integers_test, test_additionAssign) { + Integer a {"2147483640"}; + Integer b {"10"}; + additionAssign(a,b); + EXPECT_EQ(a, Integer{"2147483650"} ); +} + +TEST(Integers_test, test_overwrite) +{ + Integer i(INT32_MAX); + Integer q(0); + i *= 10; + // should not compile: + // i *= Real(5, 4); + i = 0; + i = INT32_MAX; + i *= 10; + i = q; +} + +TEST(Integers_test, test_uword) +{ + uint32_t x = 2589903246; + Integer f(x); + ASSERT_TRUE(f.mpqPartValid()); +} + +TEST(Integers_test, test_modulo) +{ + Integer a(-37033300); + Integer b(1); + Integer mod = a % b; + ASSERT_EQ(mod, 0); +} + +TEST(Integers_test, test_creation) +{ + { + Integer a{INT_MIN}; + ASSERT_TRUE(a.wordPartValid()); + ASSERT_FALSE(a.mpqMemoryAllocated()); + ASSERT_EQ(a, Integer{"-2147483648"}); + } + { + Integer a{INT_MAX}; + ASSERT_TRUE(a.wordPartValid()); + ASSERT_FALSE(a.mpqMemoryAllocated()); + ASSERT_EQ(a, Integer{"2147483647"}); + } +} + +TEST(Integers_test, test_addition) +{ + { + Integer a(3); + Integer b(0); + static_assert(std::is_same_v); + ASSERT_EQ(a + b, Integer(3)); + ASSERT_EQ(-a + b, Integer(-3)); + } + + { + Integer a(0); + Integer b(3); + ASSERT_EQ(a + b, Integer(3)); + ASSERT_EQ(a + (-b), Integer(-3)); + } + + { + Integer a(3); + Integer b(-3); + ASSERT_EQ(a+b, 0); + } + + { + Integer a(3); + Integer b(1); + ASSERT_EQ(a+b, Integer(4)); + } + { + Integer a(UINT_MAX); + Integer b(INT_MAX); + Integer sum = a+b; + ASSERT_EQ(sum, Integer("6442450942")); + ASSERT_FALSE(sum.wordPartValid()); + } + { + Integer a(UINT_MAX); + Integer b(INT_MIN); + Integer sum = a+b; + ASSERT_EQ(sum, Integer("2147483647")); + ASSERT_TRUE(sum.wordPartValid()); + } +} + +TEST(Integers_test, test_subtraction) +{ + { + Integer a(10); + Integer b(0); + static_assert(std::is_same_v); + Integer c = a-b; + ASSERT_EQ(c, a); + ASSERT_TRUE(c.wordPartValid()); + ASSERT_FALSE(c.mpqPartValid()); + } + { + Integer a(0); + Integer s = a - Integer(INT_MIN); + ASSERT_FALSE(s.wordPartValid()); + ASSERT_TRUE(s.mpqPartValid()); + ASSERT_EQ(s, Integer(INT_MAX)+1); + } + { + Integer a(INT_MAX); + Integer b(UINT_MAX); + Integer sum = a-b; + ASSERT_EQ(sum, Integer("-2147483648")); + ASSERT_TRUE(sum.wordPartValid()); + } +} + +TEST(Integers_test, test_division) +{ + { + Integer a(-1); + Integer b(-1); + Real c = a / b; + ASSERT_EQ(c, 1); + } + { + Integer a(-3); + Integer b(2); + Real c = a / b; + ASSERT_EQ(c, Real(-3, 2)); + ASSERT_TRUE(c.wordPartValid()); + ASSERT_FALSE(c.mpqMemoryAllocated()); + } +} + +TEST(Integers_test, test_operatorAssign) +{ + { + Integer f(0); + static_assert(std::is_same_v); + f -= Integer(-3) * Integer(-1); + ASSERT_EQ(f, -3); + ASSERT_TRUE(f.wordPartValid()); + ASSERT_FALSE(f.mpqMemoryAllocated()); + } + { + Integer f(0); + f += Integer(-3) * Integer(-1); + ASSERT_EQ(f, 3); + ASSERT_TRUE(f.wordPartValid()); + ASSERT_FALSE(f.mpqMemoryAllocated()); + } +} + +TEST(Integers_test, test_CHECK_WORD) +{ + word a(INT_MAX); + uword b(UINT_MAX); + uword res = 0; + CHECK_WORD(res, lword(a)*b); + ASSERT_EQ(res, (lword)(9223372030412324865)); + overflow: + std::cout << "Overflow" << std::endl; +} + +TEST(Integers_test, test_sub_lword_underflow_min) +{ + lword res; + (void)res; + lword s1 = 0; + lword s2 = LWORD_MIN; + CHECK_SUB_OVERFLOWS_LWORD(res, s1, s2); + ASSERT_TRUE(false); + overflow: + ASSERT_TRUE(true); +} + +TEST(Integers_test, test_sub_lword_nounderflow) +{ + lword res; + (void)res; + lword s1 = 0; + lword s2 = LWORD_MIN+1; + CHECK_SUB_OVERFLOWS_LWORD(res, s1, s2); + return; + overflow: + ASSERT_TRUE(false); +} + +TEST(Integers_test, test_sub_lword_nooverflow) +{ + lword res; + (void)res; + lword s1 = -1; + lword s2 = LWORD_MAX; + CHECK_SUB_OVERFLOWS_LWORD(res, s1, s2); + return; + overflow: + ASSERT_TRUE(false); +} + +TEST(Integers_test, test_sub_lword_overflow) +{ + lword res; + (void)res; + lword s1 = -2; + lword s2 = LWORD_MAX; + CHECK_SUB_OVERFLOWS_LWORD(res, s1, s2); + ASSERT_FALSE(true); + overflow: + ASSERT_TRUE(true); +} + +TEST(Integers_test, test_mod) +{ + Integer a(INT_MAX); + Integer b(INT_MIN); + Integer res = a % b; + ASSERT_EQ(res, (INT_MIN+1)); +} + +TEST(Integers_test, test_addNegated) +{ + { + Integer a(15); + Integer b(-15); + Integer res = a+b; + ASSERT_EQ(res, 0); + } + { + Integer a(INT_MAX); + Integer b(INT_MIN); + Integer res = a + b; + ASSERT_EQ(res, -1); + } +} + +TEST(Integers_test, testWordRepresentation_Negate) { + Integer a(INT_MIN); // a fits into word representation + ASSERT_TRUE(a.wordPartValid()); + a.negate(); // a now does not fit into word representation + ASSERT_FALSE(a.wordPartValid()); + a.negate(); // a now again fits into word representation + ASSERT_TRUE(a.wordPartValid()); +} + +} diff --git a/test/unit/test_LIAInterpolation.cc b/test/unit/test_LIAInterpolation.cc index 635cf7757..90a615e9c 100644 --- a/test/unit/test_LIAInterpolation.cc +++ b/test/unit/test_LIAInterpolation.cc @@ -52,7 +52,7 @@ TEST_F(LIAInterpolationTest, test_InterpolationLRASat){ PTRef dualFarkasItp = interpolator.getDualFarkasInterpolant(); std::cout << logic.pp(dualFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(leq1, leq2, dualFarkasItp)); - PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Number(1,2)); + PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Real(1,2)); std::cout << logic.pp(halfFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(leq1, leq2, halfFarkasItp)); } diff --git a/test/unit/test_LRAInterpolation.cc b/test/unit/test_LRAInterpolation.cc index 5fb93cae1..f0efc6ee5 100644 --- a/test/unit/test_LRAInterpolation.cc +++ b/test/unit/test_LRAInterpolation.cc @@ -53,7 +53,7 @@ TEST_F(LRAInterpolationTest, test_FarkasInterpolation_BothNonstrict){ PTRef dualFarkasItp = interpolator.getDualFarkasInterpolant(); // std::cout << logic.pp(dualFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), dualFarkasItp)); - PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Number(1,2)); + PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Real(1,2)); // std::cout << logic.pp(halfFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), halfFarkasItp)); } @@ -81,7 +81,7 @@ TEST_F(LRAInterpolationTest, test_FarkasInterpolation_Astrict){ PTRef dualFarkasItp = interpolator.getDualFarkasInterpolant(); // std::cout << logic.pp(dualFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), dualFarkasItp)); - PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Number(1,2)); + PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Real(1,2)); // std::cout << logic.pp(halfFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), halfFarkasItp)); } @@ -109,7 +109,7 @@ TEST_F(LRAInterpolationTest, test_FarkasInterpolation_Bstrict){ PTRef dualFarkasItp = interpolator.getDualFarkasInterpolant(); // std::cout << logic.pp(dualFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), dualFarkasItp)); - PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Number(1,2)); + PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Real(1,2)); // std::cout << logic.pp(halfFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), halfFarkasItp)); } @@ -137,7 +137,7 @@ TEST_F(LRAInterpolationTest, test_FarkasInterpolation_BothStrict){ PTRef dualFarkasItp = interpolator.getDualFarkasInterpolant(); std::cout << logic.pp(dualFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), dualFarkasItp)); - PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Number(1,2)); + PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Real(1,2)); std::cout << logic.pp(halfFarkasItp) << std::endl; EXPECT_TRUE(verifyInterpolant(logic.mkAnd(leq1, leq2), logic.mkAnd(leq3, leq4), halfFarkasItp)); } @@ -164,7 +164,7 @@ TEST_F(LRAInterpolationTest, test_AllInA){ PTRef dualFarkasItp = interpolator.getDualFarkasInterpolant(); EXPECT_TRUE(verifyInterpolant(logic.mkAnd({leq1, leq2, leq3, leq4}), logic.getTerm_true(), dualFarkasItp)); EXPECT_EQ(dualFarkasItp, logic.getTerm_false()); - PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Number(1,2)); + PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Real(1,2)); EXPECT_TRUE(verifyInterpolant(logic.mkAnd({leq1, leq2, leq3, leq4}), logic.getTerm_true(), halfFarkasItp)); EXPECT_EQ(halfFarkasItp, logic.getTerm_false()); PTRef decomposedFarkasItp = interpolator.getDecomposedInterpolant(); @@ -197,7 +197,7 @@ TEST_F(LRAInterpolationTest, test_AllInB){ PTRef dualFarkasItp = interpolator.getDualFarkasInterpolant(); EXPECT_TRUE(verifyInterpolant(logic.getTerm_true(), logic.mkAnd({leq1, leq2, leq3, leq4}), dualFarkasItp)); EXPECT_EQ(dualFarkasItp, logic.getTerm_true()); - PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Number(1,2)); + PTRef halfFarkasItp = interpolator.getFlexibleInterpolant(Real(1,2)); EXPECT_TRUE(verifyInterpolant(logic.getTerm_true(), logic.mkAnd({leq1, leq2, leq3, leq4}), halfFarkasItp)); EXPECT_EQ(halfFarkasItp, logic.getTerm_true()); PTRef decomposedFarkasItp = interpolator.getDecomposedInterpolant(); diff --git a/test/unit/test_LRALogicMkTerms.cc b/test/unit/test_LRALogicMkTerms.cc index 03567c88a..63e50ec8c 100644 --- a/test/unit/test_LRALogicMkTerms.cc +++ b/test/unit/test_LRALogicMkTerms.cc @@ -139,9 +139,9 @@ TEST_F(LRALogicMkTermsTest, test_mkNumNeg) PTRef minus = logic.mkNeg(one); ASSERT_TRUE(logic.isConstant(minus)); ASSERT_TRUE(logic.isNumConst(minus)); - ASSERT_LT(logic.getNumConst(minus), 0); + ASSERT_LT(logic.getRealConst(minus), 0); ASSERT_EQ(logic.mkNeg(minus), one); - ASSERT_EQ(logic.getNumConst(minus), -1); + ASSERT_EQ(logic.getRealConst(minus), -1); } TEST_F(LRALogicMkTermsTest, test_Inequality_Var_WithCoeff) diff --git a/test/unit/test_Model.cc b/test/unit/test_Model.cc index 61f253df6..bd90530c8 100644 --- a/test/unit/test_Model.cc +++ b/test/unit/test_Model.cc @@ -251,7 +251,7 @@ TEST_F(LAModelTest, test_constants) { EXPECT_EQ(model->evaluate(fortytwo), fortytwo); PTRef one = logic.mkRealConst(1); EXPECT_EQ(model->evaluate(one), logic.getTerm_RealOne()); - PTRef zero = logic.mkRealConst(Number(0)); + PTRef zero = logic.mkRealConst(Real(0)); EXPECT_EQ(model->evaluate(zero), logic.getTerm_RealZero()); } diff --git a/test/unit/test_Rationals.cc b/test/unit/test_Rationals.cc index 508a305e7..d727dc70f 100644 --- a/test/unit/test_Rationals.cc +++ b/test/unit/test_Rationals.cc @@ -35,7 +35,7 @@ TEST(Rationals_test, test_normalized) TEST(Rationals_test, test_hash_function) { std::vector hashes; - NumberHash hasher; + Real::Hash hasher; for (int i = 0; i < 10; i++) { Real r((int)random()); hashes.push_back(hasher(r)); @@ -122,6 +122,7 @@ TEST(Rationals_test, test_negate_minus_int32min) { EXPECT_TRUE(neg < 0); r.negate(); EXPECT_TRUE(r.isWellFormed()); + EXPECT_EQ(r, neg); } TEST(Rationals_test, test_additionAssign) { @@ -149,14 +150,6 @@ TEST(Rationals_test, test_uword) ASSERT_TRUE(f.mpqPartValid()); } -TEST(Rationals_test, test_modulo) -{ - Real a(-37033300); - Real b(1); - Real mod = a % b; - ASSERT_EQ(mod, 0); -} - TEST(Rationals_test, test_creation) { { @@ -448,14 +441,6 @@ TEST(Rationals_test, test_ceil) f.ceil(); } -TEST(Rationals_test, test_mod) -{ - Real a(INT_MAX); - Real b(INT_MIN); - Real res = a % b; - ASSERT_EQ(res, (INT_MIN+1)); -} - TEST(Rationals_test, test_addNegated) { {