From 235afa08994232e02a8275ef7b6496e35cfc1782 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 30 Dec 2024 15:23:55 -0800 Subject: [PATCH] Update KeyRefresh --- synedrion/src/cggmp21/aux_gen.rs | 13 +- synedrion/src/cggmp21/entities.rs | 10 - synedrion/src/cggmp21/key_refresh.rs | 948 ++++++++++++++------------- synedrion/src/paillier/encryption.rs | 13 +- 4 files changed, 507 insertions(+), 477 deletions(-) diff --git a/synedrion/src/cggmp21/aux_gen.rs b/synedrion/src/cggmp21/aux_gen.rs index cb45a5c..eab90dc 100644 --- a/synedrion/src/cggmp21/aux_gen.rs +++ b/synedrion/src/cggmp21/aux_gen.rs @@ -607,7 +607,6 @@ impl Round for Round3 { ( id, PublicAuxInfo { - el_gamal_pk: data.data.cap_y, paillier_pk: data.paillier_pk.into_wire(), rp_params: data.rp_params.to_wire(), }, @@ -617,7 +616,6 @@ impl Round for Round3 { let secret_aux = SecretAuxInfo { paillier_sk: self.context.paillier_sk.into_wire(), - el_gamal_sk: self.context.y, }; let aux_info = AuxInfo { @@ -659,18 +657,9 @@ mod tests { }) .collect::>(); - let aux_infos = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) + let _aux_infos = run_sync::<_, TestSessionParams>(&mut OsRng, entry_points) .unwrap() .results() .unwrap(); - - for (id, aux_info) in aux_infos.iter() { - for other_aux_info in aux_infos.values() { - assert_eq!( - aux_info.secret_aux.el_gamal_sk.mul_by_generator(), - other_aux_info.public_aux[id].el_gamal_pk - ); - } - } } } diff --git a/synedrion/src/cggmp21/entities.rs b/synedrion/src/cggmp21/entities.rs index 6119b02..7760a16 100644 --- a/synedrion/src/cggmp21/entities.rs +++ b/synedrion/src/cggmp21/entities.rs @@ -45,14 +45,12 @@ pub struct AuxInfo { #[serde(bound(deserialize = "SecretKeyPaillierWire: for <'x> Deserialize<'x>"))] pub(crate) struct SecretAuxInfo { pub(crate) paillier_sk: SecretKeyPaillierWire, - pub(crate) el_gamal_sk: Secret, // `y_i` } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound(serialize = "PublicKeyPaillierWire: Serialize"))] #[serde(bound(deserialize = "PublicKeyPaillierWire: for <'x> Deserialize<'x>"))] pub(crate) struct PublicAuxInfo { - pub(crate) el_gamal_pk: Point, // `Y_i` /// The Paillier public key. pub(crate) paillier_pk: PublicKeyPaillierWire, /// The ring-Pedersen parameters. @@ -68,14 +66,10 @@ pub(crate) struct AuxInfoPrecomputed { #[derive(Debug, Clone)] pub(crate) struct SecretAuxInfoPrecomputed { pub(crate) paillier_sk: SecretKeyPaillier, - #[allow(dead_code)] // TODO (#36): this will be needed for the 6-round presigning protocol. - pub(crate) el_gamal_sk: Secret, // `y_i` } #[derive(Debug, Clone)] pub(crate) struct PublicAuxInfoPrecomputed { - #[allow(dead_code)] // TODO (#36): this will be needed for the 6-round presigning protocol. - pub(crate) el_gamal_pk: Point, pub(crate) paillier_pk: PublicKeyPaillier, pub(crate) rp_params: RPParams, } @@ -259,7 +253,6 @@ impl AuxInfo { let secret_aux = (0..ids.len()) .map(|_| SecretAuxInfo { paillier_sk: SecretKeyPaillierWire::::random(rng), - el_gamal_sk: Secret::init_with(|| Scalar::random(rng)), }) .collect::>(); @@ -271,7 +264,6 @@ impl AuxInfo { id.clone(), PublicAuxInfo { paillier_pk: secret.paillier_sk.public_key(), - el_gamal_pk: secret.el_gamal_sk.mul_by_generator(), rp_params: RPParams::random(rng).to_wire(), }, ) @@ -297,7 +289,6 @@ impl AuxInfo { AuxInfoPrecomputed { secret_aux: SecretAuxInfoPrecomputed { paillier_sk: self.secret_aux.paillier_sk.clone().into_precomputed(), - el_gamal_sk: self.secret_aux.el_gamal_sk.clone(), }, public_aux: self .public_aux @@ -307,7 +298,6 @@ impl AuxInfo { ( id.clone(), PublicAuxInfoPrecomputed { - el_gamal_pk: public_aux.el_gamal_pk, paillier_pk: paillier_pk.clone(), rp_params: public_aux.rp_params.to_precomputed(), }, diff --git a/synedrion/src/cggmp21/key_refresh.rs b/synedrion/src/cggmp21/key_refresh.rs index b38b90b..bbfe301 100644 --- a/synedrion/src/cggmp21/key_refresh.rs +++ b/synedrion/src/cggmp21/key_refresh.rs @@ -1,26 +1,24 @@ -//! KeyRefresh protocol, in the paper Auxiliary Info. & Key Refresh in Three Rounds (Fig. 6). +//! KeyRefresh protocol, in the paper Auxiliary Info. & Key Refresh in Three Rounds (Fig. 7). //! This protocol generates an update to the secret key shares and new auxiliary parameters //! for ZK proofs (e.g. Paillier keys). -use alloc::{ - collections::{BTreeMap, BTreeSet}, - format, - string::String, - vec::Vec, +use alloc::collections::{BTreeMap, BTreeSet}; +use core::{ + fmt::{self, Debug, Display}, + marker::PhantomData, }; -use core::{fmt::Debug, marker::PhantomData}; use crypto_bigint::BitOps; use manul::protocol::{ Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage, - ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessages, Round, RoundId, Serializer, + ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, Round, RoundId, + Serializer, }; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; use super::{ - conversion::{secret_scalar_from_signed, secret_signed_from_scalar}, entities::{AuxInfo, KeyShareChange, PublicAuxInfo, SecretAuxInfo}, params::SchemeParams, sigma::{FacProof, ModProof, PrmProof, SchCommitment, SchProof, SchSecret}, @@ -28,13 +26,13 @@ use super::{ use crate::{ curve::{secret_split, Point, Scalar}, paillier::{ - Ciphertext, CiphertextWire, PaillierParams, PublicKeyPaillier, PublicKeyPaillierWire, RPParams, RPParamsWire, - RPSecret, SecretKeyPaillier, SecretKeyPaillierWire, + PaillierParams, PublicKeyPaillier, PublicKeyPaillierWire, RPParams, RPParamsWire, RPSecret, SecretKeyPaillier, + SecretKeyPaillierWire, }, tools::{ bitvec::BitVec, - hashing::{Chain, FofHasher, HashOutput}, - DowncastMap, Secret, Without, + hashing::{Chain, FofHasher, HashOutput, XofHasher}, + DowncastMap, GetRound, SafeGet, Secret, Without, }, }; @@ -45,78 +43,167 @@ pub struct KeyRefreshProtocol(PhantomData<(P, I)>); impl Protocol for KeyRefreshProtocol { type Result = (KeyShareChange, AuxInfo); - type ProtocolError = KeyRefreshError

; + type ProtocolError = KeyRefreshError; fn verify_direct_message_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &DirectMessage, + deserializer: &Deserializer, + round_id: &RoundId, + message: &DirectMessage, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &RoundId::new(1) => message.verify_is_some(), + r if r == &RoundId::new(2) => message.verify_is_some(), + r if r == &RoundId::new(3) => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } fn verify_echo_broadcast_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &EchoBroadcast, + deserializer: &Deserializer, + round_id: &RoundId, + message: &EchoBroadcast, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &RoundId::new(1) => message.verify_is_not::(deserializer), + r if r == &RoundId::new(2) => message.verify_is_some(), + r if r == &RoundId::new(3) => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } fn verify_normal_broadcast_is_invalid( - _deserializer: &Deserializer, - _round_id: &RoundId, - _message: &NormalBroadcast, + deserializer: &Deserializer, + round_id: &RoundId, + message: &NormalBroadcast, ) -> Result<(), MessageValidationError> { - unimplemented!() + match round_id { + r if r == &RoundId::new(1) => message.verify_is_some(), + r if r == &RoundId::new(2) => message.verify_is_not::>(deserializer), + r if r == &RoundId::new(3) => message.verify_is_not::>(deserializer), + _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), + } } } /// Provable KeyRefresh faults. -#[derive(displaydoc::Display, Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound(serialize = " - KeyRefreshErrorEnum

