Skip to content

Commit

Permalink
Initial version of no-carry montgomery multiplication.
Browse files Browse the repository at this point in the history
  • Loading branch information
martun committed Jul 18, 2024
1 parent 93edae7 commit 655b98a
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,15 @@ namespace boost {

template<unsigned Bits>
BOOST_MP_FORCEINLINE BOOST_MP_CXX14_CONSTEXPR typename std::enable_if<
boost::multiprecision::is_unsigned_number<cpp_int_modular_backend<Bits>>::value &&
!boost::multiprecision::backends::is_trivial_cpp_int_modular<cpp_int_modular_backend<Bits>>::value &&
!boost::multiprecision::backends::is_trivial_cpp_int_modular<cpp_int_modular_backend<Bits>>::value>::
type
!boost::multiprecision::backends::is_trivial_cpp_int_modular<cpp_int_modular_backend<Bits>>::value>::type
eval_complement(cpp_int_modular_backend<Bits>& result, const cpp_int_modular_backend<Bits>& o) noexcept {

unsigned os = o.size();
for (unsigned i = 0; i < os; ++i)
result.limbs()[i] = ~o.limbs()[i];
result.normalize();
}
#ifndef TVM

// Left shift will throw away upper bits.
// This function must be called only when s % 8 == 0, i.e. we shift bytes.
template<unsigned Bits>
Expand All @@ -129,7 +126,6 @@ namespace boost {
std::memset(pc, 0, bytes);
}
}
#endif

// Left shift will throw away upper bits.
// This function must be called only when s % limb_bits == 0, i.e. we shift limbs, which are normally 64 bit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,21 @@ namespace boost {
barrett_reduce(result, tmp);
}

