diff --git a/synedrion/src/cggmp21/aux_gen.rs b/synedrion/src/cggmp21/aux_gen.rs index 6611f28..c6f6219 100644 --- a/synedrion/src/cggmp21/aux_gen.rs +++ b/synedrion/src/cggmp21/aux_gen.rs @@ -602,7 +602,6 @@ impl Round for Round3 { ( id, PublicAuxInfo { - el_gamal_pk: data.data.cap_y, paillier_pk: data.paillier_pk.into_wire(), rp_params: data.rp_params.to_wire(), }, @@ -612,7 +611,6 @@ impl Round for Round3 { let secret_aux = SecretAuxInfo { paillier_sk: self.context.paillier_sk.into_wire(), - el_gamal_sk: self.context.y, }; let aux_info = AuxInfo { @@ -654,18 +652,9 @@ mod tests { }) .collect::>(); - let aux_infos = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) + let _aux_infos = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) .unwrap() .results() .unwrap(); - - for (id, aux_info) in aux_infos.iter() { - for other_aux_info in aux_infos.values() { - assert_eq!( - aux_info.secret_aux.el_gamal_sk.mul_by_generator(), - other_aux_info.public_aux[id].el_gamal_pk - ); - } - } } } diff --git a/synedrion/src/cggmp21/entities.rs b/synedrion/src/cggmp21/entities.rs index 6119b02..7760a16 100644 --- a/synedrion/src/cggmp21/entities.rs +++ b/synedrion/src/cggmp21/entities.rs @@ -45,14 +45,12 @@ pub struct AuxInfo { #[serde(bound(deserialize = "SecretKeyPaillierWire: for <'x> Deserialize<'x>"))] pub(crate) struct SecretAuxInfo { pub(crate) paillier_sk: SecretKeyPaillierWire, - pub(crate) el_gamal_sk: Secret, // `y_i` } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound(serialize = "PublicKeyPaillierWire: Serialize"))] #[serde(bound(deserialize = "PublicKeyPaillierWire: for <'x> Deserialize<'x>"))] pub(crate) struct PublicAuxInfo { - pub(crate) el_gamal_pk: Point, // `Y_i` /// The Paillier public key. pub(crate) paillier_pk: PublicKeyPaillierWire, /// The ring-Pedersen parameters. @@ -68,14 +66,10 @@ pub(crate) struct AuxInfoPrecomputed { #[derive(Debug, Clone)] pub(crate) struct SecretAuxInfoPrecomputed { pub(crate) paillier_sk: SecretKeyPaillier, - #[allow(dead_code)] // TODO (#36): this will be needed for the 6-round presigning protocol. - pub(crate) el_gamal_sk: Secret, // `y_i` } #[derive(Debug, Clone)] pub(crate) struct PublicAuxInfoPrecomputed { - #[allow(dead_code)] // TODO (#36): this will be needed for the 6-round presigning protocol. - pub(crate) el_gamal_pk: Point, pub(crate) paillier_pk: PublicKeyPaillier, pub(crate) rp_params: RPParams, } @@ -259,7 +253,6 @@ impl AuxInfo { let secret_aux = (0..ids.len()) .map(|_| SecretAuxInfo { paillier_sk: SecretKeyPaillierWire::::random(rng), - el_gamal_sk: Secret::init_with(|| Scalar::random(rng)), }) .collect::>(); @@ -271,7 +264,6 @@ impl AuxInfo { id.clone(), PublicAuxInfo { paillier_pk: secret.paillier_sk.public_key(), - el_gamal_pk: secret.el_gamal_sk.mul_by_generator(), rp_params: RPParams::random(rng).to_wire(), }, ) @@ -297,7 +289,6 @@ impl AuxInfo { AuxInfoPrecomputed { secret_aux: SecretAuxInfoPrecomputed { paillier_sk: self.secret_aux.paillier_sk.clone().into_precomputed(), - el_gamal_sk: self.secret_aux.el_gamal_sk.clone(), }, public_aux: self .public_aux @@ -307,7 +298,6 @@ impl AuxInfo { ( id.clone(), PublicAuxInfoPrecomputed { - el_gamal_pk: public_aux.el_gamal_pk, paillier_pk: paillier_pk.clone(), rp_params: public_aux.rp_params.to_precomputed(), }, diff --git a/synedrion/src/cggmp21/key_refresh.rs b/synedrion/src/cggmp21/key_refresh.rs index 4daae83..bf671e5 100644 --- a/synedrion/src/cggmp21/key_refresh.rs +++ b/synedrion/src/cggmp21/key_refresh.rs @@ -1,26 +1,23 @@ -//! KeyRefresh protocol, in the paper Auxiliary Info. & Key Refresh in Three Rounds (Fig. 6). +//! KeyRefresh protocol, in the paper Auxiliary Info. & Key Refresh in Three Rounds (Fig. 7). //! This protocol generates an update to the secret key shares and new auxiliary parameters //! for ZK proofs (e.g. Paillier keys). use alloc::{ collections::{BTreeMap, BTreeSet}, - format, string::String, - vec::Vec, }; use core::{fmt::Debug, marker::PhantomData}; use crypto_bigint::BitOps; use manul::protocol::{ Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, LocalError, - NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, ProtocolValidationError, - ReceiveError, Round, RoundId, Serializer, + MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, + ProtocolValidationError, ReceiveError, Round, RoundId, Serializer, }; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; use super::{ - conversion::{secret_scalar_from_signed, secret_signed_from_scalar}, entities::{AuxInfo, KeyShareChange, PublicAuxInfo, SecretAuxInfo}, params::SchemeParams, sigma::{FacProof, ModProof, PrmProof, SchCommitment, SchProof, SchSecret}, @@ -28,13 +25,13 @@ use super::{ use crate::{ curve::{secret_split, Point, Scalar}, paillier::{ - Ciphertext, CiphertextWire, PaillierParams, PublicKeyPaillier, PublicKeyPaillierWire, RPParams, RPParamsWire, - RPSecret, SecretKeyPaillier, SecretKeyPaillierWire, + PaillierParams, PublicKeyPaillier, PublicKeyPaillierWire, RPParams, RPParamsWire, RPSecret, SecretKeyPaillier, + SecretKeyPaillierWire, }, tools::{ bitvec::BitVec, - hashing::{Chain, FofHasher, HashOutput}, - DowncastMap, Secret, Without, + hashing::{Chain, FofHasher, HashOutput, XofHasher}, + DowncastMap, SafeGet, Secret, Without, }, }; @@ -45,36 +42,75 @@ pub struct KeyRefreshProtocol(PhantomData<(P, I)>); impl Protocol for KeyRefreshProtocol { type Result = (KeyShareChange, AuxInfo); - type ProtocolError = KeyRefreshError

; + type ProtocolError = KeyRefreshError; + + fn verify_direct_message_is_invalid( + deserializer: &Deserializer, + round_id: RoundId, + message: &DirectMessage, + ) -> Result<(), MessageValidationError> { + match round_id { + r if r == RoundId::new(1) => message.verify_is_some(), + r if r == RoundId::new(2) => message.verify_is_some(), + r if r == RoundId::new(3) => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } + } + + fn verify_echo_broadcast_is_invalid( + deserializer: &Deserializer, + round_id: RoundId, + message: &EchoBroadcast, + ) -> Result<(), MessageValidationError> { + match round_id { + r if r == RoundId::new(1) => message.verify_is_not::(deserializer), + r if r == RoundId::new(2) => message.verify_is_some(), + r if r == RoundId::new(3) => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } + } + + fn verify_normal_broadcast_is_invalid( + deserializer: &Deserializer, + round_id: RoundId, + message: &NormalBroadcast, + ) -> Result<(), MessageValidationError> { + match round_id { + r if r == RoundId::new(1) => message.verify_is_some(), + r if r == RoundId::new(2) => message.verify_is_not::>(deserializer), + r if r == RoundId::new(3) => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } + } } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound(serialize = " - KeyRefreshErrorEnum

: Serialize, + KeyRefreshErrorEnum: Serialize, "))] #[serde(bound(deserialize = " - KeyRefreshErrorEnum

: for<'x> Deserialize<'x>, + KeyRefreshErrorEnum: for<'x> Deserialize<'x>, "))] -pub struct KeyRefreshError(KeyRefreshErrorEnum

); +pub struct KeyRefreshError(KeyRefreshErrorEnum); #[derive(Debug, Clone, Serialize, Deserialize)] -enum KeyRefreshErrorEnum { - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round2(String), - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round3(String), - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round3MismatchedSecret { - cap_c: CiphertextWire, - x: Scalar, - mu: ::Uint, - }, +enum KeyRefreshErrorEnum { + R2HashMismatch, + R2WrongIdsX, + R2WrongIdsY, + R2WrongIdsA, + R2PaillierModulusTooSmall, + R2RPModulusTooSmall, + R2NonZeroSumOfChanges, + R2PrmFailed, + R3ShareChangeMismatch, + R3ModFailed, + R3FacFailed, + R3WrongIdsHatPsi, + R3SchFailed(I), } -impl ProtocolError for KeyRefreshError

