Skip to content

Commit

Permalink
Make Protocol::verify_*_is_invalid() mandatory to implement
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Jan 1, 2025
1 parent c9a594c commit c0cd439
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Message parts in `Round::receive_message()` and `ProtocolError::verify_messages_constitute_error()` are bundled in `ProtocolMessage`. ([#79])
- `RoundId`s are passed by reference in public methods since they are not `Copy`. ([#79])
- Using a single `ProtocolError::required_messages()` instead of multiple methods. ([#79])
- `Protocol::verify_*_is_invalid()` are now mandatory to implement. ([#79])


### Added
Expand Down
12 changes: 12 additions & 0 deletions examples/src/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ impl<Id> Protocol<Id> for SimpleProtocol {
_ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())),
}
}

fn verify_normal_broadcast_is_invalid(
_deserializer: &Deserializer,
round_id: &RoundId,
message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
if round_id == &RoundId::new(1) || round_id == &RoundId::new(2) {
message.verify_is_some()
} else {
Err(MessageValidationError::InvalidEvidence("Invalid round number".into()))
}
}
}

#[derive(Debug)]
Expand Down
28 changes: 26 additions & 2 deletions manul/benches/empty_rounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use manul::{
dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
protocol::{
Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, LocalError,
NoProtocolErrors, PartyId, Payload, Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round,
RoundId, Serializer,
MessageValidationError, NoProtocolErrors, NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessage,
ProtocolMessagePart, ReceiveError, Round, RoundId, Serializer,
},
signature::Keypair,
};
Expand All @@ -22,6 +22,30 @@ pub struct EmptyProtocol;
impl<Id> Protocol<Id> for EmptyProtocol {
type Result = ();
type ProtocolError = NoProtocolErrors;

fn verify_direct_message_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &DirectMessage,
) -> Result<(), MessageValidationError> {
unimplemented!()
}

fn verify_echo_broadcast_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &EchoBroadcast,
) -> Result<(), MessageValidationError> {
unimplemented!()
}

fn verify_normal_broadcast_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
unimplemented!()
}
}

#[derive(Debug)]
Expand Down
43 changes: 41 additions & 2 deletions manul/src/combinators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ use serde::{Deserialize, Serialize};

use crate::protocol::{
Artifact, BoxedRng, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EchoRoundParticipation, EntryPoint,
FinalizeOutcome, LocalError, NormalBroadcast, ObjectSafeRound, PartyId, Payload, Protocol, ProtocolError,
ProtocolMessage, ProtocolValidationError, ReceiveError, RequiredMessages, RoundId, Serializer,
FinalizeOutcome, LocalError, MessageValidationError, NormalBroadcast, ObjectSafeRound, PartyId, Payload, Protocol,
ProtocolError, ProtocolMessage, ProtocolValidationError, ReceiveError, RequiredMessages, RoundId, Serializer,
};

/// A marker trait that is used to disambiguate blanket trait implementations for [`Protocol`] and [`EntryPoint`].
Expand Down Expand Up @@ -217,6 +217,45 @@ where
{
type Result = <C::Protocol2 as Protocol<Id>>::Result;
type ProtocolError = ChainedProtocolError<Id, C>;

fn verify_direct_message_is_invalid(
deserializer: &Deserializer,
round_id: &RoundId,
message: &DirectMessage,
) -> Result<(), MessageValidationError> {
let (group, round_id) = round_id.split_group()?;
if group == 1 {
C::Protocol1::verify_direct_message_is_invalid(deserializer, &round_id, message)
} else {
C::Protocol2::verify_direct_message_is_invalid(deserializer, &round_id, message)
}
}

fn verify_echo_broadcast_is_invalid(
deserializer: &Deserializer,
round_id: &RoundId,
message: &EchoBroadcast,
) -> Result<(), MessageValidationError> {
let (group, round_id) = round_id.split_group()?;
if group == 1 {
C::Protocol1::verify_echo_broadcast_is_invalid(deserializer, &round_id, message)
} else {
C::Protocol2::verify_echo_broadcast_is_invalid(deserializer, &round_id, message)
}
}

fn verify_normal_broadcast_is_invalid(
deserializer: &Deserializer,
round_id: &RoundId,
message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
let (group, round_id) = round_id.split_group()?;
if group == 1 {
C::Protocol1::verify_normal_broadcast_is_invalid(deserializer, &round_id, message)
} else {
C::Protocol2::verify_normal_broadcast_is_invalid(deserializer, &round_id, message)
}
}
}

/// A trait defining how the entry point for the whole chained protocol
Expand Down
59 changes: 34 additions & 25 deletions manul/src/protocol/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,25 @@ impl RoundId {
}
}

/// Removes the top group prefix from this round ID.
/// Removes the top group prefix from this round ID
/// and returns this prefix along with the resulting round ID.
///
/// Returns the `Err` variant if the round ID is not nested.
pub(crate) fn split_group(&self) -> Result<(u8, Self), LocalError> {
if self.round_nums.len() == 1 {
Err(LocalError::new("This round ID is not in a group"))
} else {
let mut round_nums = self.round_nums.clone();
let group = round_nums.pop().expect("vector size greater than 1");
let round_id = Self {
round_nums,
is_echo: self.is_echo,
};
Ok((group, round_id))
}
}

