Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scmul for WBW Montgomery #847

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

JasonGross
Copy link
Collaborator

@mdempsky
Copy link

Thanks, that was quick! Let me benchmark to confirm the new code is actually faster than repeated additions.

@mdempsky
Copy link

Unfortunately, in my benchmarking, these scmul routines are slower than repeated additions.

fiat_p224_add takes about 12ns; tripling and quadrupling take about 25ns; and octupling takes about 38ns.

The generated scmul routines all take 50+ ns.

@JasonGross
Copy link
Collaborator Author

Hmm, that's unfortunate. I wonder if we can speed things up with rewrite rules that turn multiplications into shifts and masks. Thoughts @davidben @andres-erbsen ?

@mdempsky I take it the bitshifting code in Go is even faster than repeated additions? Could you share a link to it?

@mdempsky
Copy link

mdempsky commented Jul 14, 2020

@JasonGross I haven't benchmarked it specifically, but I assume the bitshifting is faster if only because it's fewer instructions. I can measure though if it would be helpful.

Here's an example of "delta = 8 * beta" being computed: https://github.com/golang/go/blob/e88ea87e7b886815cfdadc4cd3d70bf5ef833bd7/src/crypto/elliptic/p224.go#L618-L620

In my fiat port, I wrote that instead as:

fiat_p224_add(&delta, &beta)
fiat_p224_add(&delta, &delta)
fiat_p224_add(&delta, &delta)

Here's code that computes "t = 3 * t": https://github.com/golang/go/blob/e88ea87e7b886815cfdadc4cd3d70bf5ef833bd7/src/crypto/elliptic/p224.go#L600-L602

I rewrote that to:

fiat_p224_add(&tmp, &t, &t)
fiat_p224_add(&t, &t, &tmp)

Edit: Note that p224.go uses [8]uint32 with radix-2^28, so they've got some extra headroom for doing shifts and delaying carries.

@JasonGross
Copy link
Collaborator Author

Edit: Note that p224.go uses [8]uint32 with radix-2^28, so they've got some extra headroom for doing shifts and delaying carries.

In Montgomery, there is no separate carrying, and limbs are always saturated. It sounds like the current code uses an unsaturated Solinas representation?

If you want me to generate scmul that just inlines a bunch of double+add, I can do that, though it'll take some time, and I think it's probably not worth it to generate such code at this point?

@mdempsky
Copy link

In Montgomery, there is no separate carrying, and limbs are always saturated. It sounds like the current code uses an unsaturated Solinas representation?

I'm not hip with the lingo, but that sounds right.

If you want me to generate scmul that just inlines a bunch of double+add, I can do that, though it'll take some time, and I think it's probably not worth it to generate such code at this point?

Agreed that I don't think it's worth putting much effort into right now. I just knew there was the fixed-multiplication in curve25519, and thought this might be worth looking into if it was easy and provided a win. It seems like it doesn't at the moment.

@mdempsky
Copy link

mdempsky commented Jul 14, 2020

Here's a 64-bit scmul_8 implementation I came up with:

func fiat_p224_scmul_8(out1 *[4]uint64, arg1 *[4]uint64) {
	x1 := arg1[0] << 3
	x2 := arg1[1]<<3 | arg1[0]>>61
	x3 := arg1[2]<<3 | arg1[1]>>61
	x4 := arg1[3]<<3 | arg1[2]>>61

	z := x4 >> 32

	// Now add z * 2^96...
	m1, c1 := x1, uint64(0)
	m2, c2 := bits.Add64(x2, z<<32, c1)
	m3, c3 := bits.Add64(x3, 0, c2)
	m4 := x4 + c3

	// ... and then subtract z * (2^224 + 1).
	n1, b1 := bits.Sub64(m1, z, 0)
	n2, b2 := bits.Sub64(m2, 0, b1)
	n3, b3 := bits.Sub64(m3, 0, b2)
	n4 := m4 - z<<32 - b3

	// Finally, subtract P once more...
	o1, a1 := bits.Sub64(n1, 0x1, 0)
	o2, a2 := bits.Sub64(n2, 0xffffffff00000000, a1)
	o3, a3 := bits.Sub64(n3, 0xffffffffffffffff, a2)
	o4, a4 := bits.Sub64(n4, 0xffffffff, a3)

	// ... and keep the result if positive.
	fiat_p224_cmovznz_u64(&out1[0], a4, o1, n1)
	fiat_p224_cmovznz_u64(&out1[1], a4, o2, n2)
	fiat_p224_cmovznz_u64(&out1[2], a4, o3, n3)
	fiat_p224_cmovznz_u64(&out1[3], a4, o4, n4)
}