{ +impl ProtocolError for KeyRefreshError { fn description(&self) -> String { unimplemented!() } @@ -141,160 +177,108 @@ impl EntryPoint for KeyRefresh { let other_ids = self.all_ids.clone().without(id); - let ids_ordering = self - .all_ids - .iter() - .cloned() - .enumerate() - .map(|(idx, id)| (id, idx)) - .collect(); - let sid_hash = FofHasher::new_with_dst(b"SID") .chain_type::

() .chain(&shared_randomness) - .chain(&self.all_ids) .finalize(); - // $p_i$, $q_i$ + // Paillier secret key $p_i$, $q_i$ let paillier_sk = SecretKeyPaillierWire::::random(rng); - // $N_i$ + // Paillier public key $N_i$ let paillier_pk = paillier_sk.public_key(); - // El-Gamal key - let y = Secret::init_with(|| Scalar::random(rng)); - let cap_y = y.mul_by_generator(); + // Ring-Pedersen secret $\lambda$. + let rp_secret = RPSecret::random(rng); + // Ring-Pedersen parameters ($N$, $s$, $t$) bundled in a single object. + let rp_params = RPParams::random_with_secret(rng, &rp_secret); - // The secret and the commitment for the Schnorr PoK of the El-Gamal key - let tau_y = SchSecret::random(rng); // $\tau$ - let cap_b = SchCommitment::new(&tau_y); + let aux = (&sid_hash, id); + let psi = PrmProof::

::new(rng, &rp_secret, &rp_params, &aux); - // Secret share updates for each node ($x_i^j$ where $i$ is this party's index). - let x_to_send = self + // Ephemeral DH keys $y_{i,j}$ where $i$ is this party's index. + let ys = self .all_ids .iter() .cloned() - .zip(secret_split( - rng, - Secret::init_with(|| Scalar::ZERO), - self.all_ids.len(), - )) + .map(|id| (id, Secret::init_with(|| Scalar::random(rng)))) .collect::>(); + // Corresponding public keys $Y_{i,j}$. + let cap_ys = ys.iter().map(|(id, y)| (id.clone(), y.mul_by_generator())).collect(); - // Public counterparts of secret share updates ($X_i^j$ where $i$ is this party's index). - let cap_x_to_send = x_to_send.values().map(|x| x.mul_by_generator()).collect(); - - let rp_secret = RPSecret::random(rng); - // Ring-Pedersen parameters ($s$, $t$) bundled in a single object. - let rp_params = RPParams::random_with_secret(rng, &rp_secret); + // Secret share updates for each node ($x_{i,j}$ where $i$ is this party's index). + let split_zero = secret_split(rng, Secret::init_with(|| Scalar::ZERO), self.all_ids.len()); + let xs = self.all_ids.iter().cloned().zip(split_zero).collect::>(); - let aux = (&sid_hash, id); - let hat_psi = PrmProof::

::new(rng, &rp_secret, &rp_params, &aux); + // Public counterparts of secret share updates ($X_i^j$ where $i$ is this party's index). + let cap_xs = xs.iter().map(|(id, x)| (id.clone(), x.mul_by_generator())).collect(); - // The secrets share changes ($\tau_j$, not to be confused with $\tau$) - let tau_x = self + // Schnorr proof secrets $\tau_j$ + let taus = self .all_ids .iter() .map(|id| (id.clone(), SchSecret::random(rng))) .collect::>(); - // The commitments for share changes ($A_i^j$ where $i$ is this party's index) - let cap_a_to_send = tau_x.values().map(SchCommitment::new).collect(); + // Schnorr commitments for share changes ($A_{i,j}$ where $i$ is this party's index) + let cap_as = taus + .iter() + .map(|(id, tau)| (id.clone(), SchCommitment::new(tau))) + .collect(); - let rho = BitVec::random(rng, P::SECURITY_PARAMETER); + let rid_part = BitVec::random(rng, P::SECURITY_PARAMETER); let u = BitVec::random(rng, P::SECURITY_PARAMETER); - let data = PublicData1 { - cap_x_to_send, - cap_a_to_send, - cap_y, - cap_b, + // Note: typo in the paper, $V$ hashes in $B_i$ which is not present in the '24 version of the paper. + let r2_broadcast = Round2Broadcast { + cap_xs, + cap_ys, + cap_as, paillier_pk: paillier_pk.clone(), rp_params: rp_params.to_wire(), - hat_psi, - rho, + psi, + rid_part, u, }; - let data_precomp = PublicData1Precomp { - data, - paillier_pk: paillier_pk.into_precomputed(), - rp_params, - }; - let context = Context { paillier_sk: paillier_sk.into_precomputed(), - y, - x_to_send, - tau_x, - tau_y, - data_precomp, + rp_params, + xs, + ys, + taus, my_id: id.clone(), other_ids, + all_ids: self.all_ids, sid_hash, - ids_ordering, }; - Ok(BoxedRound::new_dynamic(Round1 { context })) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound(serialize = " - PrmProof

: Serialize, - "))] -#[serde(bound(deserialize = " - PrmProof

: for<'x> Deserialize<'x>, - "))] -struct PublicData1 { - cap_x_to_send: Vec, // $X_i^j$ where $i$ is this party's index - cap_a_to_send: Vec, // $A_i^j$ where $i$ is this party's index - cap_y: Point, - cap_b: SchCommitment, - paillier_pk: PublicKeyPaillierWire, // $N_i$ - rp_params: RPParamsWire, // $s_i$ and $t_i$ - hat_psi: PrmProof

, - rho: BitVec, - u: BitVec, -} + let round = Round1 { context, r2_broadcast }; -#[derive(Debug, Clone)] -struct PublicData1Precomp { - data: PublicData1

, - paillier_pk: PublicKeyPaillier, - rp_params: RPParams, + Ok(BoxedRound::new_dynamic(round)) + } } #[derive(Debug)] struct Context { paillier_sk: SecretKeyPaillier, - y: Secret, - x_to_send: BTreeMap>, // $x_i^j$ where $i$ is this party's index - tau_y: SchSecret, - tau_x: BTreeMap, - data_precomp: PublicData1Precomp