/// Removes the top group prefix from this round ID and returns the resulting Round ID.
///
/// Returns the `Err` variant if the round ID is not nested.
pub(crate) fn ungroup(&self) -> Result<Self, LocalError> {
Expand Down Expand Up @@ -135,44 +153,35 @@ pub trait Protocol<Id>: 'static {
/// Returns `Ok(())` if the given direct message cannot be deserialized
/// assuming it is a direct message from the round `round_id`.
///
/// Normally one would use [`DirectMessage::verify_is_not`] when implementing this.
/// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`]
/// when implementing this.
fn verify_direct_message_is_invalid(
#[allow(unused_variables)] deserializer: &Deserializer,
deserializer: &Deserializer,
round_id: &RoundId,
#[allow(unused_variables)] message: &DirectMessage,
) -> Result<(), MessageValidationError> {
Err(MessageValidationError::InvalidEvidence(format!(
"Invalid round number: {round_id:?}"
)))
}
message: &DirectMessage,
) -> Result<(), MessageValidationError>;

/// Returns `Ok(())` if the given echo broadcast cannot be deserialized
/// assuming it is an echo broadcast from the round `round_id`.
///
/// Normally one would use [`EchoBroadcast::verify_is_not`] when implementing this.
/// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`]
/// when implementing this.
fn verify_echo_broadcast_is_invalid(
#[allow(unused_variables)] deserializer: &Deserializer,
deserializer: &Deserializer,
round_id: &RoundId,
#[allow(unused_variables)] message: &EchoBroadcast,
) -> Result<(), MessageValidationError> {
Err(MessageValidationError::InvalidEvidence(format!(
"Invalid round number: {round_id:?}"
)))
}
message: &EchoBroadcast,
) -> Result<(), MessageValidationError>;

/// Returns `Ok(())` if the given echo broadcast cannot be deserialized
/// assuming it is an echo broadcast from the round `round_id`.
///
/// Normally one would use [`NormalBroadcast::verify_is_not`] when implementing this.
/// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`]
/// when implementing this.
fn verify_normal_broadcast_is_invalid(
#[allow(unused_variables)] deserializer: &Deserializer,
deserializer: &Deserializer,
round_id: &RoundId,
#[allow(unused_variables)] message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
Err(MessageValidationError::InvalidEvidence(format!(
"Invalid round number: {round_id:?}"
)))
}
message: &NormalBroadcast,
) -> Result<(), MessageValidationError>;
}

/// Declares which parts of the message from a round have to be stored to serve as the evidence of malicious behavior.
Expand Down
31 changes: 29 additions & 2 deletions manul/src/session/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,10 +841,13 @@ fn filter_messages<Verifier>(
mod tests {
use impls::impls;

use super::{Message, ProcessedArtifact, ProcessedMessage, Session, SessionParameters, VerifiedMessage};
use super::{
Deserializer, Message, ProcessedArtifact, ProcessedMessage, RoundId, Session, SessionParameters,
VerifiedMessage,
};
use crate::{
dev::{BinaryFormat, TestSessionParams, TestVerifier},
protocol::{NoProtocolErrors, Protocol},
protocol::{DirectMessage, EchoBroadcast, MessageValidationError, NoProtocolErrors, NormalBroadcast, Protocol},
};

#[test]
Expand All @@ -864,6 +867,30 @@ mod tests {
impl Protocol<<SP as SessionParameters>::Verifier> for DummyProtocol {
type Result = ();
type ProtocolError = NoProtocolErrors;

fn verify_direct_message_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &DirectMessage,
) -> Result<(), MessageValidationError> {
unimplemented!()
}

fn verify_echo_broadcast_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &EchoBroadcast,
) -> Result<(), MessageValidationError> {
unimplemented!()
}

fn verify_normal_broadcast_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
unimplemented!()
}
}

// We need `Session` to be `Send` so that we send a `Session` object to a task
Expand Down
30 changes: 27 additions & 3 deletions manul/src/tests/partial_echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use serde::{Deserialize, Serialize};
use crate::{
dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
protocol::{
Artifact, BoxedRound, Deserializer, EchoBroadcast, EchoRoundParticipation, EntryPoint, FinalizeOutcome,
LocalError, NoProtocolErrors, PartyId, Payload, Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError,
Round, RoundId, Serializer,
Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EchoRoundParticipation, EntryPoint,
FinalizeOutcome, LocalError, MessageValidationError, NoProtocolErrors, NormalBroadcast, PartyId, Payload,
Protocol, ProtocolMessage, ProtocolMessagePart, ReceiveError, Round, RoundId, Serializer,
},
signature::Keypair,
};
Expand All @@ -24,6 +24,30 @@ struct PartialEchoProtocol<Id>(PhantomData<Id>);
impl<Id: PartyId> Protocol<Id> for PartialEchoProtocol<Id> {
type Result = ();
type ProtocolError = NoProtocolErrors;

fn verify_direct_message_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &DirectMessage,
) -> Result<(), MessageValidationError> {
unimplemented!()
}

fn verify_echo_broadcast_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &EchoBroadcast,
) -> Result<(), MessageValidationError> {
unimplemented!()
}

fn verify_normal_broadcast_is_invalid(
_deserializer: &Deserializer,
_round_id: &RoundId,
_message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
unimplemented!()
}
}

#[derive(Debug, Clone)]
Expand Down

0 comments on commit c0cd439

Please sign in to comment.