: Serialize, + KeyRefreshErrorEnum: Serialize, "))] #[serde(bound(deserialize = " - KeyRefreshErrorEnum

: for<'x> Deserialize<'x>, + KeyRefreshErrorEnum: for<'x> Deserialize<'x>, "))] -pub struct KeyRefreshError(KeyRefreshErrorEnum

); +pub struct KeyRefreshError { + error: KeyRefreshErrorEnum, + phantom: PhantomData

, +} -#[derive(Debug, Clone, Serialize, Deserialize)] -enum KeyRefreshErrorEnum { - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round2(String), - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round3(String), - // TODO (#43): this can be removed when error verification is added - #[allow(dead_code)] - Round3MismatchedSecret { - cap_c: CiphertextWire, - x: Scalar, - mu: ::Uint, - }, +impl KeyRefreshError { + fn new(error: KeyRefreshErrorEnum) -> Self { + Self { + error, + phantom: PhantomData, + } + } } -impl ProtocolError for KeyRefreshError

{ - type AssociatedData = (); +impl Display for KeyRefreshError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "{:?}", self.error) + } +} + +#[derive(displaydoc::Display, Debug, Clone, Serialize, Deserialize)] +enum KeyRefreshErrorEnum { + /// Round2: public data hash mismatch + R2HashMismatch, + /// Round2: wrong IDs in public shares map + R2WrongIdsX, + /// Round2: wrong IDs in Elgamal keys map + R2WrongIdsY, + /// Round2: wrong IDs in Schnorr commitments map + R2WrongIdsA, + /// Round2: Paillier modulus is too small + R2PaillierModulusTooSmall, + /// Round2: ring-Pedersent modulus is too small + R2RPModulusTooSmall, + /// Round2: sum of share changes is not zero + R2NonZeroSumOfChanges, + /// Round2: P_prm verification failed + R2PrmFailed, + /// Round3: secret share change does not match the public commitment + R3ShareChangeMismatch, + /// Round3: P_mod verification failed + R3ModFailed, + /// Round3: P_fac verification failed + R3FacFailed, + /// Round3: Wrong IDs in Schnorr proofs map + R3WrongIdsHatPsi, + /// Round3: P_sch verification failed + R3SchFailed(I), +} + +impl ProtocolError for KeyRefreshError { + type AssociatedData = BTreeSet; fn required_messages(&self) -> RequiredMessages { - unimplemented!() + match self.error { + KeyRefreshErrorEnum::R2HashMismatch => RequiredMessages::new( + RequiredMessageParts::normal_broadcast_only(), + Some([(RoundId::new(1), RequiredMessageParts::echo_broadcast_only())].into()), + None, + ), + KeyRefreshErrorEnum::R2WrongIdsX => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast_only(), None, None) + } + _ => unimplemented!(), + } } fn verify_messages_constitute_error( &self, - _deserializer: &Deserializer, - _guilty_party: &I, - _shared_randomness: &[u8], - _associated_data: &Self::AssociatedData, - _message: ProtocolMessage, - _previous_messages: BTreeMap, + deserializer: &Deserializer, + guilty_party: &I, + shared_randomness: &[u8], + associated_data: &Self::AssociatedData, + message: ProtocolMessage, + previous_messages: BTreeMap, _combined_echos: BTreeMap>, ) -> Result<(), ProtocolValidationError> { - unimplemented!() + let sid_hash = FofHasher::new_with_dst(b"SID") + .chain_type::