, + rp_params: RPParams, + xs: BTreeMap>, // $x_{i,j}$ where $i$ is this party's index + ys: BTreeMap>, // $y_{i,j}$ where $i$ is this party's index + taus: BTreeMap, my_id: I, other_ids: BTreeSet, + all_ids: BTreeSet, sid_hash: HashOutput, - ids_ordering: BTreeMap, -} - -impl PublicData1

{ - fn hash(&self, sid_hash: &HashOutput, id: &I) -> HashOutput { - FofHasher::new_with_dst(b"Auxiliary") - .chain(sid_hash) - .chain(id) - .chain(self) - .finalize() - } } #[derive(Debug)] -struct Round1 { +struct Round1 { context: Context, + r2_broadcast: Round2Broadcast, } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Round1Message { +struct Round1EchoBroadcast { cap_v: HashOutput, } @@ -326,16 +310,10 @@ impl Round for Round1 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - EchoBroadcast::new( - serializer, - Round1Message { - cap_v: self - .context - .data_precomp - .data - .hash(&self.context.sid_hash, &self.context.my_id), - }, - ) + let message = Round1EchoBroadcast { + cap_v: self.r2_broadcast.hash(&self.context.sid_hash, &self.context.my_id), + }; + EchoBroadcast::new(serializer, message) } fn receive_message( @@ -349,10 +327,11 @@ impl Round for Round1 { ) -> Result> { normal_broadcast.assert_is_none()?; direct_message.assert_is_none()?; - let echo_broadcast = echo_broadcast.deserialize::(deserializer)?; - Ok(Payload::new(Round1Payload { + let echo_broadcast = echo_broadcast.deserialize::(deserializer)?; + let payload = Round1Payload { cap_v: echo_broadcast.cap_v, - })) + }; + Ok(Payload::new(payload)) } fn finalize( @@ -363,28 +342,58 @@ impl Round for Round1 { ) -> Result, LocalError> { let payloads = payloads.downcast_all::()?; let others_cap_v = payloads.into_iter().map(|(id, payload)| (id, payload.cap_v)).collect(); - Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round2 { + let next_round = Round2 { context: self.context, + r2_broadcast: self.r2_broadcast, others_cap_v, - }))) + }; + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(next_round))) } } #[derive(Debug)] -struct Round2 { +struct Round2 { context: Context, + r2_broadcast: Round2Broadcast, others_cap_v: BTreeMap, } -#[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = "PublicData1

: Serialize"))] -#[serde(bound(deserialize = "PublicData1

: for<'x> Deserialize<'x>"))] -struct Round2Message { - data: PublicData1

, +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + PrmProof

: Serialize, +"))] +#[serde(bound(deserialize = " + PrmProof

: for<'x> Deserialize<'x>, +"))] +struct Round2Broadcast { + cap_xs: BTreeMap, // $X_{i,j}$ where $i$ is this party's index + cap_as: BTreeMap, // $A_{i,j}$ where $i$ is this party's index + cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index + paillier_pk: PublicKeyPaillierWire, // $N_i$ + rp_params: RPParamsWire, // $\hat{N}_i$, $s_i$, and $t_i$ + psi: PrmProof

, + rid_part: BitVec, + u: BitVec, +} + +impl Round2Broadcast { + fn hash(&self, sid_hash: &HashOutput, id: &I) -> HashOutput { + FofHasher::new_with_dst(b"Auxiliary") + .chain(sid_hash) + .chain(id) + .chain(self) + .finalize() + } } -struct Round2Payload { - data: PublicData1Precomp

, +#[derive(Debug)] +struct Round2Payload { + cap_xs: BTreeMap, // $X_{i,j}$ where $i$ is this party's index + cap_as: BTreeMap, // $A_{i,j}$ where $i$ is this party's index + cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index + paillier_pk: PublicKeyPaillier, // $N_i$ + rp_params: RPParams, // $\hat{N}_i$, $s_i$, and $t_i$ + rid_part: BitVec, } impl Round for Round2 { @@ -411,12 +420,7 @@ impl Round for Round2 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - NormalBroadcast::new( - serializer, - Round2Message { - data: self.context.data_precomp.data.clone(), - }, - ) + NormalBroadcast::new(serializer, self.r2_broadcast.clone()) } fn receive_message( @@ -430,48 +434,72 @@ impl Round for Round2 { ) -> Result> { echo_broadcast.assert_is_none()?; direct_message.assert_is_none()?; - let normal_broadcast = normal_broadcast.deserialize::>(deserializer)?; - let cap_v = self - .others_cap_v - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing `V` for {from:?}")))?; - - if &normal_broadcast.data.hash(&self.context.sid_hash, from) != cap_v { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Hash mismatch".into(), - )))); + let normal_broadcast = normal_broadcast.deserialize::>(deserializer)?; + + let cap_v = self.others_cap_v.safe_get("other nodes' `V`", from)?; + + if &normal_broadcast.hash(&self.context.sid_hash, from) != cap_v { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2HashMismatch, + ))); } - let paillier_pk = normal_broadcast.data.paillier_pk.clone().into_precomputed(); + if normal_broadcast.cap_xs.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2WrongIdsX, + ))); + } + + if normal_broadcast.cap_ys.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2WrongIdsY, + ))); + } - if (paillier_pk.modulus().bits_vartime() as usize) < 8 * P::SECURITY_PARAMETER { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Paillier modulus is too small".into(), - )))); + if normal_broadcast.cap_as.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2WrongIdsA, + ))); } - if normal_broadcast.data.cap_x_to_send.iter().sum::() != Point::IDENTITY { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Sum of X points is not identity".into(), - )))); + let paillier_pk = normal_broadcast.paillier_pk.clone().into_precomputed(); + let rp_params = normal_broadcast.rp_params.to_precomputed(); + + if paillier_pk.modulus().bits_vartime() < ::MODULUS_BITS - 2 { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2PaillierModulusTooSmall, + ))); } - let aux = (&self.context.sid_hash, &from); + if rp_params.modulus().bits_vartime() < ::MODULUS_BITS - 2 { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2RPModulusTooSmall, + ))); + } - let rp_params = normal_broadcast.data.rp_params.to_precomputed(); - if !normal_broadcast.data.hat_psi.verify(&rp_params, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "PRM verification failed".into(), - )))); + if normal_broadcast.cap_xs.values().sum::() != Point::IDENTITY { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2NonZeroSumOfChanges, + ))); } - Ok(Payload::new(Round2Payload { - data: PublicData1Precomp { - data: normal_broadcast.data, - paillier_pk, - rp_params, - }, - })) + let aux = (&self.context.sid_hash, &from); + if !normal_broadcast.psi.verify(&rp_params, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R2PrmFailed, + ))); + } + + let payload = Round2Payload:: { + cap_xs: normal_broadcast.cap_xs, + cap_as: normal_broadcast.cap_as, + cap_ys: normal_broadcast.cap_ys, + paillier_pk: normal_broadcast.paillier_pk.into_precomputed(), + rp_params: normal_broadcast.rp_params.to_precomputed(), + rid_part: normal_broadcast.rid_part, + }; + + Ok(Payload::new(payload)) } fn finalize( @@ -480,86 +508,107 @@ impl Round for Round2 { payloads: BTreeMap, _artifacts: BTreeMap, ) -> Result, LocalError> { - let payloads = payloads.downcast_all::>()?; - let others_data = payloads - .into_iter() - .map(|(id, payload)| (id, payload.data)) - .collect::>(); - let mut rho = self.context.data_precomp.data.rho.clone(); - for data in others_data.values() { - rho ^= &data.data.rho; + let mut payloads = payloads.downcast_all::>()?; + + let mut rid = self.r2_broadcast.rid_part.clone(); + for payload in payloads.values() { + rid ^= &payload.rid_part; } + // Add in the payload with this node's info, for the sake of uniformity + let my_payload = Round2Payload:: { + cap_xs: self.r2_broadcast.cap_xs, + cap_as: self.r2_broadcast.cap_as, + cap_ys: self.r2_broadcast.cap_ys, + paillier_pk: self.r2_broadcast.paillier_pk.into_precomputed(), + rp_params: self.r2_broadcast.rp_params.to_precomputed(), + rid_part: self.r2_broadcast.rid_part, + }; + payloads.insert(self.context.my_id.clone(), my_payload); + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round3::new( rng, self.context, - others_data, - rho, - )))) + payloads, + rid, + )?))) } } #[derive(Debug)] struct Round3 { context: Context, - rho: BitVec, - others_data: BTreeMap>, - psi_mod: ModProof

, - pi: SchProof, + rid: BitVec, + r2_payloads: BTreeMap>, + psi_prime: ModProof

, + hat_psis: BTreeMap, +} + +impl Round3 { + fn new( + rng: &mut impl CryptoRngCore, + context: Context, + r2_payloads: BTreeMap>, + rid: BitVec, + ) -> Result { + let my_id = &context.my_id; + let aux = (&context.sid_hash, my_id, &rid); + let psi_prime = ModProof::new(rng, &context.paillier_sk, &aux); + + let my_r2_payload = r2_payloads.safe_get("Round 2 payloads", my_id)?; + + let mut hat_psis = BTreeMap::new(); + for id in context.all_ids.iter() { + let x = context.xs.safe_get("secret share changes", id)?; + let tau = context.taus.safe_get("Schnorr secrets", id)?; + let cap_a = my_r2_payload.cap_as.safe_get("Schnorr commitments", id)?; + let cap_x = my_r2_payload.cap_xs.safe_get("public share changes", id)?; + let hat_psi = SchProof::new(tau, x, cap_a, cap_x, &aux); + hat_psis.insert(id.clone(), hat_psi); + } + + Ok(Self { + context, + r2_payloads, + rid, + psi_prime, + hat_psis, + }) + } } #[derive(Clone, Serialize, Deserialize)] #[serde(bound(serialize = " - ModProof