This implementation is barely slower than fiat_p224_add. The same idea with different initial rotations should work for scmul_4; and with slightly extra complexity, I think scmul_3 should be doable too.

Caveat: Experimentally, on random inputs, it seems to match the output of three consecutive fiat_p224_add calls; but I haven't spent that long convincing myself that it doesn't have any corner cases that aren't handled correctly.

Edit: I also think with some extra cleverness, the "add z * 2^96" and "sub z * (2^224+1)" steps could be fused. I haven't spent too much time thinking about it yet though.

@mdempsky
Copy link

mdempsky commented Jul 14, 2020

and with slightly extra complexity, I think scmul_3 should be doable too.

Yeah, replacing the first few assignments with this seems to work as scmul_3, and is still measurably faster than two additions:

	x1, d1 := bits.Add64(arg1[0], arg1[0]<<1, 0)
	x2, d2 := bits.Add64(arg1[1], arg1[1]<<1|arg1[0]>>63, d1)
	x3, d3 := bits.Add64(arg1[2], arg1[2]<<1|arg1[1]>>63, d2)
	x4 := arg1[3] + (arg1[3]<<1 | arg1[2]>>63) + d3

@JasonGross
Copy link
Collaborator Author

Note that this isn't the kind of code we can write for Montgomery, because for the Montgomery code, we only have access to the prime as an integer, not to its representation as a sum/difference of taps. If you want to try to come up with a template that generates this sort of code for any prime, given just the integer representation and the bitwidth, I'm happy to integrate such a template and try to prove it correct.

@mdempsky
Copy link

Note that this isn't the kind of code we can write for Montgomery, because for the Montgomery code, we only have access to the prime as an integer, not to its representation as a sum/difference of taps.

Does this help?

$ python3 foo.py
P-224:	100000000000000000000000000000000000000000000000000000001 - 1000000000000000000000000 = ffffffffffffffffffffffffffffffff000000000000000000000001
P-256:	10000000000000000000000000000000000000000000000000000000100000050 - 200000421 = fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f
P-384:	1000000000000000000000000000000000000000000000000000000000000000100000000000000000000000100000000 - 200000001000000000000000000000001 = fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff
P-521:	20000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 - 1 = 1ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
X448:	10000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000 - 200000000000000000000000000000000000000000000000000000001 = fffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffffffffffffffffffffffffffffffffffffffffffffffffffff

$ cat foo.py
# Return n, with its least significant 1 bit cleared.
def masklow(n):
  return n & (n - 1)

# Return the smallest power of 2 >= n.
def roundup(n):
  if masklow(n) == 0: return n
  n <<= 1
  while True:
    m = masklow(n)
    if m == 0: return n
    n = m

# Returns a tuple (a, b) such that a - b = n,
# with the minimum number of bits set in a and b.
def taps(n):
  if n == 0: return 0, 0
  x = roundup(n)
  a, b = taps(x - n)
  return b + x, a

def solve(name, P):
  a, b = taps(P)
  if a - b != P: raise "oops"
  print("%s:\t%x - %x = %x" % (name, a, b, P))

solve("P-224", 2**224 - 2**96 + 1)
solve("P-256", 2**256 - 2**32 - 977)
solve("P-384", 2**384 - 2**128 - 2**96 + 2**32 - 1)
solve("P-521", 2**521 - 1)

solve("X448", 2**448 - 2**224 - 1)

@mdempsky
Copy link

Alternatively:

$ python3 foo.py
P-224:	[224, 0] - [96]
P-256:	[256, 32, 6, 4] - [33, 10, 5, 0]
P-384:	[384, 128, 32] - [129, 96, 0]
P-521:	[521] - [0]
X448:	[448, 224] - [225, 0]

$ cat foo.py
def lg2(n):
  i = 0
  while n > (1 << i):
    i += 1
  return i

def taps(n):
  if n == 0: return [], []
  x = lg2(n)
  a, b = taps((1 << x) - n)
  return [x] + b, a

def eval(xs):
  return sum(1 << x for x in xs)

def solve(name, P):
  a, b = taps(P)
  if eval(a) - eval(b) != P: raise "oops"
  print("%s:\t%s - %s" % (name, a, b))

solve("P-224", 2**224 - 2**96 + 1)
solve("P-256", 2**256 - 2**32 - 977)
solve("P-384", 2**384 - 2**128 - 2**96 + 2**32 - 1)
solve("P-521", 2**521 - 1)