() + .chain(&shared_randomness) + .finalize(); + + match self.error { + KeyRefreshErrorEnum::R2HashMismatch => { + let r1_message = previous_messages + .get_round(1)? + .echo_broadcast + .deserialize::(deserializer)?; + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + if r2_message.hash(&sid_hash, guilty_party) != r1_message.cap_v { + Ok(()) + } else { + Err(ProtocolValidationError::InvalidEvidence( + "The received hash is valid".into(), + )) + } + } + KeyRefreshErrorEnum::R2WrongIdsX => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + if &r2_message.cap_xs.keys().cloned().collect::>() != associated_data { + Ok(()) + } else { + Err(ProtocolValidationError::InvalidEvidence("The IDs are correct".into())) + } + } + _ => unimplemented!(), + } } } @@ -153,160 +240,108 @@ impl EntryPoint for KeyRefresh { let other_ids = self.all_ids.clone().without(id); - let ids_ordering = self - .all_ids - .iter() - .cloned() - .enumerate() - .map(|(idx, id)| (id, idx)) - .collect(); - let sid_hash = FofHasher::new_with_dst(b"SID") .chain_type::

() .chain(&shared_randomness) - .chain(&self.all_ids) .finalize(); - // $p_i$, $q_i$ + // Paillier secret key $p_i$, $q_i$ let paillier_sk = SecretKeyPaillierWire::::random(rng); - // $N_i$ + // Paillier public key $N_i$ let paillier_pk = paillier_sk.public_key(); - // El-Gamal key - let y = Secret::init_with(|| Scalar::random(rng)); - let cap_y = y.mul_by_generator(); + // Ring-Pedersen secret $\lambda$. + let rp_secret = RPSecret::random(rng); + // Ring-Pedersen parameters ($N$, $s$, $t$) bundled in a single object. + let rp_params = RPParams::random_with_secret(rng, &rp_secret); - // The secret and the commitment for the Schnorr PoK of the El-Gamal key - let tau_y = SchSecret::random(rng); // $\tau$ - let cap_b = SchCommitment::new(&tau_y); + let aux = (&sid_hash, id); + let psi = PrmProof::

::new(rng, &rp_secret, &rp_params, &aux); - // Secret share updates for each node ($x_i^j$ where $i$ is this party's index). - let x_to_send = self + // Ephemeral DH keys $y_{i,j}$ where $i$ is this party's index. + let ys = self .all_ids .iter() .cloned() - .zip(secret_split( - rng, - Secret::init_with(|| Scalar::ZERO), - self.all_ids.len(), - )) + .map(|id| (id, Secret::init_with(|| Scalar::random(rng)))) .collect::>(); + // Corresponding public keys $Y_{i,j}$. + let cap_ys = ys.iter().map(|(id, y)| (id.clone(), y.mul_by_generator())).collect(); - // Public counterparts of secret share updates ($X_i^j$ where $i$ is this party's index). - let cap_x_to_send = x_to_send.values().map(|x| x.mul_by_generator()).collect(); + // Secret share updates for each node ($x_{i,j}$ where $i$ is this party's index). + let split_zero = secret_split(rng, Secret::init_with(|| Scalar::ZERO), self.all_ids.len()); + let xs = self.all_ids.iter().cloned().zip(split_zero).collect::>(); - let rp_secret = RPSecret::random(rng); - // Ring-Pedersen parameters ($s$, $t$) bundled in a single object. - let rp_params = RPParams::random_with_secret(rng, &rp_secret); - - let aux = (&sid_hash, id); - let hat_psi = PrmProof::

::new(rng, &rp_secret, &rp_params, &aux); + // Public counterparts of secret share updates ($X_i^j$ where $i$ is this party's index). + let cap_xs = xs.iter().map(|(id, x)| (id.clone(), x.mul_by_generator())).collect(); - // The secrets share changes ($\tau_j$, not to be confused with $\tau$) - let tau_x = self + // Schnorr proof secrets $\tau_j$ + let taus = self .all_ids .iter() .map(|id| (id.clone(), SchSecret::random(rng))) .collect::>(); - // The commitments for share changes ($A_i^j$ where $i$ is this party's index) - let cap_a_to_send = tau_x.values().map(SchCommitment::new).collect(); + // Schnorr commitments for share changes ($A_{i,j}$ where $i$ is this party's index) + let cap_as = taus + .iter() + .map(|(id, tau)| (id.clone(), SchCommitment::new(tau))) + .collect(); - let rho = BitVec::random(rng, P::SECURITY_PARAMETER); + let rid_part = BitVec::random(rng, P::SECURITY_PARAMETER); let u = BitVec::random(rng, P::SECURITY_PARAMETER); - let data = PublicData1 { - cap_x_to_send, - cap_a_to_send, - cap_y, - cap_b, + // Note: typo in the paper, $V$ hashes in $B_i$ which is not present in the '24 version of the paper. + let r2_broadcast = Round2Broadcast { + cap_xs, + cap_ys, + cap_as, paillier_pk: paillier_pk.clone(), rp_params: rp_params.to_wire(), - hat_psi, - rho, + psi, + rid_part, u, }; - let data_precomp = PublicData1Precomp { - data, - paillier_pk: paillier_pk.into_precomputed(), - rp_params, - }; - let context = Context { paillier_sk: paillier_sk.into_precomputed(), - y, - x_to_send, - tau_x, - tau_y, - data_precomp, + rp_params, + xs, + ys, + taus, my_id: id.clone(), other_ids, + all_ids: self.all_ids, sid_hash, - ids_ordering, }; - Ok(BoxedRound::new_dynamic(Round1 { context })) - } -} + let round = Round1 { context, r2_broadcast }; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound(serialize = " - PrmProof

: Serialize, - "))] -#[serde(bound(deserialize = " - PrmProof

: for<'x> Deserialize<'x>, - "))] -struct PublicData1 { - cap_x_to_send: Vec, // $X_i^j$ where $i$ is this party's index - cap_a_to_send: Vec, // $A_i^j$ where $i$ is this party's index - cap_y: Point, - cap_b: SchCommitment, - paillier_pk: PublicKeyPaillierWire, // $N_i$ - rp_params: RPParamsWire, // $s_i$ and $t_i$ - hat_psi: PrmProof

, - rho: BitVec, - u: BitVec, -} - -#[derive(Debug, Clone)] -struct PublicData1Precomp { - data: PublicData1