: Serialize, - FacProof

: Serialize, - CiphertextWire: Serialize, + SchProof: Serialize, "))] #[serde(bound(deserialize = " - ModProof

: for<'x> Deserialize<'x>, - FacProof

: for<'x> Deserialize<'x>, - CiphertextWire: for<'x> Deserialize<'x>, + SchProof: for<'x> Deserialize<'x>, "))] -struct PublicData2 { - psi_mod: ModProof

, // $\psi_i$, a P^{mod} for the Paillier modulus - phi: FacProof

, - pi: SchProof, - paillier_enc_x: CiphertextWire, // `C_j,i` - psi_sch: SchProof, // $psi_i^j$, a P^{sch} for the secret share change +struct Round3EchoBroadcast { + hat_psis: BTreeMap, } -impl Round3 { - fn new( - rng: &mut impl CryptoRngCore, - context: Context, - others_data: BTreeMap>, - rho: BitVec, - ) -> Self { - let aux = (&context.sid_hash, &context.my_id, &rho); - let psi_mod = ModProof::new(rng, &context.paillier_sk, &aux); - - let pi = SchProof::new( - &context.tau_y, - &context.y, - &context.data_precomp.data.cap_b, - &context.data_precomp.data.cap_y, - &aux, - ); - - Self { - context, - others_data, - rho, - psi_mod, - pi, - } - } +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + ModProof

: Serialize, +"))] +#[serde(bound(deserialize = " + ModProof

: for<'x> Deserialize<'x>, +"))] +struct Round3Broadcast { + psi_prime: ModProof

, } #[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = "PublicData2

: Serialize"))] -#[serde(bound(deserialize = "PublicData2

: for<'x> Deserialize<'x>"))] -struct Round3Message { - data2: PublicData2

, +#[serde(bound(serialize = " + FacProof

: Serialize, +"))] +#[serde(bound(deserialize = " + FacProof

: for<'x> Deserialize<'x>, +"))] +struct Round3DirectMessage { + psi: FacProof

, + cap_c: Scalar, } struct Round3Payload { @@ -589,64 +638,55 @@ impl Round for Round3 { &self.context.other_ids } + fn make_echo_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let message = Round3EchoBroadcast { + hat_psis: self.hat_psis.clone(), + }; + EchoBroadcast::new(serializer, message) + } + + fn make_normal_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let message = Round3Broadcast { + psi_prime: self.psi_prime.clone(), + }; + NormalBroadcast::new(serializer, message) + } + fn make_direct_message( &self, rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &I, ) -> Result<(DirectMessage, Option), LocalError> { - let aux = (&self.context.sid_hash, &self.context.my_id, &self.rho); - - let data = self - .others_data - .get(destination) - .ok_or_else(|| LocalError::new(format!("Missing data for {destination:?}")))?; - - let phi = FacProof::new(rng, &self.context.paillier_sk, &data.rp_params, &aux); - - let destination_idx = *self - .context - .ids_ordering - .get(destination) - .ok_or_else(|| LocalError::new("destination={destination:?} is missing in ids_ordering"))?; - - let x_secret = self - .context - .x_to_send - .get(destination) - .ok_or_else(|| LocalError::new("destination={destination} is missing in x_to_send"))?; - let x_public = self - .context - .data_precomp - .data - .cap_x_to_send - .get(destination_idx) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in cap_x_to_send"))?; - let ciphertext = Ciphertext::new(rng, &data.paillier_pk, &secret_signed_from_scalar::

(x_secret)); - let proof_secret = self - .context - .tau_x - .get(destination) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in tau_x"))?; - let commitment = self - .context - .data_precomp - .data - .cap_a_to_send - .get(destination_idx) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in cap_a_to_send"))?; - - let psi_sch = SchProof::new(proof_secret, x_secret, commitment, x_public, &aux); - - let data2 = PublicData2 { - psi_mod: self.psi_mod.clone(), - phi, - pi: self.pi.clone(), - paillier_enc_x: ciphertext.to_wire(), - psi_sch, - }; - - let dm = DirectMessage::new(serializer, Round3Message { data2 })?; + let my_id = &self.context.my_id; + let aux = (&self.context.sid_hash, my_id, &self.rid); + + let r2_payload = self.r2_payloads.safe_get("Round 2 payloads", destination)?; + + let psi = FacProof::

::new(rng, &self.context.paillier_sk, &r2_payload.rp_params, &aux); + + let cap_y = r2_payload.cap_ys.safe_get("Elgamal public keys", my_id)?; + let y = self.context.ys.safe_get("Elgamal secrets", destination)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&self.context.sid_hash) + .chain(&self.rid) + .chain(my_id) + .chain(&(cap_y * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + let x = self.context.xs.safe_get("secret share changes", destination)?; + let cap_c = *(x + &rho).expose_secret(); + + let message = Round3DirectMessage { psi, cap_c }; + let dm = DirectMessage::new(serializer, message)?; Ok((dm, None)) } @@ -659,89 +699,62 @@ impl Round for Round3 { normal_broadcast: NormalBroadcast, direct_message: DirectMessage, ) -> Result> { - echo_broadcast.assert_is_none()?; - normal_broadcast.assert_is_none()?; - let direct_message = direct_message.deserialize::>(deserializer)?; - - let sender_data = &self - .others_data - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing data for {from:?}")))?; - - let enc_x = direct_message - .data2 - .paillier_enc_x - .to_precomputed(&self.context.data_precomp.paillier_pk); - - let x = secret_scalar_from_signed::

(&enc_x.decrypt(&self.context.paillier_sk)); - - let my_idx = *self - .context - .ids_ordering - .get(&self.context.my_id) - .ok_or_else(|| LocalError::new(format!("my_id={:?} is missing in ids_ordering", self.context.my_id)))?; - - if x.mul_by_generator() - != *sender_data - .data - .cap_x_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_x_to_send"))? - { - let mu = enc_x.derive_randomizer(&self.context.paillier_sk); + let echo_broadcast = echo_broadcast.deserialize::>(deserializer)?; + let normal_broadcast = normal_broadcast.deserialize::>(deserializer)?; + let direct_message = direct_message.deserialize::>(deserializer)?; + + let my_id = &self.context.my_id; + + let r2_payload = self.r2_payloads.safe_get("Round 2 payloads", from)?; + let cap_y = r2_payload.cap_ys.safe_get("Elgamal public keys", my_id)?; + let y = self.context.ys.safe_get("Elgamal secrets", from)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&self.context.sid_hash) + .chain(&self.rid) + .chain(from) + .chain(&(cap_y * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + + let x = Secret::init_with(|| direct_message.cap_c - rho); + let my_cap_x = r2_payload.cap_xs.safe_get("public share changes", my_id)?; + if &x.mul_by_generator() != my_cap_x { + // TODO: can we put all the necessary info in the proof? return Err(ReceiveError::protocol(KeyRefreshError( - KeyRefreshErrorEnum::Round3MismatchedSecret { - cap_c: direct_message.data2.paillier_enc_x, - x: *x.expose_secret(), - mu: mu.expose(), - }, + KeyRefreshErrorEnum::R3ShareChangeMismatch, ))); } - let aux = (&self.context.sid_hash, &from, &self.rho); - - if !direct_message.data2.psi_mod.verify(rng, &sender_data.paillier_pk, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Mod proof verification failed".into(), - )))); + let aux = (&self.context.sid_hash, from, &self.rid); + if !normal_broadcast.psi_prime.verify(rng, &r2_payload.paillier_pk, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R3ModFailed, + ))); } if !direct_message - .data2 - .phi - .verify(&sender_data.paillier_pk, &self.context.data_precomp.rp_params, &aux) + .psi + .verify(&r2_payload.paillier_pk, &self.context.rp_params, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Fac proof verification failed".into(), - )))); + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R3FacFailed, + ))); } - if !direct_message - .data2 - .pi - .verify(&sender_data.data.cap_b, &sender_data.data.cap_y, &aux) - { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Sch proof verification (Y) failed".into(), - )))); + if echo_broadcast.hat_psis.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R3WrongIdsHatPsi, + ))); } - if !direct_message.data2.psi_sch.verify( - sender_data - .data - .cap_a_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_a_to_send"))?, - sender_data - .data - .cap_x_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_a_to_send"))?, - &aux, - ) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Sch proof verification (X) failed".into(), - )))); + for (id, hat_psi) in echo_broadcast.hat_psis.iter() { + let cap_a = r2_payload.cap_as.safe_get("Schnorr commitments", id)?; + let cap_x = r2_payload.cap_xs.safe_get("Public share changes", id)?; + if !hat_psi.verify(cap_a, cap_x, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError( + KeyRefreshErrorEnum::R3SchFailed(id.clone()), + ))); + } } Ok(Payload::new(Round3Payload { x })) @@ -754,58 +767,47 @@ impl Round for Round3 { _artifacts: BTreeMap, ) -> Result, LocalError> { let payloads = payloads.downcast_all::()?; - let others_x = payloads + + let my_id = &self.context.my_id; + + // Share changes from other nodes + let xs = payloads .into_iter() .map(|(id, payload)| (id, payload.x)) .collect::>(); - // The combined secret share change - let x_star = - others_x.into_values().sum::>() - + self.context.x_to_send.get(&self.context.my_id).ok_or_else(|| { - LocalError::new(format!("my_id={:?} is missing in x_to_send", self.context.my_id)) - })?; - - let my_id = self.context.my_id.clone(); - let mut all_ids = self.context.other_ids; - all_ids.insert(self.context.my_id); + // Share change generated by this node + let my_x = self.context.xs.safe_get("secret share changes", my_id)?; - let mut all_data = self.others_data; - all_data.insert(my_id.clone(), self.context.data_precomp); + // The combined secret share change + let x_star = xs.into_values().sum::>() + my_x; // The combined public share changes for each node - let cap_x_star = all_ids - .iter() - .enumerate() - .map(|(idx, id)| { - Ok(( - id.clone(), - all_data - .values() - .map(|data| data.data.cap_x_to_send.get(idx)) - .sum::>() - .ok_or_else(|| LocalError::new("idx={idx} is missing in cap_x_to_send"))?, - )) - }) - .collect::>()?; + let mut cap_x_star = BTreeMap::new(); + + for id_k in self.context.all_ids.iter() { + let mut result = Point::IDENTITY; + for payload in self.r2_payloads.values() { + let cap_x = payload.cap_xs.safe_get("public share changes", id_k)?; + result = result + *cap_x; + } + cap_x_star.insert(id_k.clone(), result); + } - let public_aux = all_data + let public_aux = self + .r2_payloads .into_iter() - .map(|(id, data)| { - ( - id, - PublicAuxInfo { - el_gamal_pk: data.data.cap_y, - paillier_pk: data.paillier_pk.into_wire(), - rp_params: data.rp_params.to_wire(), - }, - ) + .map(|(id, payload)| { + let aux_info = PublicAuxInfo { + paillier_pk: payload.paillier_pk.into_wire(), + rp_params: payload.rp_params.to_wire(), + }; + (id, aux_info) }) - .collect(); + .collect::>(); let secret_aux = SecretAuxInfo { paillier_sk: self.context.paillier_sk.into_wire(), - el_gamal_sk: self.context.y, }; let key_share_change = KeyShareChange { @@ -860,7 +862,7 @@ mod tests { .results() .unwrap(); - let (changes, aux_infos): (BTreeMap<_, _>, BTreeMap<_, _>) = results + let (changes, _aux_infos): (BTreeMap<_, _>, BTreeMap<_, _>) = results .into_iter() .map(|(id, (change, aux))| ((id, change), (id, aux))) .unzip(); @@ -875,15 +877,6 @@ mod tests { } } - for (id, aux_info) in aux_infos.iter() { - for other_aux_info in aux_infos.values() { - assert_eq!( - aux_info.secret_aux.el_gamal_sk.mul_by_generator(), - other_aux_info.public_aux[id].el_gamal_pk - ); - } - } - // The resulting sum of masks should be zero, since the combined secret key // should not change after applying the masks at each node. let mask_sum: Scalar = changes diff --git a/synedrion/src/paillier/encryption.rs b/synedrion/src/paillier/encryption.rs index 7d2b60a..e3f37e7 100644 --- a/synedrion/src/paillier/encryption.rs +++ b/synedrion/src/paillier/encryption.rs @@ -53,13 +53,6 @@ impl Randomizer

{ Self::new(pk, randomizer) } - /// Expose this secret randomizer. - /// - /// Supposed to be used in certain error branches where it is needed to generate a malicious behavior evidence. - pub fn expose(&self) -> P::Uint { - *self.randomizer.expose_secret() - } - /// Converts the randomizer to a publishable form by masking it with another randomizer and a public exponent. pub fn to_masked(&self, coeff: &Self, exponent: &PublicSigned) -> MaskedRandomizer

{ MaskedRandomizer( @@ -209,6 +202,7 @@ impl Ciphertext

{ } /// Encrypts the plaintext with a random randomizer. + #[cfg(test)] pub fn new(rng: &mut impl CryptoRngCore, pk: &PublicKeyPaillier

, plaintext: &SecretSigned) -> Self { Self::new_with_randomizer(pk, plaintext, &Randomizer::random(rng, pk)) } @@ -461,7 +455,10 @@ mod tests { let randomizer = Randomizer::random(&mut OsRng, pk); let ciphertext = Ciphertext::::new_with_randomizer(pk, &plaintext, &randomizer); let randomizer_back = ciphertext.derive_randomizer(&sk); - assert_eq!(randomizer.expose(), randomizer_back.expose()); + assert_eq!( + randomizer.randomizer.expose_secret(), + randomizer_back.randomizer.expose_secret() + ); } #[test]