diff --git a/lib/ak-381/groth16.ak b/lib/ak-381/groth16.ak index 96e661c..02ead42 100644 --- a/lib/ak-381/groth16.ak +++ b/lib/ak-381/groth16.ak @@ -1,9 +1,8 @@ use aiken/builtin.{ - bls12_381_final_verify, bls12_381_g1_add, bls12_381_g1_compress, - bls12_381_g1_scalar_mul, bls12_381_g1_uncompress, bls12_381_g2_uncompress, - bls12_381_miller_loop, bls12_381_mul_miller_loop_result, + bls12_381_final_verify, bls12_381_g1_add, bls12_381_g1_scalar_mul, + bls12_381_g1_uncompress, bls12_381_g2_uncompress, bls12_381_miller_loop, + bls12_381_mul_miller_loop_result, } -use aiken/list.{head, map, map2, reduce, tail} pub type SnarkVerificationKey { nPublic: Int, @@ -23,7 +22,7 @@ pub type SnarkVerificationKey { // pub type Proof { - // G1Element + // G1Element piA: ByteArray, // G2Element piB: ByteArray, @@ -38,6 +37,29 @@ pub fn pairing(g1: ByteArray, g2: ByteArray) { ) } +pub fn derive( + vk_ic: List, + public: List, + result: G1Element, +) -> G1Element { + when vk_ic is { + [] -> result + [i, ..vk_ic] -> + when public is { + [] -> fail + [scalar, ..public] -> + derive( + vk_ic, + public, + bls12_381_g1_add( + result, + bls12_381_g1_scalar_mul(scalar, bls12_381_g1_uncompress(i)), + ), + ) + } + } +} + pub fn groth_verify( vk: SnarkVerificationKey, proof: Proof, @@ -47,14 +69,13 @@ pub fn groth_verify( let eAB = pairing(proof.piA, proof.piB) let eAlphaBeta = pairing(vk.vkAlpha, vk.vkBeta) - expect Some(vk_ic_head) = head(vk.vkIC) - expect Some(vk_ic_tail) = tail(vk.vkIC) - let vkICHead: G1Element = bls12_381_g1_uncompress(vk_ic_head) - let vkICTail: List = - map(vk_ic_tail, fn(n) { bls12_381_g1_uncompress(n) }) - let derived_vkIC = map2(public, vkICTail, bls12_381_g1_scalar_mul) - let vkI = reduce(derived_vkIC, vkICHead, bls12_381_g1_add) - let eIGamma = pairing(bls12_381_g1_compress(vkI), vk.vkGamma) + let vkI = + when vk.vkIC is { + [] -> fail @"empty vkIC?" + [head, ..tail] -> derive(tail, public, bls12_381_g1_uncompress(head)) + } + + let eIGamma = bls12_381_miller_loop(vkI, bls12_381_g2_uncompress(vk.vkGamma)) let eCDelta = pairing(proof.piC, vk.vkDelta) // * Miller functions