, - paillier_pk: PublicKeyPaillier, - rp_params: RPParams, + Ok(BoxedRound::new_dynamic(round)) + } } #[derive(Debug)] struct Context { paillier_sk: SecretKeyPaillier, - y: Secret, - x_to_send: BTreeMap>, // $x_i^j$ where $i$ is this party's index - tau_y: SchSecret, - tau_x: BTreeMap, - data_precomp: PublicData1Precomp

, + rp_params: RPParams, + xs: BTreeMap>, // $x_{i,j}$ where $i$ is this party's index + ys: BTreeMap>, // $y_{i,j}$ where $i$ is this party's index + taus: BTreeMap, my_id: I, other_ids: BTreeSet, + all_ids: BTreeSet, sid_hash: HashOutput, - ids_ordering: BTreeMap, -} - -impl PublicData1

{ - fn hash(&self, sid_hash: &HashOutput, id: &I) -> HashOutput { - FofHasher::new_with_dst(b"Auxiliary") - .chain(sid_hash) - .chain(id) - .chain(self) - .finalize() - } } #[derive(Debug)] -struct Round1 { +struct Round1 { context: Context, + r2_broadcast: Round2Broadcast, } #[derive(Debug, Clone, Serialize, Deserialize)] -struct Round1Message { +struct Round1EchoBroadcast { cap_v: HashOutput, } @@ -338,16 +373,10 @@ impl Round for Round1 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - EchoBroadcast::new( - serializer, - Round1Message { - cap_v: self - .context - .data_precomp - .data - .hash(&self.context.sid_hash, &self.context.my_id), - }, - ) + let message = Round1EchoBroadcast { + cap_v: self.r2_broadcast.hash(&self.context.sid_hash, &self.context.my_id), + }; + EchoBroadcast::new(serializer, message) } fn receive_message( @@ -359,10 +388,13 @@ impl Round for Round1 { ) -> Result> { message.normal_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; - let echo_broadcast = message.echo_broadcast.deserialize::(deserializer)?; - Ok(Payload::new(Round1Payload { + let echo_broadcast = message + .echo_broadcast + .deserialize::(deserializer)?; + let payload = Round1Payload { cap_v: echo_broadcast.cap_v, - })) + }; + Ok(Payload::new(payload)) } fn finalize( @@ -373,28 +405,58 @@ impl Round for Round1 { ) -> Result, LocalError> { let payloads = payloads.downcast_all::()?; let others_cap_v = payloads.into_iter().map(|(id, payload)| (id, payload.cap_v)).collect(); - Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round2 { + let next_round = Round2 { context: self.context, + r2_broadcast: self.r2_broadcast, others_cap_v, - }))) + }; + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(next_round))) } } #[derive(Debug)] -struct Round2 { +struct Round2 { context: Context, + r2_broadcast: Round2Broadcast, others_cap_v: BTreeMap, } -#[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = "PublicData1

: Serialize"))] -#[serde(bound(deserialize = "PublicData1

: for<'x> Deserialize<'x>"))] -struct Round2Message { - data: PublicData1

, +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + PrmProof

: Serialize, +"))] +#[serde(bound(deserialize = " + PrmProof

: for<'x> Deserialize<'x>, +"))] +struct Round2Broadcast { + cap_xs: BTreeMap, // $X_{i,j}$ where $i$ is this party's index + cap_as: BTreeMap, // $A_{i,j}$ where $i$ is this party's index + cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index + paillier_pk: PublicKeyPaillierWire, // $N_i$ + rp_params: RPParamsWire, // $\hat{N}_i$, $s_i$, and $t_i$ + psi: PrmProof

, + rid_part: BitVec, + u: BitVec, } -struct Round2Payload { - data: PublicData1Precomp

, +impl Round2Broadcast { + fn hash(&self, sid_hash: &HashOutput, id: &I) -> HashOutput { + FofHasher::new_with_dst(b"Auxiliary") + .chain(sid_hash) + .chain(id) + .chain(self) + .finalize() + } +} + +#[derive(Debug)] +struct Round2Payload { + cap_xs: BTreeMap, // $X_{i,j}$ where $i$ is this party's index + cap_as: BTreeMap, // $A_{i,j}$ where $i$ is this party's index + cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index + paillier_pk: PublicKeyPaillier, // $N_i$ + rp_params: RPParams, // $\hat{N}_i$, $s_i$, and $t_i$ + rid_part: BitVec, } impl Round for Round2 { @@ -421,12 +483,7 @@ impl Round for Round2 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - NormalBroadcast::new( - serializer, - Round2Message { - data: self.context.data_precomp.data.clone(), - }, - ) + NormalBroadcast::new(serializer, self.r2_broadcast.clone()) } fn receive_message( @@ -438,48 +495,74 @@ impl Round for Round2 { ) -> Result> { message.echo_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; - let normal_broadcast = message.normal_broadcast.deserialize::>(deserializer)?; - let cap_v = self - .others_cap_v - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing `V` for {from:?}")))?; - - if &normal_broadcast.data.hash(&self.context.sid_hash, from) != cap_v { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Hash mismatch".into(), - )))); + let normal_broadcast = message + .normal_broadcast + .deserialize::>(deserializer)?; + + let cap_v = self.others_cap_v.safe_get("other nodes' `V`", from)?; + + if &normal_broadcast.hash(&self.context.sid_hash, from) != cap_v { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2HashMismatch, + ))); + } + + if normal_broadcast.cap_xs.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2WrongIdsX, + ))); } - let paillier_pk = normal_broadcast.data.paillier_pk.clone().into_precomputed(); + if normal_broadcast.cap_ys.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2WrongIdsY, + ))); + } - if (paillier_pk.modulus().bits_vartime() as usize) < 8 * P::SECURITY_PARAMETER { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Paillier modulus is too small".into(), - )))); + if normal_broadcast.cap_as.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2WrongIdsA, + ))); } - if normal_broadcast.data.cap_x_to_send.iter().sum::() != Point::IDENTITY { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "Sum of X points is not identity".into(), - )))); + let paillier_pk = normal_broadcast.paillier_pk.clone().into_precomputed(); + let rp_params = normal_broadcast.rp_params.to_precomputed(); + + if paillier_pk.modulus().bits_vartime() < ::MODULUS_BITS - 2 { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2PaillierModulusTooSmall, + ))); } - let aux = (&self.context.sid_hash, &from); + if rp_params.modulus().bits_vartime() < ::MODULUS_BITS - 2 { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2RPModulusTooSmall, + ))); + } + + if normal_broadcast.cap_xs.values().sum::() != Point::IDENTITY { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2NonZeroSumOfChanges, + ))); + } - let rp_params = normal_broadcast.data.rp_params.to_precomputed(); - if !normal_broadcast.data.hat_psi.verify(&rp_params, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round2( - "PRM verification failed".into(), - )))); + let aux = (&self.context.sid_hash, &from); + if !normal_broadcast.psi.verify(&rp_params, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R2PrmFailed, + ))); } - Ok(Payload::new(Round2Payload { - data: PublicData1Precomp { - data: normal_broadcast.data, - paillier_pk, - rp_params, - }, - })) + let payload = Round2Payload:: { + cap_xs: normal_broadcast.cap_xs, + cap_as: normal_broadcast.cap_as, + cap_ys: normal_broadcast.cap_ys, + paillier_pk: normal_broadcast.paillier_pk.into_precomputed(), + rp_params: normal_broadcast.rp_params.to_precomputed(), + rid_part: normal_broadcast.rid_part, + }; + + Ok(Payload::new(payload)) } fn finalize( @@ -488,86 +571,107 @@ impl Round for Round2 { payloads: BTreeMap, _artifacts: BTreeMap, ) -> Result, LocalError> { - let payloads = payloads.downcast_all::>()?; - let others_data = payloads - .into_iter() - .map(|(id, payload)| (id, payload.data)) - .collect::>(); - let mut rho = self.context.data_precomp.data.rho.clone(); - for data in others_data.values() { - rho ^= &data.data.rho; + let mut payloads = payloads.downcast_all::>()?; + + let mut rid = self.r2_broadcast.rid_part.clone(); + for payload in payloads.values() { + rid ^= &payload.rid_part; } + // Add in the payload with this node's info, for the sake of uniformity + let my_payload = Round2Payload:: { + cap_xs: self.r2_broadcast.cap_xs, + cap_as: self.r2_broadcast.cap_as, + cap_ys: self.r2_broadcast.cap_ys, + paillier_pk: self.r2_broadcast.paillier_pk.into_precomputed(), + rp_params: self.r2_broadcast.rp_params.to_precomputed(), + rid_part: self.r2_broadcast.rid_part, + }; + payloads.insert(self.context.my_id.clone(), my_payload); + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(Round3::new( rng, self.context, - others_data, - rho, - )))) + payloads, + rid, + )?))) } } #[derive(Debug)] struct Round3 { context: Context, - rho: BitVec, - others_data: BTreeMap>, - psi_mod: ModProof

, - pi: SchProof, -} - -#[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = " - ModProof