solve("X448", 2**448 - 2**224 - 1)

@mdempsky
Copy link

mdempsky commented Jul 14, 2020

Here it is as 32-bit/64-bit limb sequences (in little endian):

$ python3 limb.py
p224/32:	[1, 0, 0, 0, 0, 0, 0, 1] - [0, 0, 0, 1, 0, 0, 0, 0]
p224/64:	[1, 0, 0, 4294967296] - [0, 4294967296, 0, 0]

p256/32:	[0, 0, 0, 1, 0, 0, 1, 0, 1] - [1, 0, 0, 0, 0, 0, 0, 1, 0]
p256/64:	[0, 4294967296, 0, 0, 1] - [1, 0, 0, 4294967295, 0]

p384/32:	[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] - [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
p384/64:	[4294967295, 0, 0, 0, 0, 0, 1] - [0, 4294967296, 1, 0, 0, 0, 0]

p521/32:	[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 512] - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p521/64:	[0, 0, 0, 0, 0, 0, 0, 0, 512] - [1, 0, 0, 0, 0, 0, 0, 0, 0]

secp256k1/32:	[0, 0, 0, 0, 0, 0, 0, 0, 1] - [977, 1, 0, 0, 0, 0, 0, 0, 0]
secp256k1/64:	[0, 0, 0, 0, 1] - [4294968273, 0, 0, 0, 0]

x25519/32:	[0, 0, 0, 0, 0, 0, 0, 2147483648] - [19, 0, 0, 0, 0, 0, 0, 0]
x25519/64:	[0, 0, 0, 9223372036854775808] - [19, 0, 0, 0]

x448/32:	[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] - [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
x448/64:	[0, 0, 0, 0, 0, 0, 0, 1] - [1, 0, 0, 4294967296, 0, 0, 0, 0]

poly1305/32:	[0, 0, 0, 0, 4] - [5, 0, 0, 0, 0]
poly1305/64:	[0, 0, 4] - [5, 0, 0]

p434/32:	[0, 0, 0, 0, 0, 0, 0, 0, 827895460, 2076597368, 0, 1828478935, 655848260, 144415] - [1, 0, 0, 0, 0, 0, 486539264, 37652869, 0, 0, 2117787562, 0, 0, 0]
p434/64:	[0, 0, 0, 0, 8918917783347572388, 7853257225132122198, 620258357900100] - [1, 0, 0, 161717841442111488, 0, 0, 0]

$ cat limb.py
def bits(n):
  i = 0
  while n >= (1 << i): i += 1
  return i

def taps(n, i, W):
  if i == 0: return [], []

  i -= 1
  shift = i * W

  # Compute nearest integer to "n / (1 << shift)" (rounding 0.5 up).
  half = (1 << shift) >> 1
  v = (n + half) >> shift

  delta = n - (v << shift)
  if delta >= 0:
      a, b = taps(delta, i, W)
  else:
      b, a = taps(-delta, i, W)
  return a + [v], b + [0]

def eval(xs, W):
  return sum(x << (i * W) for (i, x) in enumerate(xs))

def solve(name, P):
  for W in [32, 64]:
    a, b = taps(P, bits(P) // W + 1, W)
    if eval(a, W) - eval(b, W) != P: raise "oops"
    print("%s/%d:\t%s - %s" % (name, W, a, b))
  print()

solve("p224", 2**224 - 2**96 + 1)
solve("p256", 2**256 - 2**224 + 2**192 + 2**96 - 1)
solve("p384", 2**384 - 2**128 - 2**96 + 2**32 - 1)
solve("p521", 2**521 - 1)

solve("secp256k1", 2**256 - 2**32 - 977)

solve("x25519", 2**255 - 19)
solve("x448", 2**448 - 2**224 - 1)

solve("poly1305", 2**130 - 5)

solve("p434", 2**216 * 3**137 - 1)

@andres-erbsen
Copy link
Contributor

my reading of #847 (comment) (the example implementation above) is that it does saturated solinas multiplication followed by canonicalization, ignoring the assumption that the input is in Montgomery form but simplifying under the assumptions that z is a single limb and multiplying z by any limb of the prime never overflows. The implementation of solinas multiplication without the last two restrictions is stuck in technical debt at (#375 (comment)). The case here is particularly simple because z is just one limb, so it might be the best example to iterate on for #375 or just to implement directly. I believe we already have code for the canonicalization part.

As for how to access the solinas form of primes in Montgomery code, I think it would be just fine to pass them in when available or to infer them using code like @mdempsky posts aobe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants