From 9295bd580a22433896499b40ca0cef25da840435 Mon Sep 17 00:00:00 2001 From: i80287 Date: Wed, 13 Nov 2024 00:53:07 +0300 Subject: [PATCH 1/3] update math_functions.hpp and is_prime.hpp --- number_theory/is_prime.hpp | 24 ++++----- number_theory/math_functions.hpp | 72 +++++++++++++++++---------- number_theory/test_math_functions.cpp | 32 ++++++++++++ 3 files changed, 89 insertions(+), 39 deletions(-) diff --git a/number_theory/is_prime.hpp b/number_theory/is_prime.hpp index c6ccb3e..a6290ff 100644 --- a/number_theory/is_prime.hpp +++ b/number_theory/is_prime.hpp @@ -91,9 +91,9 @@ ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_prp(uint64_t n, uint64_t a) noexce * the Jacobi symbol] **********************************************************************************************/ template -ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_lucas_prp(uint64_t n, uint32_t p, +ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_lucas_prp(uint64_t n, uint16_t p, int32_t q) noexcept { - const int64_t d = static_cast(uint64_t{p} * uint64_t{p}) - int64_t{q} * 4; + const int64_t d = int64_t{uint32_t{p} * uint32_t{p}} - int64_t{q} * 4; if constexpr (BasicChecks) { /* Check if p*p - 4*q == 0. */ if (unlikely(d == 0)) { @@ -139,16 +139,15 @@ ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_lucas_prp(uint64_t n, uint32_t p, * make sure U_s == 0 mod n or V_((2^t)*s) == 0 mod n, * for some t, 0 <= t < r */ - uint64_t uh = 1; // Initial value for U_1 - uint64_t vl = 2; // Initial value for V_0 - uint64_t vh = p; // Initial value for V_1 + uint64_t uh = 1; // Initial value for U_1 + uint64_t vl = 2; // Initial value for V_0 + uint64_t vh = uint64_t{p}; // Initial value for V_1 uint64_t ql = 1; uint64_t qh = 1; // q mod n - const uint64_t widen_q = - (q >= 0 ? static_cast(q) - : (n - (static_cast(-static_cast(q)) % n))) % - n; + const uint64_t widen_q = (q >= 0 ? static_cast(q) + : (n - static_cast(-static_cast(q)) % n)) % + n; CONFIG_ASSUME_STATEMENT(widen_q < n); // n >= 3 => n - 1 >= 2 => n - 1 >= 1 => s >= 1 for (uint32_t j = ::math_functions::log2_floor(s); j != 0; j--) { @@ -290,7 +289,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_lucas_prp(uint64_t n, uint32_t p, tmp_vl = vl_vl - ql_2; vl = vl_vl >= ql_2 ? tmp_vl : tmp_vl + n; CONFIG_ASSUME_STATEMENT(vl < n); - CONFIG_ASSUME_STATEMENT(vl == (uint128_t(n) + vl_vl - ql_2) % n); + CONFIG_ASSUME_STATEMENT(vl == (uint128_t{n} + vl_vl - ql_2) % n); if (vl == 0) { return true; @@ -355,6 +354,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_selfridge_prp(uint64_t n) noexcept break; case -1: { CONFIG_ASSUME_STATEMENT(d <= kMaxD + kStep * 2); + CONFIG_ASSUME_STATEMENT(-kMaxD - kStep <= d); CONFIG_ASSUME_STATEMENT((1 - d) % 4 == 0); const std::int32_t q = (1 - d) / 4; CONFIG_ASSUME_STATEMENT(1 - 4 * q == d); @@ -586,8 +586,8 @@ ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_selfridge_prp(uint64_t n) noexcept } [[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR bool is_mersenne_prime(const uint128_t n) noexcept { - const auto np1 = n + 1; - if (!is_power_of_two(np1)) { + const uint128_t np1 = n + 1; + if (!::math_functions::is_power_of_two(np1)) { return false; } const auto [q, p] = ::math_functions::extract_pow2(np1); diff --git a/number_theory/math_functions.hpp b/number_theory/math_functions.hpp index e1cee59..766788e 100644 --- a/number_theory/math_functions.hpp +++ b/number_theory/math_functions.hpp @@ -201,7 +201,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR uint64_t bin_pow_mod(uint64_t n, uint64_t p, uint return y; } -[[nodiscard]] ATTRIBUTE_CONST constexpr uint32_t isqrt(uint64_t n) noexcept { +[[nodiscard]] ATTRIBUTE_CONST constexpr uint32_t isqrt(const uint64_t n) noexcept { /** * In the runtime `sqrtl` is used (but not for the msvc prior to the c++20). */ @@ -236,7 +236,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR uint64_t bin_pow_mod(uint64_t n, uint64_t p, uint #if defined(INTEGERS_128_BIT_HPP) -[[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR uint64_t isqrt(uint128_t n) noexcept { +[[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR uint64_t isqrt(const uint128_t n) noexcept { /** * See Hackers Delight Chapter 11. */ @@ -282,7 +282,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR uint64_t bin_pow_mod(uint64_t n, uint64_t p, uint * `uint32_t(std::cbrt(3375.0))` may be equal to 14 */ -#if defined(__GNUG__) && !defined(__clang__) +#if defined(__GNUG__) && !defined(__clang__) && CONFIG_HAS_AT_LEAST_CXX_17 [[maybe_unused]] const auto n_original_value = n; #endif @@ -341,7 +341,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR uint64_t bin_pow_mod(uint64_t n, uint64_t p, uint /// @note ⌊n^0.25⌋ = ⌊⌊n^0.5⌋^0.5⌋ (see Hackers Delight Chapter 11, ex.1) /// @param[in] n /// @return ⌊n^0.25⌋ -[[nodiscard]] ATTRIBUTE_CONST constexpr uint32_t ifrrt(uint64_t n) noexcept { +[[nodiscard]] ATTRIBUTE_CONST constexpr uint32_t ifrrt(const uint64_t n) noexcept { return ::math_functions::isqrt(::math_functions::isqrt(n)); } @@ -351,7 +351,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR uint64_t bin_pow_mod(uint64_t n, uint64_t p, uint /// It can be shown that ⌊n^0.25⌋ = ⌊⌊n^0.5⌋^0.5⌋ /// @param[in] n /// @return -[[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR uint32_t ifrrt(uint128_t n) noexcept { +[[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR uint32_t ifrrt(const uint128_t n) noexcept { return ::math_functions::isqrt(::math_functions::isqrt(n)); } @@ -468,7 +468,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR /// @brief This function reverses bits of the @a `b` /// @param[in] b /// @return 8-bit number whose bits are reversed bits of the @a `b`. -[[nodiscard]] ATTRIBUTE_CONST constexpr uint8_t bit_reverse(uint8_t b) noexcept { +[[nodiscard]] ATTRIBUTE_CONST constexpr uint8_t bit_reverse(const uint8_t b) noexcept { // See https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 return static_cast(((b * 0x80200802ULL) & 0x0884422110ULL) * 0x0101010101ULL >> 32U); } @@ -1413,7 +1413,9 @@ ATTRIBUTE_ALWAYS_INLINE ATTRIBUTE_CONST constexpr uint32_t base_b_len_impl( template [[nodiscard]] ATTRIBUTE_CONST constexpr uint32_t base_b_len(T value, const uint8_t base = 10) noexcept { - static_assert(::math_functions::detail::is_integral_v); + static_assert(::math_functions::detail::is_integral_v && !std::is_same_v && + !std::is_same_v, + "integral type (not bool or char) expected in the base_b_len"); if constexpr (::math_functions::detail::is_signed_v) { const uint32_t is_negative = uint32_t{value < 0}; @@ -1424,7 +1426,7 @@ template } } -/// @brief For n > 0 returns ⌈log_2(n)⌉. For n = 0 returns (uint32_t)-1 +/// @brief For n > 0 returns ⌊log_2(n)⌋. For n = 0 returns (uint32_t)-1 /// @tparam UIntType unsigned integral type (at least unsigned int in size) /// @param[in] n /// @return @@ -1461,7 +1463,7 @@ template #endif [[nodiscard]] ATTRIBUTE_ALWAYS_INLINE ATTRIBUTE_CONST constexpr uint32_t log2_ceil(const UIntType n) noexcept { - return ::math_functions::log2_floor(n) + ((n & (n - 1)) != 0); + return ::math_functions::log2_floor(n) + uint32_t{(n & (n - 1)) != 0}; } template @@ -1680,29 +1682,15 @@ template return (x || y) && (y || z) && (x || z); } +template #if CONFIG_HAS_CONCEPTS - -template <::math_functions::detail::unsigned_integral T> + requires ::math_functions::detail::unsigned_integral +#endif [[nodiscard]] ATTRIBUTE_CONST constexpr T next_even(T n) noexcept { + static_assert(::math_functions::detail::is_unsigned_v, "unsigned integral type expected"); return n + 2 - n % 2; } -#else - -// clang-format off - -template -[[nodiscard]] -ATTRIBUTE_CONST -constexpr -std::enable_if_t<::math_functions::detail::is_unsigned_v, T> next_even(T n) noexcept { - return n + 2 - n % 2; -} - -// clang-format on - -#endif - template struct SumSinCos { FloatType sines_sum; @@ -2923,6 +2911,36 @@ template return ::math_functions::arange(T{0}, n); } +/// @brief Return vector of elements {log2(0), log2(1), log2(2), log2(3), ..., log2(n)} +/// @note Here log2(0) := -1 +/// @param n +/// @return +[[nodiscard]] CONSTEXPR_VECTOR std::vector log2_arange(const uint32_t n) { + std::vector values(size_t{n} + 1); + values[0] = static_cast(-1); + for (size_t i = 1; i <= n; i++) { + values[i] = values[i / 2] + 1; + } + + return values; +} + +/// @brief Return vector of elements {0! mod m, 1! mod m, 2! mod m, 3! mod m, ..., n! mod m} +/// @param n +/// @return +[[nodiscard]] +CONSTEXPR_VECTOR std::vector factorial_arange_mod_m(const uint32_t n, const uint32_t m) { + std::vector values(size_t{n} + 1); + uint32_t current_factorial = m != 1 ? 1u : 0u; + values[0] = current_factorial; + for (size_t i = 1; i <= n; i++) { + current_factorial = static_cast((uint64_t{current_factorial} * uint64_t{i}) % m); + values[i] = current_factorial; + } + + return values; +} + } // namespace math_functions // NOLINTEND(cppcoreguidelines-avoid-magic-numbers) diff --git a/number_theory/test_math_functions.cpp b/number_theory/test_math_functions.cpp index 8b6627a..8ee92aa 100644 --- a/number_theory/test_math_functions.cpp +++ b/number_theory/test_math_functions.cpp @@ -3471,6 +3471,36 @@ void test_arange() { assert((arange(11, 11, 11).empty())); } +void test_log2_arange() { + log_tests_started(); + + using std::vector; + + assert(log2_arange(0) == vector{static_cast(-1)}); + const uint32_t n = 1'000'000; + const vector range = log2_arange(n); + assert(range.size() == n + 1); + for (uint32_t i = 0; i <= n; i++) { + assert(range[i] == log2_floor(i)); + } +} + +void test_factorial_arange_mod_m() { + log_tests_started(); + + for (const uint32_t m : {2u, 4u, static_cast(1e7) + 9}) { + for (const uint32_t n : {10u, 1000u, 100000u}) { + const std::vector fact_range = factorial_arange_mod_m(n, m); + uint32_t factorial = 1; + assert(fact_range[0] == factorial); + for (uint32_t i = 1; size_t{i} <= size_t{n}; i++) { + factorial = static_cast((uint64_t{factorial} * uint64_t{i}) % m); + assert(fact_range[i] == factorial); + } + } + } +} + void test_masked_popcount_sum() noexcept { log_tests_started(); @@ -3523,5 +3553,7 @@ int main() { test_solve_factorial_congruence(); test_powers_sum(); test_arange(); + test_log2_arange(); + test_factorial_arange_mod_m(); test_masked_popcount_sum(); } From 20d61b552a3990b442a5fa859db7806b79c56c84 Mon Sep 17 00:00:00 2001 From: i80287 Date: Wed, 13 Nov 2024 12:43:55 +0300 Subject: [PATCH 2/3] update math_functions, fibonacci_num.hpp and is_prime.hpp --- number_theory/fibonacci_num.hpp | 4 +- number_theory/is_prime.hpp | 3 +- number_theory/math_functions.hpp | 205 +++++++++++++++++------ number_theory/test_fibonacci_num.cpp | 16 +- number_theory/test_kronecker_symbol.cpp | 8 +- number_theory/test_math_functions.cpp | 212 +++++++++++++++++++++++- 6 files changed, 381 insertions(+), 67 deletions(-) diff --git a/number_theory/fibonacci_num.hpp b/number_theory/fibonacci_num.hpp index 0ca0cd4..9387e22 100644 --- a/number_theory/fibonacci_num.hpp +++ b/number_theory/fibonacci_num.hpp @@ -71,7 +71,7 @@ ATTRIBUTE_CONST constexpr fibs_pair fibonacci_nums(std::uint32_t n) noexcept { /// @param n /// @return F_n [[nodiscard]] -ATTRIBUTE_CONST constexpr std::uint64_t fibonacci_num(std::uint32_t n) noexcept { +ATTRIBUTE_CONST constexpr std::uint64_t nth_fibonacci_num(std::uint32_t n) noexcept { return fibonacci_nums(n).fib_n; } @@ -153,7 +153,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR fibs_pair_u128 fibonacci_nums_u128(std::uint32_t /// @param n /// @return F_n [[nodiscard]] -ATTRIBUTE_CONST I128_CONSTEXPR uint128_t fibonacci_num_u128(std::uint32_t n) noexcept { +ATTRIBUTE_CONST I128_CONSTEXPR uint128_t nth_fibonacci_num_u128(std::uint32_t n) noexcept { return fibonacci_nums_u128(n).fib_n; } diff --git a/number_theory/is_prime.hpp b/number_theory/is_prime.hpp index a6290ff..2ede78c 100644 --- a/number_theory/is_prime.hpp +++ b/number_theory/is_prime.hpp @@ -93,7 +93,8 @@ ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_prp(uint64_t n, uint64_t a) noexce template ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_lucas_prp(uint64_t n, uint16_t p, int32_t q) noexcept { - const int64_t d = int64_t{uint32_t{p} * uint32_t{p}} - int64_t{q} * 4; + const uint32_t p2 = uint32_t{p} * uint32_t{p}; + const int64_t d = int64_t{p2} - int64_t{q} * 4; if constexpr (BasicChecks) { /* Check if p*p - 4*q == 0. */ if (unlikely(d == 0)) { diff --git a/number_theory/math_functions.hpp b/number_theory/math_functions.hpp index 766788e..f3be6ad 100644 --- a/number_theory/math_functions.hpp +++ b/number_theory/math_functions.hpp @@ -34,9 +34,13 @@ #include "config_macros.hpp" -#if CONFIG_HAS_AT_LEAST_CXX_20 +#if CONFIG_HAS_AT_LEAST_CXX_20 && CONFIG_HAS_INCLUDE() #include +#define MATH_FUNCTIONS_HAS_BIT +#endif +#if CONFIG_HAS_AT_LEAST_CXX_20 && CONFIG_HAS_INCLUDE() #include +#define MATH_FUNCTIONS_HAS_RANGES #endif #if CONFIG_HAS_INCLUDE("integers_128_bit.hpp") @@ -66,6 +70,7 @@ namespace math_functions { using std::int32_t; using std::int64_t; +using std::ptrdiff_t; using std::size_t; using std::uint32_t; using std::uint64_t; @@ -111,8 +116,7 @@ template #else template #endif -[[nodiscard]] ATTRIBUTE_CONST constexpr T bin_pow(T n, - std::ptrdiff_t p) noexcept(noexcept(n *= n) && +[[nodiscard]] ATTRIBUTE_CONST constexpr T bin_pow(T n, ptrdiff_t p) noexcept(noexcept(n *= n) && noexcept(1 / n)) { const bool not_inverse = p >= 0; const size_t p_u = p >= 0 ? static_cast(p) : -static_cast(p); @@ -1010,17 +1014,11 @@ template return sizeof(n) * CHAR_BIT; } -#if defined(INTEGERS_128_BIT_HPP) +#ifdef INTEGERS_128_BIT_HPP if constexpr (std::is_same_v) { const uint64_t low = static_cast(n); if (low != 0) { -#if CONFIG_HAS_AT_LEAST_CXX_20 - return std::countr_zero(low); -#elif defined(__GNUG__) - return __builtin_ctzll(low); -#else - return static_cast(::math_functions::detail::tz_count_64_software(low)); -#endif + return ::math_functions::countr_zero(low); } const uint64_t high = static_cast(n >> 64U); @@ -1029,7 +1027,7 @@ template } else #endif -#if CONFIG_HAS_AT_LEAST_CXX_20 +#ifdef MATH_FUNCTIONS_HAS_BIT { return std::countr_zero(n); } @@ -1080,14 +1078,7 @@ template if constexpr (std::is_same_v) { const uint64_t high = static_cast(n >> 64U); if (high != 0) { - // Avoid recursive call to countl_zero -#if CONFIG_HAS_AT_LEAST_CXX_20 - return std::countl_zero(high); -#elif defined(__GNUG__) - return __builtin_clzll(high); -#else - return static_cast(::math_functions::detail::lz_count_64_software(high)); -#endif + return ::math_functions::countl_zero(high); } const uint64_t low = static_cast(n); @@ -1096,7 +1087,7 @@ template } else #endif -#if CONFIG_HAS_AT_LEAST_CXX_20 +#ifdef MATH_FUNCTIONS_HAS_BIT { return std::countl_zero(n); } @@ -1146,7 +1137,7 @@ template } else #endif -#if CONFIG_HAS_AT_LEAST_CXX_20 +#ifdef MATH_FUNCTIONS_HAS_BIT { return std::popcount(n); } @@ -2389,10 +2380,14 @@ struct InverseResult { namespace detail { +// clang-format off + template -CONSTEXPR_VECTOR typename ::math_functions::InverseResult inv_mod_m_impl(Iter nums_begin, - Iter nums_end, - uint32_t m) { +ATTRIBUTE_NODISCARD +CONSTEXPR_VECTOR +typename ::math_functions::InverseResult inv_mod_m_impl(Iter nums_begin, Iter nums_end, uint32_t m) { + // clang-format on + const auto n = static_cast(std::distance(nums_begin, nums_end)); auto res = ::math_functions::InverseResult{ std::vector(n), @@ -2430,17 +2425,16 @@ CONSTEXPR_VECTOR typename ::math_functions::InverseResult inv_mod_m_impl(Iter nu } // namespace detail -#if CONFIG_HAS_CONCEPTS +#if CONFIG_HAS_CONCEPTS && defined(MATH_FUNCTIONS_HAS_RANGES) // clang-format off -template - requires ::math_functions::detail::integral> && - (!std::same_as, bool>) +template + requires ::math_functions::detail::integral> && + (!std::same_as, bool>) [[nodiscard]] -CONSTEXPR_VECTOR -typename ::math_functions::InverseResult inv_mod_m(Iter nums_iter_begin, Iter nums_iter_end, uint32_t m) { - return ::math_functions::detail::inv_mod_m_impl(nums_iter_begin, nums_iter_end, m); +CONSTEXPR_VECTOR typename ::math_functions::InverseResult inv_mod_m(Iterator nums_begin, Iterator nums_end, uint32_t m) { + return ::math_functions::detail::inv_mod_m_impl(nums_begin, nums_end, m); } /// @brief Inverse @a nums mod m @@ -2449,10 +2443,10 @@ typename ::math_functions::InverseResult inv_mod_m(Iter nums_iter_begin, Iter nu /// @param nums /// @param m /// @return -template +template [[nodiscard]] -CONSTEXPR_VECTOR -typename ::math_functions::InverseResult inv_mod_m(const Container& nums, uint32_t m) { +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) +CONSTEXPR_VECTOR typename ::math_functions::InverseResult inv_mod_m(Range&& nums, uint32_t m) { return ::math_functions::inv_mod_m(std::begin(nums), std::end(nums), m); } @@ -2460,28 +2454,37 @@ typename ::math_functions::InverseResult inv_mod_m(const Container& nums, uint32 #else +// clang-format off + template -[[nodiscard]] -CONSTEXPR_VECTOR std::enable_if_t< - ::math_functions::detail::is_integral_v::value_type> && - !std::is_same_v::value_type, bool>, - typename ::math_functions::InverseResult> inv_mod_m(Iter nums_iter_begin, Iter nums_iter_end, - uint32_t m) { +ATTRIBUTE_NODISCARD +CONSTEXPR_VECTOR +typename std::enable_if_t< + ::math_functions::detail::is_integral_v::value_type> && + !std::is_same_v::value_type, bool>, + typename ::math_functions::InverseResult> +inv_mod_m(Iter nums_iter_begin, Iter nums_iter_end, uint32_t m) { return ::math_functions::detail::inv_mod_m_impl(nums_iter_begin, nums_iter_end, m); } -template -[[nodiscard]] +template +ATTRIBUTE_NODISCARD CONSTEXPR_VECTOR - std::enable_if_t<::math_functions::detail::is_integral_v()))>::value_type>&& :: - math_functions::detail::is_integral_v()))>::value_type>, - typename ::math_functions::InverseResult> inv_mod_m(const Container& nums, - uint32_t m) { +typename std::enable_if_t< + ::math_functions::detail::is_integral_v< + typename std::iterator_traits()))>::value_type + > && + ::math_functions::detail::is_integral_v< + typename std::iterator_traits()))>::value_type + >, + typename ::math_functions::InverseResult> +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) +inv_mod_m(Range&& nums, uint32_t m) { return ::math_functions::inv_mod_m(std::begin(nums), std::end(nums), m); } +// clang-format on + #endif /// @brief Solve congruence 2^k * x ≡ c (mod m), @@ -2931,7 +2934,7 @@ template [[nodiscard]] CONSTEXPR_VECTOR std::vector factorial_arange_mod_m(const uint32_t n, const uint32_t m) { std::vector values(size_t{n} + 1); - uint32_t current_factorial = m != 1 ? 1u : 0u; + uint32_t current_factorial = m != 1 ? 1U : 0U; values[0] = current_factorial; for (size_t i = 1; i <= n; i++) { current_factorial = static_cast((uint64_t{current_factorial} * uint64_t{i}) % m); @@ -2941,6 +2944,105 @@ CONSTEXPR_VECTOR std::vector factorial_arange_mod_m(const uint32_t n, return values; } +namespace detail { + +// clang-format off + +template +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_SIZED_ACCESS(read_only, 3, 4) +constexpr Iterator find_wmedian_iter(uint64_t weighted_sum, + Iterator iter, + const uint64_t* const prefsums, + const size_t prefsums_size) noexcept { + // clang-format on + + uint64_t min_weighted_sum = weighted_sum; + Iterator min_weighted_sum_iter = iter; + ++iter; + const size_t n = prefsums_size - 1; + const uint64_t max_prefsum = prefsums[n]; + for (size_t j = 1; j < n; ++iter, ++j) { + weighted_sum = weighted_sum + prefsums[j] - (max_prefsum - prefsums[j]); + if (weighted_sum < min_weighted_sum) { + min_weighted_sum = weighted_sum; + min_weighted_sum_iter = iter; + } + } + + return min_weighted_sum_iter; +} + +// clang-format off + +template +ATTRIBUTE_NODISCARD +CONSTEXPR_VECTOR Iterator wmedian_impl(const Iterator begin, const Iterator end) { + // clang-format on + + const ptrdiff_t n_signed = std::distance(begin, end); + if (unlikely(n_signed <= 0)) { + return end; + } + + const size_t n = static_cast(n_signed); + std::vector prefsums(n + 1); + size_t i = 0; + uint64_t weighted_sum = 0; + for (Iterator iter = begin; iter != end; ++iter, ++i) { + const uint32_t val = *iter; + prefsums[i + 1] = prefsums[i] + uint64_t{val}; + weighted_sum += uint64_t{i} * uint64_t{val}; + } + + return ::math_functions::detail::find_wmedian_iter(weighted_sum, begin, + std::as_const(prefsums).data(), n + 1); +} + +} // namespace detail + +#if CONFIG_HAS_CONCEPTS && defined(MATH_FUNCTIONS_HAS_RANGES) + +template + requires std::is_same_v, uint32_t> +[[nodiscard]] CONSTEXPR_VECTOR Iterator weighted_median(Iterator begin, Iterator end) { + return ::math_functions::detail::wmedian_impl(begin, end); +} + +// clang-format off + +template +[[nodiscard]] +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) +CONSTEXPR_VECTOR typename std::ranges::borrowed_iterator_t weighted_median(Range&& range ATTRIBUTE_LIFETIME_BOUND) { + return ::math_functions::weighted_median(std::begin(range), std::end(range)); +} + +// clang-format on + +#else + +// clang-format off + +template +ATTRIBUTE_NODISCARD +CONSTEXPR_VECTOR typename std::enable_if_t::value_type, uint32_t>, + Iterator> +weighted_median(Iterator begin, Iterator end) { + return ::math_functions::detail::wmedian_impl(begin, end); +} + +// clang-format on + +template +// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) +ATTRIBUTE_NODISCARD CONSTEXPR_VECTOR auto weighted_median(Range&& range ATTRIBUTE_LIFETIME_BOUND) { + return ::math_functions::weighted_median(std::begin(range), std::end(range)); +} + +#endif + } // namespace math_functions // NOLINTEND(cppcoreguidelines-avoid-magic-numbers) @@ -3018,6 +3120,9 @@ namespace std { #endif // INTEGERS_128_BIT_HPP #undef CONSTEXPR_VECTOR +#ifdef MATH_FUNCTIONS_HAS_BIT +#undef MATH_FUNCTIONS_HAS_BIT +#endif #if defined(MATH_FUNCTIONS_HPP_ENABLE_TARGET_OPTIONS) #if defined(__GNUG__) diff --git a/number_theory/test_fibonacci_num.cpp b/number_theory/test_fibonacci_num.cpp index 06ecf80..adef1e9 100644 --- a/number_theory/test_fibonacci_num.cpp +++ b/number_theory/test_fibonacci_num.cpp @@ -15,8 +15,8 @@ template void test_fib_u64() noexcept { log_tests_started(); - static_assert(math_functions::fibonacci_num(0) == 1); - static_assert(math_functions::fibonacci_num(1) == 1); + static_assert(math_functions::nth_fibonacci_num(0) == 1); + static_assert(math_functions::nth_fibonacci_num(1) == 1); uint64_t prev_prev_fib = 1; uint64_t prev_fib = 1; for (uint32_t n = 2; n < k; n++) { @@ -24,7 +24,7 @@ void test_fib_u64() noexcept { const auto [f_n_1, f_n] = math_functions::fibonacci_nums(n); assert(f_n_1 == prev_fib); assert(f_n == current_fib); - assert(f_n == math_functions::fibonacci_num(n)); + assert(f_n == math_functions::nth_fibonacci_num(n)); prev_prev_fib = prev_fib; prev_fib = current_fib; } @@ -34,11 +34,11 @@ template void test_fib_u128() noexcept { log_tests_started(); #if defined(HAS_I128_CONSTEXPR) - static_assert(math_functions::fibonacci_num_u128(0) == 1); - static_assert(math_functions::fibonacci_num_u128(1) == 1); + static_assert(math_functions::nth_fibonacci_num_u128(0) == 1); + static_assert(math_functions::nth_fibonacci_num_u128(1) == 1); #endif - assert(math_functions::fibonacci_num_u128(0) == 1); - assert(math_functions::fibonacci_num_u128(1) == 1); + assert(math_functions::nth_fibonacci_num_u128(0) == 1); + assert(math_functions::nth_fibonacci_num_u128(1) == 1); uint128_t prev_prev_fib = 1; uint128_t prev_fib = 1; for (uint32_t n = 2; n < k; n++) { @@ -46,7 +46,7 @@ void test_fib_u128() noexcept { const auto [f_n_1, f_n] = math_functions::fibonacci_nums_u128(n); assert(f_n_1 == prev_fib); assert(f_n == current_fib); - assert(f_n == math_functions::fibonacci_num_u128(n)); + assert(f_n == math_functions::nth_fibonacci_num_u128(n)); prev_prev_fib = prev_fib; prev_fib = current_fib; } diff --git a/number_theory/test_kronecker_symbol.cpp b/number_theory/test_kronecker_symbol.cpp index c396efe..2c95239 100644 --- a/number_theory/test_kronecker_symbol.cpp +++ b/number_theory/test_kronecker_symbol.cpp @@ -13,8 +13,8 @@ #include "math_functions.hpp" #include "test_tools.hpp" -using math_functions::fibonacci_num; using math_functions::kronecker_symbol; +using math_functions::nth_fibonacci_num; #if CONFIG_HAS_INCLUDE() #include @@ -286,10 +286,10 @@ void CheckJacobiBasic() noexcept { } if (p <= math_functions::kMaxFibNonOverflowU64) { - if constexpr (fibonacci_num(1) == 1 && fibonacci_num(2) == 1) { - assert(fibonacci_num(p) % p == j_5_p_mod_p); + if constexpr (nth_fibonacci_num(1) == 1 && nth_fibonacci_num(2) == 1) { + assert(nth_fibonacci_num(p) % p == j_5_p_mod_p); } else { - assert(fibonacci_num(p - 1) % p == j_5_p_mod_p); + assert(nth_fibonacci_num(p - 1) % p == j_5_p_mod_p); } } diff --git a/number_theory/test_math_functions.cpp b/number_theory/test_math_functions.cpp index 8ee92aa..da7745a 100644 --- a/number_theory/test_math_functions.cpp +++ b/number_theory/test_math_functions.cpp @@ -3488,8 +3488,8 @@ void test_log2_arange() { void test_factorial_arange_mod_m() { log_tests_started(); - for (const uint32_t m : {2u, 4u, static_cast(1e7) + 9}) { - for (const uint32_t n : {10u, 1000u, 100000u}) { + for (const uint32_t m : {2U, 4U, static_cast(1e7) + 9}) { + for (const uint32_t n : {10U, 1000U, 100000U}) { const std::vector fact_range = factorial_arange_mod_m(n, m); uint32_t factorial = 1; assert(fact_range[0] == factorial); @@ -3528,6 +3528,213 @@ void test_masked_popcount_sum() noexcept { } } +void test_weighted_median() { + using std::vector; + + { + vector v{5, 4, 3, 2, 1}; + assert(weighted_median(v) == v.begin() + 1); + } + { + vector v{5, 2, 3, 1}; + assert(weighted_median(v) == v.begin() + 1); + } + { + const vector v{ + 3499211612, 581869302, 3890346734, 3586334585, 545404204, 4161255391, 3922919429, + 949333985, 2715962298, 1323567403, 418932835, 2350294565, 1196140740, 809094426, + 2348838239, 4264392720, 4112460519, 4279768804, 4144164697, 4156218106, 676943009, + 3117454609, 4168664243, 4213834039, 4111000746, 471852626, 2084672536, 3427838553, + 3437178460, 1275731771, 609397212, 20544909, 1811450929, 483031418, 3933054126, + 2747762695, 3402504553, 3772830893, 4120988587, 2163214728, 2816384844, 3427077306, + 153380495, 1551745920, 3646982597, 910208076, 4011470445, 2926416934, 2915145307, + 1712568902, 3254469058, 3181055693, 3191729660, 2039073006, 1684602222, 1812852786, + 2815256116, 746745227, 735241234, 1296707006, 3032444839, 3424291161, 136721026, + 1359573808, 1189375152, 3747053250, 198304612, 640439652, 417177801, 4269491673, + 3536724425, 3530047642, 2984266209, 537655879, 1361931891, 3280281326, 4081172609, + 2107063880, 147944788, 2850164008, 1884392678, 540721923, 1638781099, 902841100, + 3287869586, 219972873, 3415357582, 156513983, 802611720, 1755486969, 2103522059, + 1967048444, 1913778154, 2094092595, 2775893247, 3410096536, 3046698742, 3955127111, + 3241354600, 3468319344, 1185518681, 3031277329, 2919300778, 12105075, 2813624502, + 3052449900, 698412071, 2765791248, 511091141, 1958646067, 2140457296, 3323948758, + 4122068897, 2464257528, 1461945556, 3765644424, 2513705832, 3471087299, 961264978, + 76338300, 3226667454, 3527224675, 1095625157, 3525484323, 2173068963, 4037587209, + 3002511655, 1772389185, 3826400342, 1817480335, 4120125281, 2495189930, 2350272820, + 678852156, 595387438, 3271610651, 641212874, 988512770, 1105989508, 3477783405, + 3610853094, 4245667946, 1092133642, 1427854500, 3497326703, 1287767370, 1045931779, + 58150106, 3991156885, 933029415, 1503168825, 3897101788, 844370145, 3644141418, + 1078396938, 4101769245, 2645891717, 3345340191, 2032760103, 4241106803, 1510366103, + 290319951, 3568381791, 3408475658, 2513690134, 2553373352, 2361044915, 3147346559, + 3939316793, 2986002498, 1227669233, 2919803768, 3252150224, 1685003584, 3237241796, + 2411870849, 1634002467, 893645500, 2438775379, 2265043167, 325791709, 1736062366, + 231714000, 1515103006, 2279758133, 2546159170, 3346497776, 1530490810, 4011545318, + 4144499009, 557942923, 663307952, 2443079012, 1696117849, 2016017442, 1663423246, + 51119001, 3122246755, 1447930741, 1668894615, + }; + assert(weighted_median(v) == v.begin() + 98); + } + { + const vector v{ + 985960778, 2860674143, 2968742429, 2594641170, 3050160906, 1696058985, 3122376166, + 2182044559, 2094860131, 3813024814, 800699405, 530565855, 4033017831, 2932007873, + 286351694, 1262478340, 957474756, 1675384708, 4125210577, 3025675706, 2070911595, + 2492594739, 4101999706, 509483035, 1056501017, 2205558691, 3984832071, 3458516866, + 993374347, 1005154904, 2961173510, 2254879989, 388337156, 4199061715, 1374215613, + 2779100868, 1585196674, 2326601676, 2543518909, 1253161009, 1101452479, 4026085828, + 2444973131, 1559335522, 1567131291, 837011729, 2834214485, 3118342083, 2571080135, + 3213328226, 3531873439, 1856831665, 1310580097, 2442529957, 3046681832, 3271690434, + 3134764498, 3267335484, 630943181, 639509449, 3405734440, 1835539045, 2468594140, + 850053284, 2684650624, 3522616351, 167140491, 3277906793, 3628563340, 3974577599, + 827947059, 2276025901, 1903598325, 115033735, 3364130392, 2846096347, 1714976408, + 3592268746, 1079812253, 2835837984, 497820880, 3072730676, 983366057, 3086900695, + 745691392, 3536460441, 2349084751, 3851548162, 2337109115, 3534080360, 4111766338, + 2646667457, 743711972, 832159134, 133280532, 3997118892, 2690535057, 646521452, + 1023706758, 2329543377, 1668315947, 104158626, 554828811, 3100133738, 3705606252, + 1354648466, 2411997247, 2288011834, 1627662859, 1716592181, 3642635840, 1158238372, + 3815614483, 3483036216, 2543862353, 2972982015, 128026573, 646663579, 1898712001, + 235037311, 1186238169, 1194650207, 778150010, 129405274, 814687594, 3412656806, + 652378048, 969364358, 4250068332, 1268986887, 2005890049, 53103056, 2233089521, + 1464237069, 1491599729, 3876670640, 1492117846, 1371136301, 211399389, 3216043338, + 2215938377, 3649522078, 585770438, 1580256687, 3306553132, 1284491359, 1027219513, + 3917626406, 1184591233, 2985140924, 3124870166, 2089108909, 2382641111, 759907791, + 1651535410, 1196217750, 1416222464, 4158197116, 808607844, 102999428, 2588628451, + 1635689902, 3910486019, 3882050518, 3305854945, 2102411944, 3810033807, 2756991256, + 278790782, 2185169182, 4259320572, 3510109028, 2088155844, 2271923173, 748780304, + 2044695239, 3750116655, 1910836364, 429225931, 708171834, 121868046, 2092095304, + 381622095, 1721736905, 4133026139, 638460672, 516429888, 2801902516, 1618441419, + 1814936805, 938188480, 2537518490, 3479155686, 303590527, 368884417, 2536232817, + 4203933800, 2504593040, 956800143, 1805633868, + }; + assert(weighted_median(v) == v.begin() + 90); + } + { + const vector v{ + 2198438985, 2698226636, 1968950470, 3315576946, 431429275, 671069548, 3819086165, + 2626023107, 2400866928, 1293181152, 742335012, 1303628936, 1069455802, 3136464212, + 559870608, 1939049197, 1645891156, 512698940, 4193197177, 4225218128, 3699758192, + 3246245406, 1853870250, 2730159865, 2104201649, 3122320074, 981166580, 68354129, + 572726716, 3696106829, 1336207750, 1930546613, 2029720412, 2636162218, 3865237919, + 3780437118, 170842145, 2864042351, 1961591048, 595314277, 3186265701, 3306563953, + 2643727482, 2569952601, 2890998654, 2887077875, 1593400555, 4191162562, 4266346982, + 3692726186, 510989497, 792710458, 1973707153, 1485600190, 2065601451, 691529571, + 4115746598, 672524964, 1376658681, 3062610121, 3671134895, 3103445730, 1187470072, + 3741842869, 2621281472, 1345985168, 2005098775, 1295728899, 1552615950, 1427371637, + 2324291105, 2791379451, 733673891, 3897507786, 3318494682, 2262747389, 3511281556, + 2331636631, 3650039274, 527308565, 1478464269, 2709761083, 3612653589, 3996969744, + 2866897468, 1894233895, 2515599365, 1604073355, 44034912, 652977540, 3150921166, + 433419443, 1191398399, 93103485, 1237949188, 757522497, 1848128832, 2742632524, + 221348416, 3035518529, 1180347579, 2738387411, 2342967856, 527838565, 3086650998, + 391287595, 4149513012, 1743492517, 1738734840, 1187423038, 2259959790, 1813480057, + 926057651, 66467867, 3522468017, 1983811204, 3206798407, 3657074583, 3278890216, + 2455270391, 4060076374, 2717476378, 3302021426, 1284712450, 3296104179, 3228549559, + 3348390349, 3536531104, 3773403126, 3557866222, 352335128, 1034511494, 1422355958, + 3550983429, 972358605, 1868989705, 1576749113, 1266867208, 2263617637, 2871713105, + 3600587925, 1507774735, 1997468526, 2802994620, 3900393299, 384979660, 3144123175, + 3638471436, 3641023835, 67491570, 603500481, 1219502220, 1840861867, 3801372968, + 172883917, 2294240793, 1251141238, 2845964134, 1300830768, 2206351191, 1161746546, + 3948567717, 565054765, 1327876099, 197083479, 865284907, 1904538794, 3130051700, + 3509608189, 3371107568, 2511734607, 2170372470, 818665103, 2597229755, 4279813875, + 565043806, 2848032329, 4161811179, 238116018, 2474844821, 3312758816, 2428833949, + 4169958513, 3157499672, 2034042263, 2039063941, 649680578, 1936540734, 2774218050, + 2420282375, 4246671605, 2760606335, 741115352, 4229692006, 2442446935, 2432448267, + 1507624674, 4014954826, 172462598, 2798804471, 1478496492, 4158011907, 3578425637, + 2774058699, 3467322652, 1242946269, 3042225433, 4126609958, 1959884083, 100919636, + 4207475824, 2890301089, 1655883188, 620377993, 974373117, 3000729911, 1573593495, + 4082478003, 728493092, 4049868276, 3721543483, 1581267090, 210405669, 3770015983, + 2566039977, 123531824, 3138814883, 766181839, 2868358630, 3604139352, 101030176, + 2911320197, 3621003675, 2725396167, 4216783581, 3209235287, 3428420150, 1423404721, + 2026421510, 3331165437, 1906444927, 1738506295, 1658622337, 4037546934, 62758658, + 282215480, 2454127656, 762175613, 1864784716, 2560513854, 1764751983, 2499637040, + 1696950220, 1172385008, 4099218094, 236157041, 1521161736, 3144072206, 1012305352, + 2948845598, 3642900924, 872432072, 2660189822, 2342292884, 2587157749, 3316289558, + 1092845974, 1346977252, 1835438873, 1073442789, 453044250, 1966546511, 2671788533, + 503967775, 164271669, 4146288346, 3479082420, 1754934937, 118226603, 1362849777, + 2094921149, 3841100871, 3726630165, 2513176589, 3845873146, 1928637450, 3300578133, + 1637309271, 4273189893, 841408673, 780687312, 3068408162, 3723788458, 3186381597, + 680650751, 3443513395, 2097607606, 624135354, 178440730, 1916204003, 2089319372, + 917290619, 4220365102, 2248490289, 1781382199, 4262160900, 503058278, 898015522, + 456784347, 2666132000, 1229430938, 3892502875, 3382423168, 2225727059, 3945769469, + 1305585395, 365541118, 969113939, 1275376623, 590166712, 677877049, 2852115798, + 11096246, 347837303, 1892692355, 4081285522, 1082772706, 966807766, 3410184102, + 3357104964, 572024428, 4212523929, 4026366629, 2571474160, 2939217316, 3227415384, + 1417864419, 1792685702, 2525083979, 1192614550, 1140731683, 2360860537, 1262733984, + 169923995, 2384991095, 954107734, 3437158942, 697320146, 2069070346, 3170055388, + 3152579461, 12804847, 3140289084, 3909085161, 2415860469, 3067368095, 3025045103, + 4085925520, 2572792253, 2848513251, 3218032082, 3755474655, 3911797686, 2809561970, + 1198924467, 2370141101, 3958440341, 920580168, 4064150086, 1970638698, 4101561404, + 1291020707, 2382134586, 3411254050, 800537068, 2303787400, 3316529912, 522162461, + 280502638, 1040387771, 1417344415, 4056556051, 3920407747, 1357252829, 3512741141, + 2017705677, 1817013305, 804414100, 3589638714, 4031940838, 2568142328, 2323893758, + 4017126129, 487667137, 1958493171, 3854387515, 1458820727, 2412821961, 1877541167, + 3309883082, + }; + assert(weighted_median(v) == v.begin() + 201); + } + { + const vector v{ + 947064119, 2463473553, 1357960775, 3290572628, 2055223605, 904611658, 3310714130, + 3094821715, 4036786951, 4234658425, 2295055999, 2882748084, 1988829139, 1175927272, + 3395829442, 2864897917, 1684171291, 2341017676, 1372347005, 1618354246, 1871625234, + 2493479328, 2785009713, 4211998089, 2549570156, 2079457421, 1520300726, 2509540033, + 2321725416, 402379550, 687707643, 313538677, 1534073743, 1891764817, 938427779, + 778289328, 672019428, 3082532227, 1135121852, 4219884929, 2345588223, 3630728351, + 3144560123, 1974093771, 212457889, 1228122095, 2787861615, 3661564807, 1829513425, + 4030546139, 460921948, 3578746698, 853326091, 546697133, 1129729515, 2034359903, + 3322354990, 2995232049, 517302473, 4135387251, 1526440386, 978428423, 2548803006, + 3467195111, 1266014376, 3721207119, 1687731515, 1525865115, 3308967567, 1223477057, + 1065550786, 2785943281, 215151296, 917353139, 2660145980, 97525125, 3751687123, + 2970458123, 3098344161, 323703063, 3915020632, 2365441842, 1480229350, 274677396, + 2871267867, 1406180404, 2360820853, 974065880, 2661062989, 2712380412, 376855563, + 1979707277, 613026410, 96957959, 2887827583, 937958203, 3160515197, 1099338925, + 4184413080, 1251133231, 4118092204, 3952282721, 3650692691, 3762410885, 3793396036, + 2523016444, 2586709795, 4209510543, 3558843077, 758666971, 1143913321, 1807376325, + 514573285, 2968944465, 1727839008, 3302163890, 4241621238, 4205539121, 2150505144, + 2425284607, 2469392405, 381450548, 426153632, 1902575637, 724754293, 1087386454, + 2379653075, 3994356063, 615721295, 665602084, 2704617846, 3555476104, 2716274853, + 204320755, 2530031056, 3694326968, 4077637074, 1339388696, 1654876915, 3792303189, + 3820101065, 2245078296, 2412457046, 1862554895, 2299541843, 376910710, 4155736862, + 535327173, 1131307213, 2702553973, 2024074808, 1911180440, 2021545457, 4117127967, + 581706849, 3801421243, 953951874, 888723980, 683410143, 1199410171, 4212758383, + 2870607789, 3208805492, 585187180, 4154047050, 4019506773, 1839741093, 2570149320, + 694307633, 2443945161, 1753173027, 2758314914, 1624550078, 1656759730, 2043455526, + 1526411629, 469788492, 4279411810, 3219828347, 2387017878, 1773045882, 2812209416, + 844080004, 3185535469, 3390085243, 1006249574, 1270373140, 435096366, 1856177700, + 2681867539, 3542313457, 2889161658, 236345377, 941867968, 3915055750, 3697279435, + 2348206425, 3120712108, 662742826, 3687773207, 1806281654, 2852579935, 3178390464, + 156787741, 3925905752, 1409765060, 822899573, 1531010145, 3316184810, 3524156801, + 3957304585, 2610117718, 4032320257, 1175226157, 3412415284, 3720717688, 3449209445, + 188030566, 3913756853, 3729048572, 3372920924, 3744983428, 3757512600, 1299620994, + 442228868, 942078154, 3405276314, 2152813959, 260900849, 170712998, 3038609486, + 2433030535, 958683772, 3146705394, 817891942, 2912674091, 1414322490, 2029766898, + 1930060429, 1564519595, 3538967872, 1497897325, 1849069210, 1217409924, 2188048326, + 4220806131, 1801415580, 168420499, 542539945, 2900460540, 572821308, 2598917121, + 160899876, 4130302504, 1986510967, 2283719748, 4287957384, 801525063, 2474324718, + 4099396619, 2489831153, 4258657014, 1979851312, 3363689778, 2292922490, 1555343951, + 674249807, 105052069, 2334008542, 3561838536, 2655524681, 678625507, 388597217, + 3775808634, 4117223444, 3477484430, 2381484536, 4083994813, 4218606161, 1767150873, + 3730371373, 1804674850, 2182662111, 1433426463, 4136460195, 2255303744, 2261283108, + 1297775852, 221991570, 940090914, 35213111, 4212071344, 1825481599, 3875060466, + 2967654476, 3945157914, 1101548522, 1556484954, 4259794425, 3920238661, 2794992786, + 2042279931, 4071435559, 4183875156, 4051818214, 2929313615, 1153916099, 3070181984, + 556895225, 1368187501, 3642137283, 4065052736, 1928994759, 3115940252, 981190195, + 919553044, 2264933424, 448508889, 4050756467, 1611212400, 2269945412, 1510686865, + 870492320, 3268872237, 2288717877, 3515597231, 2526603475, 2494890555, 1250519811, + 2178242306, 2656365274, 1187050286, 1312527878, 2599956382, 2259502126, 2709353730, + 3968953185, 1626775623, 2717130741, 1489729986, 580019294, 2671393865, 3724828959, + 2978808725, 288891719, 3953746073, 3228010256, 2671789993, 3973108949, 4233489947, + 188725487, 1769923839, 3107869668, 745456721, 2925588696, 1739708663, 4080708110, + 2361062252, 612296046, 952419774, 3716624962, 1664349889, 304418999, 3862971924, + 2642697924, 3525472098, 3311873261, 1602920940, 3135489949, 654313227, 3827474756, + 3314494162, 2932376072, 1795617831, 3398424874, 138451278, 2531174825, 4218806170, + 115753329, 2369353692, 2375350422, 1965864670, 689448408, 1508746061, 1339592348, + 2941716006, 3038638842, 2551449145, 358699304, 2092893096, 3462181775, 1875672339, + 261293573, 4026018009, 3318634571, 4140835783, 636444917, 2582911950, 440702711, + 4121492823, + }; + assert(weighted_median(v) == v.begin() + 207); + } +} + } // namespace // clang-format off @@ -3556,4 +3763,5 @@ int main() { test_log2_arange(); test_factorial_arange_mod_m(); test_masked_popcount_sum(); + test_weighted_median(); } From 67a3cda8e583540172fbb1e96f986427904d0739 Mon Sep 17 00:00:00 2001 From: i80287 Date: Wed, 13 Nov 2024 14:51:33 +0300 Subject: [PATCH 3/3] update gcd() in the math_functions.hpp --- number_theory/is_prime.hpp | 2 +- number_theory/math_functions.hpp | 146 ++++++++++++++++++++------ number_theory/test_math_functions.cpp | 117 +++++++++++++-------- 3 files changed, 191 insertions(+), 74 deletions(-) diff --git a/number_theory/is_prime.hpp b/number_theory/is_prime.hpp index 2ede78c..03a97b0 100644 --- a/number_theory/is_prime.hpp +++ b/number_theory/is_prime.hpp @@ -112,7 +112,7 @@ ATTRIBUTE_CONST I128_CONSTEXPR bool is_strong_lucas_prp(uint64_t n, uint16_t p, // NOLINTNEXTLINE(bugprone-implicit-widening-of-multiplication-result) const int128_t rhs = int128_t{int64_t{2} * int64_t{q}} * int128_t{d}; - if (unlikely(std::gcd(n, rhs) != 1)) { + if (unlikely(math_functions::gcd(n, rhs) != 1)) { // is_strong_lucas_prp requires gcd(n, 2 * q * (p * p - 4 * q)) = 1 return false; } diff --git a/number_theory/math_functions.hpp b/number_theory/math_functions.hpp index f3be6ad..a0d008a 100644 --- a/number_theory/math_functions.hpp +++ b/number_theory/math_functions.hpp @@ -3043,22 +3043,15 @@ ATTRIBUTE_NODISCARD CONSTEXPR_VECTOR auto weighted_median(Range&& range ATTRIBUT #endif -} // namespace math_functions - // NOLINTEND(cppcoreguidelines-avoid-magic-numbers) -#if defined(INTEGERS_128_BIT_HPP) - -namespace std { +#ifdef INTEGERS_128_BIT_HPP -// NOLINTBEGIN(cert-dcl58-cpp) +namespace detail { -/// @brief Computes greaters common divisor of @a `a` and @a `b` -/// using Stein's algorithm (binary gcd). Here gcd(0, 0) = 0. -/// @param[in] a -/// @param[in] b -/// @return gcd(a, b) -[[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR uint128_t gcd(uint128_t a, uint128_t b) noexcept { +ATTRIBUTE_NODISCARD +ATTRIBUTE_CONST +I128_CONSTEXPR uint128_t gcd(uint128_t a, uint128_t b) noexcept { if (unlikely(a == 0)) { return b; } @@ -3088,37 +3081,126 @@ namespace std { } } -[[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR uint128_t gcd(uint64_t a, int128_t b) noexcept { - const uint128_t b0 = ::math_functions::uabs(b); - if (unlikely(b0 == 0)) { - return a; +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR uint128_t gcd(uint128_t a, int128_t b) noexcept { + return ::math_functions::detail::gcd(a, ::math_functions::uabs(b)); +} + +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR uint128_t gcd(int128_t a, uint128_t b) noexcept { + return ::math_functions::detail::gcd(::math_functions::uabs(a), b); +} + +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR int128_t gcd(int128_t a, int128_t b) noexcept { + const int128_t value = + static_cast(::math_functions::detail::gcd(::math_functions::uabs(a), b)); + CONFIG_ASSUME_STATEMENT(value >= 0); + return value; +} + +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR uint128_t gcd(uint128_t a, uint64_t b) noexcept { + if ((config::is_constant_evaluated() || config::is_gcc_constant_p(a <= b)) && a <= b) { + return std::gcd(static_cast(a), b); } - // gcd(a, b) = gcd(a, b0) = gcd(b0, a % b0) = gcd(a1, b1) - const uint128_t a1 = b0; - // b1 = a % b0 - const uint64_t b1 = a < b0 ? a : a % static_cast(b0); // a < 2^64 => b1 < 2^64 - if (b1 == 0) { - return a1; + if (unlikely(b == 0)) { + return a; } - // gcd(a1, b1) = gcd(b1, a1 % b1) = gcd(a2, b2) - const uint64_t a2 = b1; // b1 < 2^64 => a2 < 2^64 - // b2 = a1 % b1 - // a1 = b0, b1 = a % b0 => b1 < a1 - const uint64_t b2 = static_cast(a1 % b1); // b1 < 2^64 => b2 = a1 % b1 < 2^64 - return std::gcd(a2, b2); + // gcd(a, b) = gcd(b, a % b) = gcd(a % b, b) + return std::gcd(static_cast(a % b), b); } -[[nodiscard]] ATTRIBUTE_CONST I128_CONSTEXPR uint128_t gcd(int128_t a, uint64_t b) noexcept { - return std::gcd(b, a); +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR uint128_t gcd(uint64_t a, uint128_t b) noexcept { + return ::math_functions::detail::gcd(b, a); } -// NOLINTEND(cert-dcl58-cpp) +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR uint128_t gcd(uint128_t a, int64_t b) noexcept { + return ::math_functions::detail::gcd(a, ::math_functions::uabs(b)); +} -} // namespace std +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR uint128_t gcd(int64_t a, uint128_t b) noexcept { + return ::math_functions::detail::gcd(b, a); +} + +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR int128_t gcd(uint64_t a, int128_t b) noexcept { + const int128_t value = + static_cast(::math_functions::detail::gcd(a, ::math_functions::uabs(b))); + CONFIG_ASSUME_STATEMENT(value >= 0); + return value; +} + +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR int128_t gcd(int128_t a, uint64_t b) noexcept { + return ::math_functions::detail::gcd(b, a); +} + +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR int128_t gcd(int128_t a, int64_t b) noexcept { + return ::math_functions::detail::gcd(a, ::math_functions::uabs(b)); +} + +ATTRIBUTE_NODISCARD +ATTRIBUTE_ALWAYS_INLINE +ATTRIBUTE_CONST +I128_CONSTEXPR int128_t gcd(int64_t a, int128_t b) noexcept { + return ::math_functions::detail::gcd(b, a); +} + +} // namespace detail #endif // INTEGERS_128_BIT_HPP +/// @brief Computes greaters common divisor of @a `a` and @a `b` +/// using Stein's algorithm (binary gcd). Here gcd(0, 0) = 0. +/// @param[in] a +/// @param[in] b +/// @return gcd(a, b) +template +[[nodiscard]] +ATTRIBUTE_ALWAYS_INLINE ATTRIBUTE_CONST constexpr std::common_type_t gcd(M m, N n) noexcept { + static_assert( + ::math_functions::detail::is_integral_v && ::math_functions::detail::is_integral_v, + "math_functions::gcd arguments must be integers"); + +#if defined(INTEGERS_128_BIT_HPP) + if constexpr (sizeof(M) <= sizeof(uint64_t) && sizeof(N) <= sizeof(uint64_t)) { +#endif + return std::gcd(m, n); +#if defined(INTEGERS_128_BIT_HPP) + } else { + return ::math_functions::detail::gcd(m, n); + } +#endif +} + +} // namespace math_functions + #undef CONSTEXPR_VECTOR #ifdef MATH_FUNCTIONS_HAS_BIT #undef MATH_FUNCTIONS_HAS_BIT diff --git a/number_theory/test_math_functions.cpp b/number_theory/test_math_functions.cpp index da7745a..8a6b10d 100644 --- a/number_theory/test_math_functions.cpp +++ b/number_theory/test_math_functions.cpp @@ -47,7 +47,6 @@ using namespace math_functions; // NOLINTNEXTLINE(google-build-using-namespace) using namespace test_tools; -using std::gcd; using std::size_t; using std::uint32_t; using std::uint64_t; @@ -1869,52 +1868,88 @@ void test_general_asserts() noexcept { I128_ASSERT_THAT(base_b_len(-int128_t{101}) == 4); I128_ASSERT_THAT(base_b_len(static_cast(uint128_t{1} << 127U)) == 40); - I128_ASSERT_THAT(gcd(uint128_t{1}, uint128_t{1}) == 1); - I128_ASSERT_THAT(gcd(uint128_t{3}, uint128_t{7}) == 1); - I128_ASSERT_THAT(gcd(uint128_t{0}, uint128_t{112378432}) == 112378432); - I128_ASSERT_THAT(gcd(uint128_t{112378432}, uint128_t{0}) == 112378432); - I128_ASSERT_THAT(gcd(uint128_t{429384832}, uint128_t{324884}) == 4); - I128_ASSERT_THAT(gcd(uint128_t{18446744073709551521ULL}, uint128_t{18446744073709551533ULL}) == - 1); - I128_ASSERT_THAT(gcd(uint128_t{18446744073709551521ULL} * 18446744073709551521ULL, - uint128_t{18446744073709551521ULL}) == 18446744073709551521ULL); - I128_ASSERT_THAT(gcd(uint128_t{23999993441ULL} * 23999993377ULL, - uint128_t{23999992931ULL} * 23999539633ULL) == 1); - I128_ASSERT_THAT(gcd(uint128_t{2146514599U} * 2146514603U * 2146514611U, - uint128_t{2146514611U} * 2146514621U * 2146514647U) == 2146514611ULL); - I128_ASSERT_THAT(gcd(uint128_t{2146514599U} * 2146514603U * 2146514611U * 2, - uint128_t{2146514599U} * 2146514603U * 2146514611U * 3) == - uint128_t{2146514599U} * 2146514603U * 2146514611U); - I128_ASSERT_THAT(gcd(uint128_t{100000000000000003ULL} * 1000000000000000003ULL, - uint128_t{1000000000000000003ULL} * 1000000000000000009ULL) == - 1000000000000000003ULL); - I128_ASSERT_THAT(gcd(uint128_t{3} * 2 * 5 * 7 * 11 * 13 * 17 * 19, - uint128_t{18446744073709551557ULL} * 3) == 3); - I128_ASSERT_THAT(gcd(uint128_t{1000000000000000009ULL}, - uint128_t{1000000000000000009ULL} * 1000000000000000009ULL) == - 1000000000000000009ULL); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{1}, uint128_t{1}) == 1); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{3}, uint128_t{7}) == 1); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{0}, uint128_t{112378432}) == 112378432); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{112378432}, uint128_t{0}) == 112378432); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{429384832}, uint128_t{324884}) == 4); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551521ULL}, + uint128_t{18446744073709551533ULL}) == 1); I128_ASSERT_THAT( - gcd(uint128_t{0}, uint128_t{1000000000000000009ULL} * 1000000000000000009ULL) == - uint128_t{1000000000000000009ULL} * 1000000000000000009ULL); - I128_ASSERT_THAT(gcd(uint128_t{18446744073709551557ULL}, uint128_t{0}) == + math_functions::gcd(uint128_t{18446744073709551521ULL} * 18446744073709551521ULL, + uint128_t{18446744073709551521ULL}) == 18446744073709551521ULL); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{23999993441ULL} * 23999993377ULL, + uint128_t{23999992931ULL} * 23999539633ULL) == 1); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{2146514599U} * 2146514603U * 2146514611U, + uint128_t{2146514611U} * 2146514621U * 2146514647U) == + 2146514611ULL); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{2146514599U} * 2146514603U * 2146514611U * 2, + uint128_t{2146514599U} * 2146514603U * 2146514611U * 3) == + uint128_t{2146514599U} * 2146514603U * 2146514611U); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{100000000000000003ULL} * 1000000000000000003ULL, + uint128_t{1000000000000000003ULL} * + 1000000000000000009ULL) == 1000000000000000003ULL); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{3} * 2 * 5 * 7 * 11 * 13 * 17 * 19, + uint128_t{18446744073709551557ULL} * 3) == 3); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{1000000000000000009ULL}, + uint128_t{1000000000000000009ULL} * + 1000000000000000009ULL) == 1000000000000000009ULL); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{0}, uint128_t{1000000000000000009ULL} * + 1000000000000000009ULL) == + uint128_t{1000000000000000009ULL} * 1000000000000000009ULL); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551557ULL}, uint128_t{0}) == 18446744073709551557ULL); - I128_ASSERT_THAT(gcd(uint64_t{2}, int128_t{4}) == 2); - I128_ASSERT_THAT(gcd(uint64_t{2}, int128_t{-4}) == 2); - I128_ASSERT_THAT(gcd(uint64_t{3}, int128_t{7}) == 1); - I128_ASSERT_THAT(gcd(uint64_t{3}, int128_t{-7}) == 1); - I128_ASSERT_THAT(gcd(uint64_t{3}, int128_t{18446744073709551557ULL} * 3) == 3); - I128_ASSERT_THAT(gcd(uint64_t{3}, int128_t{18446744073709551557ULL} * (-3)) == 3); - I128_ASSERT_THAT(gcd(uint64_t{3} * 2 * 5 * 7 * 11 * 13 * 17 * 19, - int128_t{18446744073709551557ULL} * 3) == 3); - I128_ASSERT_THAT(gcd(uint64_t{1000000000000000009ULL}, - int128_t{1000000000000000009LL} * 1000000000000000009LL) == + I128_ASSERT_THAT(math_functions::gcd(uint64_t{2}, int128_t{4}) == 2); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{2}, int128_t{-4}) == 2); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{3}, int128_t{7}) == 1); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{3}, int128_t{-7}) == 1); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{3}, int128_t{18446744073709551557ULL} * 3) == 3); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{3}, int128_t{18446744073709551557ULL} * (-3)) == + 3); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{3} * 2 * 5 * 7 * 11 * 13 * 17 * 19, + int128_t{18446744073709551557ULL} * 3) == 3); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{1000000000000000009ULL}, + int128_t{1000000000000000009LL} * 1000000000000000009LL) == 1000000000000000009ULL); - I128_ASSERT_THAT(gcd(uint64_t{0}, int128_t{1000000000000000009LL} * 1000000000000000009LL) == - uint128_t{1000000000000000009LL} * 1000000000000000009ULL); - I128_ASSERT_THAT(gcd(uint64_t{18446744073709551557ULL}, int128_t{0}) == + I128_ASSERT_THAT( + math_functions::gcd(uint64_t{0}, int128_t{1000000000000000009LL} * 1000000000000000009LL) == + uint128_t{1000000000000000009LL} * 1000000000000000009ULL); + I128_ASSERT_THAT(math_functions::gcd(uint64_t{18446744073709551557ULL}, int128_t{0}) == 18446744073709551557ULL); + // clang-format off + + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551557ULL}, int128_t{18446744073709551521ULL}) == int128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551533ULL}, int128_t{18446744073709551557ULL}) == int128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551521ULL}, int128_t{18446744073709551533ULL}) == int128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551557ULL}, int128_t{18446744073709551557ULL}) == int128_t{18446744073709551557ULL}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551533ULL}, int128_t{18446744073709551533ULL}) == int128_t{18446744073709551533ULL}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551521ULL}, int128_t{18446744073709551521ULL}) == int128_t{18446744073709551521ULL}); + + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551557ULL}, int128_t{18446744073709551521ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551533ULL}, int128_t{18446744073709551557ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551521ULL}, int128_t{18446744073709551533ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551557ULL}, int128_t{18446744073709551557ULL}) == uint128_t{18446744073709551557ULL}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551533ULL}, int128_t{18446744073709551533ULL}) == uint128_t{18446744073709551533ULL}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551521ULL}, int128_t{18446744073709551521ULL}) == uint128_t{18446744073709551521ULL}); + + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551557ULL}, uint128_t{18446744073709551521ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551533ULL}, uint128_t{18446744073709551557ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551521ULL}, uint128_t{18446744073709551533ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551557ULL}, uint128_t{18446744073709551557ULL}) == uint128_t{18446744073709551557ULL}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551533ULL}, uint128_t{18446744073709551533ULL}) == uint128_t{18446744073709551533ULL}); + I128_ASSERT_THAT(math_functions::gcd(int128_t{18446744073709551521ULL}, uint128_t{18446744073709551521ULL}) == uint128_t{18446744073709551521ULL}); + + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551557ULL}, uint128_t{18446744073709551521ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551533ULL}, uint128_t{18446744073709551557ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551521ULL}, uint128_t{18446744073709551533ULL}) == uint128_t{1}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551557ULL}, uint128_t{18446744073709551557ULL}) == uint128_t{18446744073709551557ULL}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551533ULL}, uint128_t{18446744073709551533ULL}) == uint128_t{18446744073709551533ULL}); + I128_ASSERT_THAT(math_functions::gcd(uint128_t{18446744073709551521ULL}, uint128_t{18446744073709551521ULL}) == uint128_t{18446744073709551521ULL}); + + // clang-format on + ASSERT_THAT(math_functions::popcount(0U) == 0); ASSERT_THAT(math_functions::popcount(1U << 1U) == 1); ASSERT_THAT(math_functions::popcount(1U << 2U) == 1);