: Serialize, - FacProof

: Serialize, - CiphertextWire: Serialize, -"))] -#[serde(bound(deserialize = " - ModProof

: for<'x> Deserialize<'x>, - FacProof

: for<'x> Deserialize<'x>, - CiphertextWire: for<'x> Deserialize<'x>, -"))] -struct PublicData2 { - psi_mod: ModProof

, // $\psi_i$, a P^{mod} for the Paillier modulus - phi: FacProof

, - pi: SchProof, - paillier_enc_x: CiphertextWire, // `C_j,i` - psi_sch: SchProof, // $psi_i^j$, a P^{sch} for the secret share change + rid: BitVec, + r2_payloads: BTreeMap>, + psi_prime: ModProof

, + hat_psis: BTreeMap, } impl Round3 { fn new( rng: &mut impl CryptoRngCore, context: Context, - others_data: BTreeMap>, - rho: BitVec, - ) -> Self { - let aux = (&context.sid_hash, &context.my_id, &rho); - let psi_mod = ModProof::new(rng, &context.paillier_sk, &aux); - - let pi = SchProof::new( - &context.tau_y, - &context.y, - &context.data_precomp.data.cap_b, - &context.data_precomp.data.cap_y, - &aux, - ); + r2_payloads: BTreeMap>, + rid: BitVec, + ) -> Result { + let my_id = &context.my_id; + let aux = (&context.sid_hash, my_id, &rid); + let psi_prime = ModProof::new(rng, &context.paillier_sk, &aux); + + let my_r2_payload = r2_payloads.safe_get("Round 2 payloads", my_id)?; + + let mut hat_psis = BTreeMap::new(); + for id in context.all_ids.iter() { + let x = context.xs.safe_get("secret share changes", id)?; + let tau = context.taus.safe_get("Schnorr secrets", id)?; + let cap_a = my_r2_payload.cap_as.safe_get("Schnorr commitments", id)?; + let cap_x = my_r2_payload.cap_xs.safe_get("public share changes", id)?; + let hat_psi = SchProof::new(tau, x, cap_a, cap_x, &aux); + hat_psis.insert(id.clone(), hat_psi); + } - Self { + Ok(Self { context, - others_data, - rho, - psi_mod, - pi, - } + r2_payloads, + rid, + psi_prime, + hat_psis, + }) } } #[derive(Clone, Serialize, Deserialize)] -#[serde(bound(serialize = "PublicData2

: Serialize"))] -#[serde(bound(deserialize = "PublicData2

: for<'x> Deserialize<'x>"))] -struct Round3Message { - data2: PublicData2

, +#[serde(bound(serialize = " + SchProof: Serialize, +"))] +#[serde(bound(deserialize = " + SchProof: for<'x> Deserialize<'x>, +"))] +struct Round3EchoBroadcast { + hat_psis: BTreeMap, +} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + ModProof

: Serialize, +"))] +#[serde(bound(deserialize = " + ModProof

: for<'x> Deserialize<'x>, +"))] +struct Round3Broadcast { + psi_prime: ModProof

, +} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + FacProof

: Serialize, +"))] +#[serde(bound(deserialize = " + FacProof

: for<'x> Deserialize<'x>, +"))] +struct Round3DirectMessage { + psi: FacProof

, + cap_c: Scalar, } struct Round3Payload { @@ -597,64 +701,55 @@ impl Round for Round3 { &self.context.other_ids } + fn make_echo_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let message = Round3EchoBroadcast { + hat_psis: self.hat_psis.clone(), + }; + EchoBroadcast::new(serializer, message) + } + + fn make_normal_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let message = Round3Broadcast { + psi_prime: self.psi_prime.clone(), + }; + NormalBroadcast::new(serializer, message) + } + fn make_direct_message( &self, rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &I, ) -> Result<(DirectMessage, Option), LocalError> { - let aux = (&self.context.sid_hash, &self.context.my_id, &self.rho); - - let data = self - .others_data - .get(destination) - .ok_or_else(|| LocalError::new(format!("Missing data for {destination:?}")))?; - - let phi = FacProof::new(rng, &self.context.paillier_sk, &data.rp_params, &aux); - - let destination_idx = *self - .context - .ids_ordering - .get(destination) - .ok_or_else(|| LocalError::new("destination={destination:?} is missing in ids_ordering"))?; - - let x_secret = self - .context - .x_to_send - .get(destination) - .ok_or_else(|| LocalError::new("destination={destination} is missing in x_to_send"))?; - let x_public = self - .context - .data_precomp - .data - .cap_x_to_send - .get(destination_idx) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in cap_x_to_send"))?; - let ciphertext = Ciphertext::new(rng, &data.paillier_pk, &secret_signed_from_scalar::

(x_secret)); - let proof_secret = self - .context - .tau_x - .get(destination) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in tau_x"))?; - let commitment = self - .context - .data_precomp - .data - .cap_a_to_send - .get(destination_idx) - .ok_or_else(|| LocalError::new("destination_idx={destination_idx} is missing in cap_a_to_send"))?; - - let psi_sch = SchProof::new(proof_secret, x_secret, commitment, x_public, &aux); - - let data2 = PublicData2 { - psi_mod: self.psi_mod.clone(), - phi, - pi: self.pi.clone(), - paillier_enc_x: ciphertext.to_wire(), - psi_sch, - }; - - let dm = DirectMessage::new(serializer, Round3Message { data2 })?; + let my_id = &self.context.my_id; + let aux = (&self.context.sid_hash, my_id, &self.rid); + + let r2_payload = self.r2_payloads.safe_get("Round 2 payloads", destination)?; + + let psi = FacProof::

