Skip to content

Commit

Permalink
bigint: Use a better Montgomery RR doubling-vs-squaring trade-off.
Browse files Browse the repository at this point in the history
Clarify how the math works.
  • Loading branch information
briansmith committed Nov 16, 2023
1 parent 25112e9 commit fb94369
Showing 1 changed file with 105 additions and 36 deletions.
141 changes: 105 additions & 36 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ use crate::{
bits::BitLength,
c, cpu, error,
limb::{self, Limb, LimbMask, LIMB_BITS},
polyfill::u64_from_usize,
};
use alloc::vec;
use core::{marker::PhantomData, num::NonZeroU64};
Expand Down Expand Up @@ -276,46 +275,116 @@ impl<M> One<M, RR> {
// is correct because R**2 will still be a multiple of the latter as
// `N0::LIMBS_USED` is either one or two.
fn newRR(m: &Modulus<M>) -> Self {
let m_bits = m.len_bits().as_usize_bits();
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;

// base = 2**r (mod m) == R (mod m).
let mut base = m.zero();
m.oneR(&mut base.limbs);

// Double `base` so that base == 2*R (mod m), i.e. `2` in Montgomery
// form (`elem_exp_vartime()` requires the base to be in Montgomery
// form). Then compute
// RR = R**2 == base**r == R**r == (2**r)**r (mod m).
let w = m.limbs().len(); // The words length of the numbers involved.

// The length of the numbers involved, in bits. R = 2**r.
let r = w * LIMB_BITS;

// w = d * 2**z.
let z = w.trailing_zeros();
let d = w >> z;

#[allow(non_upper_case_globals)]
const b: u32 = LIMB_BITS.ilog2();
#[allow(clippy::assertions_on_constants)]
const _LIMB_BITS_IS_2_POW_B: () = assert!(LIMB_BITS == 1 << b);
debug_assert_eq!(r, w * (1 << b));

// RR = R**2 (mod m)
// = (2**r)**2 (mod m)
// = (2**r)*(2**r) (mod m)
// = (2**r)**(t*LIMB_BITS) (mod m)
// = (2**r)**(d * 2**z * LIMB_BITS) (mod m)
// = (2**r)**(d * 2**z * 2**b) (mod m)
// = (2**r)**(d * 2**(z+b)) (mod m)
// = ((2**r)**d)**(2**(z+b)) (mod m)

let mut acc: Elem<M, R> = m.zero();
m.oneR(&mut acc.limbs);

// (2**r)**d = 2**(r*d) (mod m)
// = 2**r * 2**d (mod m)
// = R * 2**d (mod m)
// = acc * 2**d (mod m)
//
// Thus we can compute (2**r)**d by doubling `acc` d times.
//
// Then we'd need to compute acc**(2**(z+b)). Notice above that `acc` would be equal to
// R * 2**d. Since that has a Montgomery factory (R) we could do the exponentiation as
// `elem_exp_vartime(acc, 2**(z+b), m)`.
//
// Since the exponent is a power of two, that exponentiation would consist of z + b
// Montgomery squarings and zero Montgomery multiplications as there are no trailing zeroes
// in the binary representation of that exponent.
//
// The first Montgomery squaring of that exponentiation would give:
//
// acc**2/R = (R * 2**d)**2 / R (mod m)
// = (R * 2**d)*(R * 2**d) / R (mod m)
// = R * 2**d * R * 2**d / R (mod m)
// = R*R/R * 2**d * 2**d (mod m)
// = R * 2**d * 2**d (mod m)
// = (R * 2**d) * 2**d (mod m)
// = acc * 2**d (mod m)
//
// Take advantage of the fact that `elem_double` is faster than
// `elem_squared` by replacing some of the early squarings with
// doublings.
// TODO: Benchmark doubling vs. squaring performance to determine the
// optimal value of `LG_BASE`.
const LG_BASE: usize = 2; // Doubling vs. squaring trade-off.
debug_assert_eq!(LG_BASE.count_ones(), 1); // Must be 2**n for n >= 0.

let doublings = LG_BASE;
// `m_bits >= LG_BASE` (for the currently chosen value of `LG_BASE`)
// since we require the modulus to have at least `MODULUS_MIN_LIMBS`
// limbs. `r >= m_bits` as seen above. So `r >= LG_BASE` and thus
// `r / LG_BASE` is non-zero.
// In other words, that first Montgomery squaring would be equivalent to doubling d times.
// Doubling is much faster than squaring. Let's say we can do t doublings in the time it
// takes to do one squaring. Then it would be faster to double `acc` t times before doing
// the first squaring of the exponentiation.
//
// The maximum value of `r` is determined by
// `MODULUS_MAX_LIMBS * LIMB_BITS`. Further `r` is a multiple of
// `LIMB_BITS` so the maximum Hamming Weight is bounded by
// `MODULUS_MAX_LIMBS`. For the common case of {2048, 4096, 8192}-bit
// moduli the Hamming weight is 1. For the other common case of 3072
// the Hamming weight is 2.
let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap();
for _ in 0..doublings {
elem_double(&mut base, m)
// We must set that threshold to at least d if we want to avoid multiplications (which are
// even more expensive than 2**t squarings) in the square-and-multiply exponentiation.
// 1024-, 2048-, 4086-, and 8192-, and 16384- bit moduli all have d = 1 (powers of two),
// while 1536- and 3072- bit moduli have d = 2. For these reasons, we don't consider setting
// the threshold to less than d.
//
// David Benjamin did some experiments and concluded that (in BoringSSL) it makes sense to
// set the doubling-vs-squaring performance trade-off threshold to the number of limbs,
// emphasizing that the complexity of doubling is O(n) but the complexity of Montgomery
// squaring is O(n**2). (Below we have more insight into what values of t make sense.)

// Earlier we noted that the first squaring of the exponentiation would have been equivalent
// to doubling `d` times. Now we've doubled `t` times. That same math, substituting `t` for
// `d`, will show us that the first squaring is now equivalent to doubling `t` times. More
// generally, before the `n`th squaring we'll have done the equivalent of 2**(n-1) * t
// doublings, and that `n`th squaring will do the equivalent of that same amount, so that
// after the `n`th squaring we'll have done the equivalent of 2**n * t doublings:
//
// 2 * 2**(n-1) * t
// = 2**(n-1+1) * t
// = 2**n * t.
//
// As long as we set `t` to be the number of limbs, we'll have r = 2**b * t, so we need to
// do `b` squarings. Then we're done.
let t = w;
debug_assert!(d <= t);
debug_assert!(t < r);
for _ in 0..t {
elem_double(&mut acc, m);
}
let RR = elem_exp_vartime(base, exponent, m);
debug_assert_eq!(r, t * (1 << b));
for _ in 0..b {
acc = elem_squared(acc, m);
}

// If we had instead set `t` to be different threshold, we would instead do something like
// this:
// ```
// let mut done = t;
// while done < (r - done) {
// acc = elem_squared(acc, m);
// done *= 2;
// }
// for _ in done..r {
// elem_double(&mut r, m);
// }
// ```
// Because of this, we'd have to change the threshold `t` by powers of two (halving or
// doubling some number of times) for the change to have an effect other than just deferring
// some of the doubles to the end.

Self(Elem {
limbs: RR.limbs,
limbs: acc.limbs,
encoding: PhantomData, // PhantomData<RR>
})
}
Expand Down

0 comments on commit fb94369

Please sign in to comment.