diff --git a/Cargo.lock b/Cargo.lock index ea44ef7..ea4a342 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -385,6 +385,12 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "drain_filter_polyfill" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca9f76bdd86dfc8d64eecb0484d02ad4cf0e6767d4682f686d1b0580b6c27f82" + [[package]] name = "ecdsa-mpc" version = "0.3.0" @@ -1483,6 +1489,7 @@ dependencies = [ "anyhow", "curv-kzen", "digest 0.9.0", + "drain_filter_polyfill", "ecdsa-mpc", "futures", "lazy_static", diff --git a/round-based-ing/Cargo.toml b/round-based-ing/Cargo.toml index 19e2326..8ca4011 100644 --- a/round-based-ing/Cargo.toml +++ b/round-based-ing/Cargo.toml @@ -17,6 +17,7 @@ curv-kzen = "0.2" secp256k1 = "0.15" sorted-vec = "0.7" digest = "0.9" +drain_filter_polyfill = "0.1" [dev-dependencies] anyhow = "1" diff --git a/round-based-ing/src/generic.rs b/round-based-ing/src/generic.rs index 6705da0..576cc7b 100644 --- a/round-based-ing/src/generic.rs +++ b/round-based-ing/src/generic.rs @@ -5,6 +5,7 @@ use round_based::{Delivery, Incoming, MessageDestination, Mpc, MpcParty, Outgoin use ecdsa_mpc::protocol::{Address, InputMessage, OutputMessage, PartyIndex}; use ecdsa_mpc::state_machine::{self, State, Transition}; +use drain_filter_polyfill::VecExt; use futures::{SinkExt, StreamExt}; use thiserror::Error; use tracing::{error, trace, trace_span, warn}; @@ -21,6 +22,7 @@ where T: StateMachineTraits, T::ErrorState: fmt::Debug, M: Mpc, + T::Msg: MessageRound, { let span = trace_span!("MPC protocol execution", protocol = %protocol_name, i = party_index); trace!(parent: &span, "Starting the protocol"); @@ -29,6 +31,7 @@ where let (mut incomings, mut outgoings) = delivery.split(); let mut state: Box + Send> = Box::new(initial_state); + let mut out_of_order_messages: Vec> = vec![]; for round_i in 1u16.. { trace!(parent: &span, i = round_i, "Proceeding to round `i`"); @@ -38,14 +41,17 @@ where let msg = match convert_output_message_to_outgoing(&parties, msg) { Ok(m) => m, Err(UnknownDestination { recipient }) => { - warn!(?recipient, "Protocol wants to send message to the party that doesn't take part in computation. Ignore that message."); + warn!( + parent: &span, + ?recipient, + "Protocol wants to send message to the party that doesn't take part in computation. Ignore that message." + ); continue; } }; trace!( parent: &span, recipient = ?msg.recipient, - is_broadcast = msg.is_broadcast(), "Sending message to `recipient`" ); outgoings.feed(msg).await.map_err(Error::SendMessage)?; @@ -53,25 +59,56 @@ where outgoings.flush().await.map_err(Error::SendMessage)?; } + let mut out_of_order_messages_for_this_round = out_of_order_messages + .drain_filter(|incoming| incoming.msg.round() == round_i) + .collect::>(); + out_of_order_messages_for_this_round.reverse(); + let mut received_msgs = vec![]; while !state.is_input_complete(&received_msgs) { - let incoming = incomings - .next() - .await - .ok_or(Error::UnexpectedEof)? - .map_err(Error::ReceiveNextMessage)?; + let incoming = if let Some(msg) = out_of_order_messages_for_this_round.pop() { + trace!(parent: &span, "Retrieved out of order message"); + msg + } else { + incomings + .next() + .await + .ok_or(Error::UnexpectedEof)? + .map_err(Error::ReceiveNextMessage)? + }; let sender = incoming.sender; - if sender == party_index { - // Ignore own messages - continue; - } trace!( parent: &span, sender = incoming.sender, is_broadcast = incoming.is_broadcast(), + message_round = incoming.msg.round(), "Received message from `sender`" ); + + if sender == party_index { + trace!( + parent: &span, + "Message was sent by this party - ignoring it" + ); + continue; + } + if incoming.msg.round() < round_i { + warn!( + parent: &span, + "Received message from previous round. Ignore that message." + ); + continue; + } + if incoming.msg.round() > round_i { + trace!( + parent: &span, + "Received out of order message, save it to process later" + ); + out_of_order_messages.push(incoming); + continue; + } + let msg = convert_incoming_to_input_message(&parties, incoming)?; if !state.is_message_expected(&msg, &received_msgs) { error!( @@ -272,3 +309,35 @@ pub enum InvalidPartiesList { #[error("list of parties too large: it must fit into u16")] TooLarge, } + +pub trait MessageRound { + fn round(&self) -> u16; +} + +impl MessageRound for ecdsa_mpc::ecdsa::messages::keygen::Message { + fn round(&self) -> u16 { + match self { + Self::R1(..) => 1, + Self::R2(..) => 2, + Self::R3(..) => 3, + Self::R4(..) => 4, + } + } +} + +impl MessageRound for ecdsa_mpc::ecdsa::messages::signing::Message { + fn round(&self) -> u16 { + match self { + Self::R1(..) => 1, + Self::R2(..) => 2, + Self::R2b(..) => 3, + Self::R3(..) => 4, + Self::R4(..) => 5, + Self::R5(..) => 6, + Self::R6(..) => 7, + Self::R7(..) => 8, + Self::R8(..) => 9, + Self::R9(..) => 10, + } + } +}