diff --git a/src/src/definitions.cairo b/src/src/definitions.cairo index a3bc38fc..c87fe9f2 100644 --- a/src/src/definitions.cairo +++ b/src/src/definitions.cairo @@ -789,6 +789,17 @@ fn get_min_one(curve_index: usize) -> u384 { return u384 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 }; } + +fn get_modulus(curve_index: usize) -> CircuitModulus { + match curve_index { + 0 => get_BN254_modulus(), + 1 => get_BLS12_381_modulus(), + 2 => get_SECP256K1_modulus(), + 3 => get_SECP256R1_modulus(), + 4 => get_ED25519_modulus(), + _ => panic_with_felt252('Invalid curve index'), + } +} // Returns the modulus of BLS12_381 #[inline(always)] fn get_BLS12_381_modulus() -> CircuitModulus { diff --git a/src/src/ec_ops_g2.cairo b/src/src/ec_ops_g2.cairo index 659656d8..fefa8096 100644 --- a/src/src/ec_ops_g2.cairo +++ b/src/src/ec_ops_g2.cairo @@ -9,7 +9,9 @@ use garaga::circuits::tower_circuits::{run_BLS12_381_FP2_MUL_circuit, run_BN254_ use core::option::Option; use garaga::core::circuit::AddInputResultTrait2; -use garaga::definitions::{G2Point, G2PointZero, get_BLS12_381_modulus, get_b2, get_a, get_p}; +use garaga::definitions::{ + G2Point, G2PointZero, get_BLS12_381_modulus, get_b2, get_a, get_p, get_modulus +}; use garaga::circuits::ec; use garaga::utils::u384_assert_zero; use garaga::basic_field_ops::neg_mod_p; @@ -254,10 +256,10 @@ fn ec_safe_add_with_options( fn ec_safe_add(P: G2Point, Q: G2Point, curve_index: usize) -> Option { // assumes that the points are on the curve and not the point at infinity. // Returns None if the points are the same and opposite y coordinates (Point at infinity) - let same_x = eq_mod_p(P.x0, P.x1, Q.x0, Q.x1); + let same_x = eq_mod_p(P.x0, P.x1, Q.x0, Q.x1, curve_index); if same_x { - let opposite_y = eq_neg_mod_p(P.y0, P.y1, Q.y0, Q.y1); + let opposite_y = eq_neg_mod_p(P.y0, P.y1, Q.y0, Q.y1, curve_index); if opposite_y { return Option::None; @@ -293,7 +295,7 @@ fn ec_mul_inner(pt: G2Point, mut bits: Array, curve_index: usize) -> Op // returns true if a == b mod p bls12-381 #[inline] -pub fn eq_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool { +pub fn eq_mod_p(a0: u384, a1: u384, b0: u384, b1: u384, curve_index: usize) -> bool { let _a0 = CE::> {}; let _a1 = CE::> {}; let _b0 = CE::> {}; @@ -301,7 +303,7 @@ pub fn eq_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool { let sub0 = circuit_sub(_a0, _b0); let sub1 = circuit_sub(_a1, _b1); - let modulus = get_BLS12_381_modulus(); + let modulus = get_modulus(curve_index); let outputs = (sub0, sub1) .new_inputs() @@ -318,7 +320,7 @@ pub fn eq_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool { // returns true if a == -b mod p bls12-381 #[inline] -pub fn eq_neg_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool { +pub fn eq_neg_mod_p(a0: u384, a1: u384, b0: u384, b1: u384, curve_index: usize) -> bool { let _a0 = CE::> {}; let _a1 = CE::> {}; let _b0 = CE::> {}; @@ -326,7 +328,7 @@ pub fn eq_neg_mod_p(a0: u384, a1: u384, b0: u384, b1: u384) -> bool { let check0 = circuit_add(_a0, _b0); let check1 = circuit_add(_a1, _b1); - let modulus = get_BLS12_381_modulus(); + let modulus = get_modulus(curve_index); let outputs = (check0, check1) .new_inputs() .next_2(a0)