::new(rng, &self.context.paillier_sk, &r2_payload.rp_params, &aux); + + let cap_y = r2_payload.cap_ys.safe_get("Elgamal public keys", my_id)?; + let y = self.context.ys.safe_get("Elgamal secrets", destination)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&self.context.sid_hash) + .chain(&self.rid) + .chain(my_id) + .chain(&(cap_y * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + let x = self.context.xs.safe_get("secret share changes", destination)?; + let cap_c = *(x + &rho).expose_secret(); + + let message = Round3DirectMessage { psi, cap_c }; + let dm = DirectMessage::new(serializer, message)?; Ok((dm, None)) } @@ -665,89 +760,68 @@ impl Round for Round3 { from: &I, message: ProtocolMessage, ) -> Result> { - message.echo_broadcast.assert_is_none()?; - message.normal_broadcast.assert_is_none()?; - let direct_message = message.direct_message.deserialize::>(deserializer)?; - - let sender_data = &self - .others_data - .get(from) - .ok_or_else(|| LocalError::new(format!("Missing data for {from:?}")))?; - - let enc_x = direct_message - .data2 - .paillier_enc_x - .to_precomputed(&self.context.data_precomp.paillier_pk); - - let x = secret_scalar_from_signed::

(&enc_x.decrypt(&self.context.paillier_sk)); - - let my_idx = *self - .context - .ids_ordering - .get(&self.context.my_id) - .ok_or_else(|| LocalError::new(format!("my_id={:?} is missing in ids_ordering", self.context.my_id)))?; - - if x.mul_by_generator() - != *sender_data - .data - .cap_x_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_x_to_send"))? - { - let mu = enc_x.derive_randomizer(&self.context.paillier_sk); - return Err(ReceiveError::protocol(KeyRefreshError( - KeyRefreshErrorEnum::Round3MismatchedSecret { - cap_c: direct_message.data2.paillier_enc_x, - x: *x.expose_secret(), - mu: mu.expose(), - }, + let echo_broadcast = message + .echo_broadcast + .deserialize::>(deserializer)?; + let normal_broadcast = message + .normal_broadcast + .deserialize::>(deserializer)?; + let direct_message = message + .direct_message + .deserialize::>(deserializer)?; + + let my_id = &self.context.my_id; + + let r2_payload = self.r2_payloads.safe_get("Round 2 payloads", from)?; + let cap_y = r2_payload.cap_ys.safe_get("Elgamal public keys", my_id)?; + let y = self.context.ys.safe_get("Elgamal secrets", from)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&self.context.sid_hash) + .chain(&self.rid) + .chain(from) + .chain(&(cap_y * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + + let x = Secret::init_with(|| direct_message.cap_c - rho); + let my_cap_x = r2_payload.cap_xs.safe_get("public share changes", my_id)?; + if &x.mul_by_generator() != my_cap_x { + // TODO: can we put all the necessary info in the proof? + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3ShareChangeMismatch, ))); } - let aux = (&self.context.sid_hash, &from, &self.rho); - - if !direct_message.data2.psi_mod.verify(rng, &sender_data.paillier_pk, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Mod proof verification failed".into(), - )))); + let aux = (&self.context.sid_hash, from, &self.rid); + if !normal_broadcast.psi_prime.verify(rng, &r2_payload.paillier_pk, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3ModFailed, + ))); } if !direct_message - .data2 - .phi - .verify(&sender_data.paillier_pk, &self.context.data_precomp.rp_params, &aux) + .psi + .verify(&r2_payload.paillier_pk, &self.context.rp_params, &aux) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Fac proof verification failed".into(), - )))); + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3FacFailed, + ))); } - if !direct_message - .data2 - .pi - .verify(&sender_data.data.cap_b, &sender_data.data.cap_y, &aux) - { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Sch proof verification (Y) failed".into(), - )))); + if echo_broadcast.hat_psis.keys().cloned().collect::>() != self.context.all_ids { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3WrongIdsHatPsi, + ))); } - if !direct_message.data2.psi_sch.verify( - sender_data - .data - .cap_a_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_a_to_send"))?, - sender_data - .data - .cap_x_to_send - .get(my_idx) - .ok_or_else(|| LocalError::new("my_idx={my_idx} is missing in cap_a_to_send"))?, - &aux, - ) { - return Err(ReceiveError::protocol(KeyRefreshError(KeyRefreshErrorEnum::Round3( - "Sch proof verification (X) failed".into(), - )))); + for (id, hat_psi) in echo_broadcast.hat_psis.iter() { + let cap_a = r2_payload.cap_as.safe_get("Schnorr commitments", id)?; + let cap_x = r2_payload.cap_xs.safe_get("Public share changes", id)?; + if !hat_psi.verify(cap_a, cap_x, &aux) { + return Err(ReceiveError::protocol(KeyRefreshError::new( + KeyRefreshErrorEnum::R3SchFailed(id.clone()), + ))); + } } Ok(Payload::new(Round3Payload { x })) @@ -760,58 +834,47 @@ impl Round for Round3 { _artifacts: BTreeMap, ) -> Result, LocalError> { let payloads = payloads.downcast_all::()?; - let others_x = payloads + + let my_id = &self.context.my_id; + + // Share changes from other nodes + let xs = payloads .into_iter() .map(|(id, payload)| (id, payload.x)) .collect::>(); - // The combined secret share change - let x_star = - others_x.into_values().sum::>() - + self.context.x_to_send.get(&self.context.my_id).ok_or_else(|| { - LocalError::new(format!("my_id={:?} is missing in x_to_send", self.context.my_id)) - })?; - - let my_id = self.context.my_id.clone(); - let mut all_ids = self.context.other_ids; - all_ids.insert(self.context.my_id); + // Share change generated by this node + let my_x = self.context.xs.safe_get("secret share changes", my_id)?; - let mut all_data = self.others_data; - all_data.insert(my_id.clone(), self.context.data_precomp); + // The combined secret share change + let x_star = xs.into_values().sum::>() + my_x; // The combined public share changes for each node - let cap_x_star = all_ids - .iter() - .enumerate() - .map(|(idx, id)| { - Ok(( - id.clone(), - all_data - .values() - .map(|data| data.data.cap_x_to_send.get(idx)) - .sum::>() - .ok_or_else(|| LocalError::new("idx={idx} is missing in cap_x_to_send"))?, - )) - }) - .collect::>()?; + let mut cap_x_star = BTreeMap::new(); + + for id_k in self.context.all_ids.iter() { + let mut result = Point::IDENTITY; + for payload in self.r2_payloads.values() { + let cap_x = payload.cap_xs.safe_get("public share changes", id_k)?; + result = result + *cap_x; + } + cap_x_star.insert(id_k.clone(), result); + } - let public_aux = all_data + let public_aux = self + .r2_payloads .into_iter() - .map(|(id, data)| { - ( - id, - PublicAuxInfo { - el_gamal_pk: data.data.cap_y, - paillier_pk: data.paillier_pk.into_wire(), - rp_params: data.rp_params.to_wire(), - }, - ) + .map(|(id, payload)| { + let aux_info = PublicAuxInfo { + paillier_pk: payload.paillier_pk.into_wire(), + rp_params: payload.rp_params.to_wire(), + }; + (id, aux_info) }) - .collect(); + .collect::>(); let secret_aux = SecretAuxInfo { paillier_sk: self.context.paillier_sk.into_wire(), - el_gamal_sk: self.context.y, }; let key_share_change = KeyShareChange { @@ -866,7 +929,7 @@ mod tests { .results() .unwrap(); - let (changes, aux_infos): (BTreeMap<_, _>, BTreeMap<_, _>) = results + let (changes, _aux_infos): (BTreeMap<_, _>, BTreeMap<_, _>) = results .into_iter() .map(|(id, (change, aux))| ((id, change), (id, aux))) .unzip(); @@ -881,15 +944,6 @@ mod tests { } } - for (id, aux_info) in aux_infos.iter() { - for other_aux_info in aux_infos.values() { - assert_eq!( - aux_info.secret_aux.el_gamal_sk.mul_by_generator(), - other_aux_info.public_aux[id].el_gamal_pk - ); - } - } - // The resulting sum of masks should be zero, since the combined secret key // should not change after applying the masks at each node. let mask_sum: Scalar = changes diff --git a/synedrion/src/paillier/encryption.rs b/synedrion/src/paillier/encryption.rs index 7d2b60a..e3f37e7 100644 --- a/synedrion/src/paillier/encryption.rs +++ b/synedrion/src/paillier/encryption.rs @@ -53,13 +53,6 @@ impl Randomizer

{ Self::new(pk, randomizer) } - /// Expose this secret randomizer. - /// - /// Supposed to be used in certain error branches where it is needed to generate a malicious behavior evidence. - pub fn expose(&self) -> P::Uint { - *self.randomizer.expose_secret() - } - /// Converts the randomizer to a publishable form by masking it with another randomizer and a public exponent. pub fn to_masked(&self, coeff: &Self, exponent: &PublicSigned) -> MaskedRandomizer

{ MaskedRandomizer( @@ -209,6 +202,7 @@ impl Ciphertext

{ } /// Encrypts the plaintext with a random randomizer. + #[cfg(test)] pub fn new(rng: &mut impl CryptoRngCore, pk: &PublicKeyPaillier

, plaintext: &SecretSigned) -> Self { Self::new_with_randomizer(pk, plaintext, &Randomizer::random(rng, pk)) } @@ -461,7 +455,10 @@ mod tests { let randomizer = Randomizer::random(&mut OsRng, pk); let ciphertext = Ciphertext::::new_with_randomizer(pk, &plaintext, &randomizer); let randomizer_back = ciphertext.derive_randomizer(&sk); - assert_eq!(randomizer.expose(), randomizer_back.expose()); + assert_eq!( + randomizer.randomizer.expose_secret(), + randomizer_back.randomizer.expose_secret() + ); } #[test]