From 373eabeca619c6d4281781d8568377c89b411863 Mon Sep 17 00:00:00 2001 From: cairo Date: Tue, 3 Sep 2024 14:29:26 +0200 Subject: [PATCH] Optimizations to P256 operations (#5181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hadrien Croubois Co-authored-by: Ernesto García --- contracts/utils/cryptography/P256.sol | 37 ++++++++++++++------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/contracts/utils/cryptography/P256.sol b/contracts/utils/cryptography/P256.sol index 83c9c975447..69717bdc51a 100644 --- a/contracts/utils/cryptography/P256.sol +++ b/contracts/utils/cryptography/P256.sol @@ -125,11 +125,12 @@ library P256 { return (0, 0); } + uint256 p = P; // cache P on the stack uint256 rx = uint256(r); - uint256 ry2 = addmod(mulmod(addmod(mulmod(rx, rx, P), A, P), rx, P), B, P); // weierstrass equation y² = x³ + a.x + b - uint256 ry = Math.modExp(ry2, P1DIV4, P); // This formula for sqrt work because P ≡ 3 (mod 4) - if (mulmod(ry, ry, P) != ry2) return (0, 0); // Sanity check - if (ry % 2 != v % 2) ry = P - ry; + uint256 ry2 = addmod(mulmod(addmod(mulmod(rx, rx, p), A, p), rx, p), B, p); // weierstrass equation y² = x³ + a.x + b + uint256 ry = Math.modExp(ry2, P1DIV4, p); // This formula for sqrt work because P ≡ 3 (mod 4) + if (mulmod(ry, ry, p) != ry2) return (0, 0); // Sanity check + if (ry % 2 != v % 2) ry = p - ry; JPoint[16] memory points = _preComputeJacobianPoints(rx, ry); uint256 w = Math.invModPrime(uint256(r), N); @@ -170,11 +171,13 @@ library P256 { */ function _affineFromJacobian(uint256 jx, uint256 jy, uint256 jz) private view returns (uint256 ax, uint256 ay) { if (jz == 0) return (0, 0); - uint256 zinv = Math.invModPrime(jz, P); - uint256 zzinv = mulmod(zinv, zinv, P); - uint256 zzzinv = mulmod(zzinv, zinv, P); - ax = mulmod(jx, zzinv, P); - ay = mulmod(jy, zzzinv, P); + uint256 p = P; // cache P on the stack + uint256 zinv = Math.invModPrime(jz, p); + assembly ("memory-safe") { + let zzinv := mulmod(zinv, zinv, p) + ax := mulmod(jx, zzinv, p) + ay := mulmod(jy, mulmod(zzinv, zinv, p), p) + } } /** @@ -190,12 +193,11 @@ library P256 { assembly ("memory-safe") { let p := P let z1 := mload(add(p1, 0x40)) + let zz1 := mulmod(z1, z1, p) // zz1 = z1² let s1 := mulmod(mload(add(p1, 0x20)), mulmod(mulmod(z2, z2, p), z2, p), p) // s1 = y1*z2³ - let s2 := mulmod(y2, mulmod(mulmod(z1, z1, p), z1, p), p) // s2 = y2*z1³ - let r := addmod(s2, sub(p, s1), p) // r = s2-s1 + let r := addmod(mulmod(y2, mulmod(zz1, z1, p), p), sub(p, s1), p) // r = s2-s1 = y2*z1³-s1 let u1 := mulmod(mload(p1), mulmod(z2, z2, p), p) // u1 = x1*z2² - let u2 := mulmod(x2, mulmod(z1, z1, p), p) // u2 = x2*z1² - let h := addmod(u2, sub(p, u1), p) // h = u2-u1 + let h := addmod(mulmod(x2, zz1, p), sub(p, u1), p) // h = u2-u1 = x2*z1²-u1 let hh := mulmod(h, h, p) // h² // x' = r²-h³-2*u1*h² @@ -226,12 +228,11 @@ library P256 { let zz := mulmod(z, z, p) let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y² let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴ - let t := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p) // t = m²-2*s - // x' = t - rx := t - // y' = m*(s-t)-8*y⁴ - ry := addmod(mulmod(m, addmod(s, sub(p, t), p), p), sub(p, mulmod(8, mulmod(yy, yy, p), p)), p) + // x' = t = m²-2*s + rx := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p) + // y' = m*(s-t)-8*y⁴ = m*(s-x')-8*y⁴ + ry := addmod(mulmod(m, addmod(s, sub(p, rx), p), p), sub(p, mulmod(8, mulmod(yy, yy, p), p)), p) // z' = 2*y*z rz := mulmod(2, mulmod(y, z, p), p) }