// Delegates Montgomery multiplication to one of corresponding algorithms.
template<typename Backend1>
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul(Backend1 &result, const Backend1 &y) const {
return montgomery_mul_impl(result, y, std::integral_constant<bool, is_trivial_cpp_int_modular<Backend1>::value>());
void montgomery_mul(Backend1 &result, const Backend1 &y) const {

if ( is_applicable_for_faster_montgomery_mul())
// Translated from pseudo-code
montgomery_mul_impl__from_pseudocode(
result,
y,
std::integral_constant<bool, is_trivial_cpp_int_modular<Backend>::value>() );
else
montgomery_mul_impl__previous_version(
result,
y,
std::integral_constant<bool, is_trivial_cpp_int_modular<Backend>::value>() );
}

//
Expand All @@ -428,7 +440,9 @@ namespace boost {
//
// A specialization for trivial cpp_int_modular types only.
template<typename Backend1>
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul_impl(Backend1 &result, const Backend1 &y, std::integral_constant<bool, true> const&) const {
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul_impl(
Backend1 &result, const Backend1 &y,
std::integral_constant<bool, true> const&) const {
BOOST_ASSERT(eval_lt(result, m_mod) && eval_lt(y, m_mod));

Backend_padded_limbs A(internal_limb_type(0u));
Expand Down Expand Up @@ -487,9 +501,212 @@ namespace boost {
result = A;
}

// Given a value represented in 'double_limb_type', decomposes it into
// two 'limb_type' variables, based on high order bits and low order bits.
// There 'a' receives high order bits of 'X', and 'b' receives the low order bits.
static BOOST_MP_CXX14_CONSTEXPR void dbl_limb_to_limbs(
const internal_double_limb_type& X,
internal_limb_type& a,
internal_limb_type& b ) {
b = X;
a = X >> limb_bits;
}

// Given values of 2 limbs 'a' and 'b', writes them in a 'double_limb_type' X.
// There, 'a' goes to high order bits, and 'b' goes to low order bits.
static BOOST_MP_CXX14_CONSTEXPR void limbs_to_dbl_limb(
const internal_limb_type& a,
const internal_limb_type& b,
internal_double_limb_type& X ) {
X = a;
X <<= limb_bits;
X += b;
}

// Tests if the faster implementation of Montgomery multiplication is possible.
// We don't need the template argument Backend1, it's just here to enable specialization.
template<class Backend1 = Backend>
BOOST_MP_CXX14_CONSTEXPR typename boost::enable_if_c<!is_trivial_cpp_int_modular<Backend1>::value, bool>::type
is_applicable_for_faster_montgomery_mul() const {

// Check that its highest bit is 1, some other bit is 0, and it has < 12 limbs.
return m_mod.internal_limb_count < 12 && eval_bit_test(m_mod, Bits - 1) &&
!eval_eq(m_mod_compliment, Backend(internal_limb_type(1u)));
}

template<class Backend1 = Backend>
BOOST_MP_CXX14_CONSTEXPR typename boost::enable_if_c<is_trivial_cpp_int_modular<Backend1>::value, bool>::type
is_applicable_for_faster_montgomery_mul() const {
return false;
}

// Code for multiplication, taken from:
// "https://github.com/arkworks-rs/algebra/blob/065cd24fc5ae17e024c892cee126ad3bd885f01c/ff-macros/src/montgomery/mul.rs"
// Manually translated from Rust

// Calculates "a + b*c"
// Returns lower bits of result
// Updates "carry" to contain the higher bits
BOOST_MP_CXX14_CONSTEXPR internal_limb_type mac(
internal_limb_type a,
internal_limb_type b,
internal_limb_type c,
internal_limb_type& carry ) const {
// Necessary typedefs
typedef internal_limb_type u64;
typedef internal_double_limb_type u128;
// How many bits there are in one limb
constexpr unsigned int limb_digits = std::numeric_limits<internal_limb_type>::digits;

// " let tmp = (a as u128) + widening_mul(b, c);"
u128 tmp( b );
tmp *= c;
tmp += a;

// " *carry = (tmp >> 64) as u64;"
carry = (tmp >> limb_digits);

// " tmp as u64"
return tmp;
}

//
// Calculates "a + b*c + carry"
// Updates "carry" to contain the higher bits
BOOST_MP_CXX14_CONSTEXPR void mac_discard(
internal_limb_type a,
internal_limb_type b,
internal_limb_type c,
internal_limb_type& carry ) const {
// Necessary typedefs
typedef internal_limb_type u64;
typedef internal_double_limb_type u128;
// How many bits there are in one limb
constexpr unsigned int limb_digits = std::numeric_limits<internal_limb_type>::digits;

// " let tmp = (a as u128) + widening_mul(b, c);"
u128 tmp( b );
tmp *= c;
tmp += a;

// " *carry = (tmp >> 64) as u64;"
tmp >>= limb_digits;
carry = tmp;
}

//
// Calculates "a + b*c + carry"
// Returns lower bits of result
// Updates "carry" to contain the higher bits
BOOST_MP_CXX14_CONSTEXPR internal_limb_type mac_with_carry(
internal_limb_type a,
internal_limb_type b,
internal_limb_type c,
internal_limb_type& carry ) const {
// Necessary typedefs
typedef internal_limb_type u64;
typedef internal_double_limb_type u128;
// How many bits there are in one limb
constexpr unsigned int limb_digits = std::numeric_limits<internal_limb_type>::digits;

// " let tmp = (a as u128) + widening_mul(b, c) + (*carry as u128);"
u128 tmp( b );
tmp *= c;
tmp += a;
tmp += carry;

// " *carry = (tmp >> 64) as u64;"
carry = (tmp >> limb_digits);

// " tmp as u64"
return tmp;
}

// Faster implementation of Montgomery multiplication.
// Resource is taken from pseudo-code, present at:
// "https://hackmd.io/@gnark/modular_multiplication".
template< typename Backend1 >
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul_impl__from_pseudocode(
Backend1& c,
const Backend1& b,
std::integral_constant<bool, false> const& ) const {

BOOST_ASSERT( eval_lt(c, m_mod) && eval_lt(b, m_mod) );
BOOST_ASSERT( is_applicable_for_faster_montgomery_mul() );

// Obtain number of limbs
constexpr int N = Backend1::internal_limb_count;

const Backend1 a( c ); // Copy the first argument, as the implemented
// algorithm doesn't work in-place.

// Prepare temporary variables
c = internal_limb_type(0u);
internal_limb_type A( 0u ), C( 0u );
internal_double_limb_type tmp( 0u );
internal_limb_type dummy;

auto* a_limbs = a.limbs();
auto* b_limbs = b.limbs();
auto* c_limbs = c.limbs();
auto* m_mod_limbs = m_mod.limbs();

for ( int i = 0; i < N; ++i ) {
// "(A,t[0]) := t[0] + a[0]*b[i]"
tmp = a_limbs[0];
tmp *= b_limbs[i];
tmp += c_limbs[0];
modular_functions_fixed::dbl_limb_to_limbs( tmp, A, c_limbs[0] );

// "m := t[0]*q'[0] mod W"
tmp = c_limbs[0];
//tmp *= q.limbs()[0];
tmp *= m_montgomery_p_dash;
// tmp = -tmp;
internal_limb_type m = tmp;

// "(C,_) := t[0] + m*q[0]"
tmp = m; // This can be omitted because of the last assignment
tmp *= m_mod_limbs[0];
tmp += c_limbs[0];
modular_functions_fixed::dbl_limb_to_limbs( tmp, C, dummy );

for ( int j = 1; j < N; ++j ) {
// "(A,t[j]) := t[j] + a[j]*b[i] + A"
tmp = a_limbs[j];
tmp *= b_limbs[i]; // !!! Used '+' instead of '*'
tmp += c_limbs[j];
tmp += A;
modular_functions_fixed::dbl_limb_to_limbs( tmp, A, c_limbs[j] );

// "(C,t[j-1]) := t[j] + m*q[j] + C"
tmp = m;
tmp *= m_mod_limbs[j]; // !!! Used '+' instead of '*'
tmp += c_limbs[j];
tmp += C;
modular_functions_fixed::dbl_limb_to_limbs( tmp, C, c_limbs[j-1] );
}

// "t[N-1] = C + A"
c_limbs[N-1] = C + A;
}
}

// Overload for primitive data types
template< typename Backend1 >
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul_impl__from_pseudocode(
Backend1& c,
const Backend1& b,
std::integral_constant<bool, true> const&) const {
montgomery_mul_impl__previous_version(
c,
b,
std::integral_constant<bool, true>() );
}

// A specialization for non-trivial cpp_int_modular types only.
template<typename Backend1>
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul_impl(Backend1 &result, const Backend1 &y,
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul_impl__previous_version(Backend1 &result, const Backend1 &y,
std::integral_constant<bool, false> const&) const {
BOOST_ASSERT(eval_lt(result, m_mod) && eval_lt(y, m_mod));

Expand Down Expand Up @@ -602,6 +819,75 @@ namespace boost {
result = A;
}

//
// WARNING: could be errors here due to trivial backend -- more tests needed
// TODO(martun): optimize this function, it obviously does not need to be this long.
//
// A specialization for trivial cpp_int_modular types only.
template< typename Backend1 >
BOOST_MP_CXX14_CONSTEXPR void montgomery_mul_impl__previous_version(
Backend1& result,
const Backend1& y,
std::integral_constant<bool, true> const& ) const {

BOOST_ASSERT(eval_lt(result, m_mod) && eval_lt(y, m_mod));

Backend_padded_limbs A(internal_limb_type(0u));
const size_t mod_size = m_mod.size();
auto mod_last_limb = static_cast<internal_double_limb_type>(get_limb_value(m_mod, 0));
auto y_last_limb = get_limb_value(y, 0);

for (size_t i = 0; i < mod_size; i++) {
auto x_i = get_limb_value(result, i);
auto A_0 = A.limbs()[0];
internal_limb_type u_i = (A_0 + x_i * y_last_limb) * m_montgomery_p_dash;

// A += x[i] * y + u_i * m followed by a 1 limb-shift to the right
internal_limb_type k = 0;
internal_limb_type k2 = 0;

internal_double_limb_type z = static_cast<internal_double_limb_type>(y_last_limb) *
static_cast<internal_double_limb_type>(x_i) +
A_0 + k;
internal_double_limb_type z2 = mod_last_limb * static_cast<internal_double_limb_type>(u_i) +
static_cast<internal_limb_type>(z) + k2;
k = static_cast<internal_limb_type>(z >> std::numeric_limits<internal_limb_type>::digits);
k2 = static_cast<internal_limb_type>(z2 >> std::numeric_limits<internal_limb_type>::digits);

for (size_t j = 1; j < mod_size; ++j) {
internal_double_limb_type t =
static_cast<internal_double_limb_type>(get_limb_value(y, j)) *
static_cast<internal_double_limb_type>(x_i) +
A.limbs()[j] + k;
internal_double_limb_type t2 =
static_cast<internal_double_limb_type>(get_limb_value(m_mod, j)) *
static_cast<internal_double_limb_type>(u_i) +
static_cast<internal_limb_type>(t) + k2;
A.limbs()[j - 1] = static_cast<internal_limb_type>(t2);
k = static_cast<internal_limb_type>(t >>
std::numeric_limits<internal_limb_type>::digits);
k2 = static_cast<internal_limb_type>(t2 >>
std::numeric_limits<internal_limb_type>::digits);
}
internal_double_limb_type tmp =
static_cast<internal_double_limb_type>(
custom_get_limb_value<internal_limb_type>(A, mod_size)) +
k + k2;
custom_set_limb_value<internal_limb_type>(A, mod_size - 1,
static_cast<internal_limb_type>(tmp));
custom_set_limb_value<internal_limb_type>(
A, mod_size,
static_cast<internal_limb_type>(tmp >>
std::numeric_limits<internal_limb_type>::digits));
}

if (!eval_lt(A, m_mod)) {
eval_subtract(A, m_mod);
}

result = A;
}

template<typename Backend1, typename Backend2, typename Backend3,
/// result should fit in the output parameter
typename = typename boost::enable_if_c<boost::multiprecision::backends::max_precision<Backend1>::value >=
Expand Down Expand Up @@ -709,3 +995,4 @@ namespace boost {
} // namespace boost

#endif // CRYPTO3_MULTIPRECISION_MODULAR_FUNCTIONS_FIXED_PRECISION_HPP

Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ BOOST_AUTO_TEST_CASE(modular_adaptor_montgomery_mult_perf_test) {
std::cout << base_data << std::endl;
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now() - start);
std::cout << "Multiplication time: " << std::fixed << std::setprecision(3)
std::cout << "Multiplication time (when montgomery_mul is called directly): " << std::fixed << std::setprecision(3)
<< std::dec << elapsed.count() / SAMPLES << " ns" << std::endl;
}

Expand Down Expand Up @@ -162,7 +162,7 @@ BOOST_AUTO_TEST_CASE(modular_adaptor_backend_mult_perf_test) {

auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now() - start);
std::cout << "Multiplication time: " << std::fixed << std::setprecision(3)
std::cout << "Multiplication time (when called from modular adaptor): " << std::fixed << std::setprecision(3)
<< elapsed.count() / SAMPLES << " ns" << std::endl;

// Print something so the whole computation is not optimized out.
Expand Down

0 comments on commit 655b98a

Please sign in to comment.