From c6a2fb4535b2764d1080c3f3125bac3715633c2b Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 7 Jun 2024 16:10:36 +0200 Subject: [PATCH 1/6] Apply suggestions from code review Co-authored-by: raphaelrobert --- openmls/src/group/core_group/new_from_welcome.rs | 2 +- openmls/src/group/mls_group/creation.rs | 4 ++-- openmls/src/group/mls_group/mod.rs | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/openmls/src/group/core_group/new_from_welcome.rs b/openmls/src/group/core_group/new_from_welcome.rs index eb8e46c66..7abcb62b8 100644 --- a/openmls/src/group/core_group/new_from_welcome.rs +++ b/openmls/src/group/core_group/new_from_welcome.rs @@ -228,7 +228,7 @@ pub(in crate::group) fn build_staged_welcome( Ok(group) } -/// Process a Welcome message up to the point where the ratchet tree is is required. +/// Process a Welcome message up to the point where the ratchet tree is required. pub(in crate::group) fn process_welcome( welcome: Welcome, key_package_bundle: &KeyPackageBundle, diff --git a/openmls/src/group/mls_group/creation.rs b/openmls/src/group/mls_group/creation.rs index 7aafcbc83..a61a9eae7 100644 --- a/openmls/src/group/mls_group/creation.rs +++ b/openmls/src/group/mls_group/creation.rs @@ -148,8 +148,8 @@ fn transpose_err_opt(v: Result, E>) -> Option> { } impl ProcessedWelcome { - /// Creates a new processed [`Welcome`] message that can be used to parse - /// it before creating a [`StagedWelcome`]. + /// Creates a new processed [`Welcome`] message , which can be + /// inspected before creating a [`StagedWelcome`]. /// /// This does not require a ratchet tree yet. /// diff --git a/openmls/src/group/mls_group/mod.rs b/openmls/src/group/mls_group/mod.rs index b32b229bf..4c7360f5a 100644 --- a/openmls/src/group/mls_group/mod.rs +++ b/openmls/src/group/mls_group/mod.rs @@ -495,12 +495,12 @@ pub struct StagedWelcome { group: StagedCoreWelcome, } -/// A parsed, but not fully processed `Welcome` message. +/// A `Welcome` message that has been processed but not staged yet. /// /// This may be used in order to retrieve information from the `Welcome` about -/// the ratchet tree. +/// the ratchet tree and PSKs. /// -/// Use `into_staged_welcome` to get the [`StagedWelcome`] on this. +/// Use `into_staged_welcome` to stage it into a [`StagedWelcome`]. pub struct ProcessedWelcome { // The group configuration. See [`MlsGroupJoinConfig`] for more information. mls_group_config: MlsGroupJoinConfig, From d7187bc4d4af32d816d86ca6ec5605e7d6986dd9 Mon Sep 17 00:00:00 2001 From: Konrad Kohbrok Date: Fri, 14 Jun 2024 07:35:14 +0200 Subject: [PATCH 2/6] remove partialeq constraint on storage error --- traits/src/storage.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/traits/src/storage.rs b/traits/src/storage.rs index e5aad83e7..defd88315 100644 --- a/traits/src/storage.rs +++ b/traits/src/storage.rs @@ -28,7 +28,7 @@ pub const V_TEST: u16 = u16::MAX; /// More details can be taken from the comments on the respective method. pub trait StorageProvider { /// An opaque error returned by all methods on this trait. - type Error: core::fmt::Debug + std::error::Error + PartialEq; + type Error: core::fmt::Debug + std::error::Error; /// Get the version of this provider. fn version() -> u16 { From ba7aadbb387c41e718153c219b7c11f980b7f5e3 Mon Sep 17 00:00:00 2001 From: raphaelrobert Date: Wed, 19 Jun 2024 23:31:48 +0200 Subject: [PATCH 3/6] Another interop fix --- .github/workflows/interop.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/interop.yml b/.github/workflows/interop.yml index 5dd45ec56..9a706e5c9 100644 --- a/.github/workflows/interop.yml +++ b/.github/workflows/interop.yml @@ -66,7 +66,7 @@ jobs: run: | git clone https://github.com/mlswg/mls-implementations.git cd mls-implementations - git checkout f07090a844ebece12c064ce94ab853fd477db12f + git checkout 8a6ee96bc732abca77d872babf1830ccfec7fa49 - name: test-runner | Install dependencies run: | @@ -74,6 +74,7 @@ jobs: echo $(go env GOPATH)/bin >> $GITHUB_PATH go install google.golang.org/protobuf/cmd/protoc-gen-go@latest go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest + go get -u google.golang.org/grpc - name: test-runner | Build run: | @@ -84,7 +85,6 @@ jobs: make run-go || echo "Build despite errors." cd test-runner # TODO(#1366) - go get -u google.golang.org/grpc go mod tidy -e patch main.go main.go.patch go build From 37867333c4909c1058a3cf8d8e2470a655c46649 Mon Sep 17 00:00:00 2001 From: raphaelrobert Date: Wed, 19 Jun 2024 23:55:24 +0200 Subject: [PATCH 4/6] Change working directory --- .github/workflows/interop.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/interop.yml b/.github/workflows/interop.yml index 9a706e5c9..6c47a8926 100644 --- a/.github/workflows/interop.yml +++ b/.github/workflows/interop.yml @@ -74,6 +74,7 @@ jobs: echo $(go env GOPATH)/bin >> $GITHUB_PATH go install google.golang.org/protobuf/cmd/protoc-gen-go@latest go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest + cd mls-implementations go get -u google.golang.org/grpc - name: test-runner | Build From d0b66273bd04caa0c356c2d3d5a36567c99f9dba Mon Sep 17 00:00:00 2001 From: Jan Winkelmann <146678+keks@users.noreply.github.com> Date: Thu, 20 Jun 2024 11:06:55 +0200 Subject: [PATCH 5/6] Remove duplicate method in storage trait (#1591) Co-authored-by: Jan Winkelmann (keks) --- memory_storage/src/lib.rs | 8 -------- memory_storage/src/test_store.rs | 7 ------- openmls/src/group/mls_group/processing.rs | 2 +- traits/src/storage.rs | 6 ------ 4 files changed, 1 insertion(+), 22 deletions(-) diff --git a/memory_storage/src/lib.rs b/memory_storage/src/lib.rs index 4290d5723..dc1efb626 100644 --- a/memory_storage/src/lib.rs +++ b/memory_storage/src/lib.rs @@ -866,14 +866,6 @@ impl StorageProvider for MemoryStorage { self.append::(OWN_LEAF_NODES_LABEL, &key, value) } - fn clear_own_leaf_nodes>( - &self, - group_id: &GroupId, - ) -> Result<(), Self::Error> { - let key = serde_json::to_vec(group_id)?; - self.delete::(OWN_LEAF_NODES_LABEL, &key) - } - fn aad>( &self, group_id: &GroupId, diff --git a/memory_storage/src/test_store.rs b/memory_storage/src/test_store.rs index f8c7b31bf..c17d76fee 100644 --- a/memory_storage/src/test_store.rs +++ b/memory_storage/src/test_store.rs @@ -487,13 +487,6 @@ impl StorageProvider for MemoryStorage { todo!() } - fn clear_own_leaf_nodes>( - &self, - _group_id: &GroupId, - ) -> Result<(), Self::Error> { - todo!() - } - fn aad>( &self, _group_id: &GroupId, diff --git a/openmls/src/group/mls_group/processing.rs b/openmls/src/group/mls_group/processing.rs index bc61815a2..ba78d64bc 100644 --- a/openmls/src/group/mls_group/processing.rs +++ b/openmls/src/group/mls_group/processing.rs @@ -156,7 +156,7 @@ impl MlsGroup { self.own_leaf_nodes.clear(); provider .storage() - .clear_own_leaf_nodes(self.group_id()) + .delete_own_leaf_nodes(self.group_id()) .map_err(MergeCommitError::StorageError)?; // Delete a potential pending commit diff --git a/traits/src/storage.rs b/traits/src/storage.rs index e5aad83e7..d4c255768 100644 --- a/traits/src/storage.rs +++ b/traits/src/storage.rs @@ -66,12 +66,6 @@ pub trait StorageProvider { leaf_node: &LeafNode, ) -> Result<(), Self::Error>; - /// Clears the own leaf node for the group with given id to storage - fn clear_own_leaf_nodes>( - &self, - group_id: &GroupId, - ) -> Result<(), Self::Error>; - /// Enqueue a proposal. /// /// A good way to implement this could be to add a proposal to a proposal store, indexed by the From 2f835f789d5ef74f812aca37a823eaf7245ffb55 Mon Sep 17 00:00:00 2001 From: Jan Winkelmann <146678+keks@users.noreply.github.com> Date: Thu, 20 Jun 2024 18:27:19 +0200 Subject: [PATCH 6/6] Add Test checking GroupContextExtensionProposal Validation: Commit contains up to one GCE Proposal (#1590) Co-authored-by: Jan Winkelmann (keks) --- openmls/Cargo.toml | 2 + openmls/src/framing/mls_auth_content.rs | 2 +- .../src/group/core_group/new_from_welcome.rs | 7 +- openmls/src/group/core_group/staged_commit.rs | 7 +- openmls/src/group/mls_group/test_mls_group.rs | 161 ++++++++- openmls/src/lib.rs | 1 + openmls/src/skip_validation.rs | 100 ++++++ openmls/src/test_utils/frankenstein/codec.rs | 4 +- .../test_utils/frankenstein/credentials.rs | 11 + openmls/src/test_utils/frankenstein/crypto.rs | 49 +++ .../src/test_utils/frankenstein/extensions.rs | 42 +++ .../src/test_utils/frankenstein/framing.rs | 340 +++++++++++++++++- .../src/test_utils/frankenstein/group_info.rs | 23 ++ .../test_utils/frankenstein/key_package.rs | 2 +- .../src/test_utils/frankenstein/leaf_node.rs | 115 +++++- openmls/src/test_utils/frankenstein/mod.rs | 4 + 16 files changed, 830 insertions(+), 40 deletions(-) create mode 100644 openmls/src/skip_validation.rs create mode 100644 openmls/src/test_utils/frankenstein/crypto.rs diff --git a/openmls/Cargo.toml b/openmls/Cargo.toml index 17a6daa94..0137d32be 100644 --- a/openmls/Cargo.toml +++ b/openmls/Cargo.toml @@ -36,6 +36,7 @@ openmls_memory_storage = { path = "../memory_storage", features = [ ], optional = true } openmls_test = { path = "../openmls_test", optional = true } openmls_libcrux_crypto = { path = "../libcrux_crypto", optional = true } +once_cell = { version = "1.19.0", optional = true } [features] default = ["backtrace"] @@ -49,6 +50,7 @@ test-utils = [ "dep:openmls_basic_credential", "dep:openmls_memory_storage", "dep:openmls_test", + "dep:once_cell", ] libcrux-provider = [ "dep:openmls_libcrux_crypto", diff --git a/openmls/src/framing/mls_auth_content.rs b/openmls/src/framing/mls_auth_content.rs index a28e8cca8..3d7eae0b3 100644 --- a/openmls/src/framing/mls_auth_content.rs +++ b/openmls/src/framing/mls_auth_content.rs @@ -51,7 +51,7 @@ pub(crate) struct FramedContentAuthData { } impl FramedContentAuthData { - pub(super) fn deserialize( + pub(crate) fn deserialize( bytes: &mut R, content_type: ContentType, ) -> Result { diff --git a/openmls/src/group/core_group/new_from_welcome.rs b/openmls/src/group/core_group/new_from_welcome.rs index 7abcb62b8..3a50c1758 100644 --- a/openmls/src/group/core_group/new_from_welcome.rs +++ b/openmls/src/group/core_group/new_from_welcome.rs @@ -183,7 +183,12 @@ pub(in crate::group) fn build_staged_welcome( log_crypto!(trace, " Got: {:x?}", confirmation_tag); log_crypto!(trace, " Expected: {:x?}", public_group.confirmation_tag()); debug_assert!(false, "Confirmation tag mismatch"); - return Err(WelcomeError::ConfirmationTagMismatch); + + // in some tests we need to be able to proceed despite the tag being wrong, + // e.g. to test whether a later validation check is performed correctly. + if !crate::skip_validation::is_disabled::confirmation_tag() { + return Err(WelcomeError::ConfirmationTagMismatch); + } } let message_secrets_store = MessageSecretsStore::new_with_secret(0, message_secrets); diff --git a/openmls/src/group/core_group/staged_commit.rs b/openmls/src/group/core_group/staged_commit.rs index 534047dbf..440c9c0be 100644 --- a/openmls/src/group/core_group/staged_commit.rs +++ b/openmls/src/group/core_group/staged_commit.rs @@ -285,7 +285,12 @@ impl CoreGroup { // TODO: We have tests expecting this error. // They need to be rewritten. // debug_assert!(false, "Confirmation tag mismatch"); - return Err(StageCommitError::ConfirmationTagMismatch); + + // in some tests we need to be able to proceed despite the tag being wrong, + // e.g. to test whether a later validation check is performed correctly. + if !crate::skip_validation::is_disabled::confirmation_tag() { + return Err(StageCommitError::ConfirmationTagMismatch); + } } diff.update_interim_transcript_hash(ciphersuite, provider.crypto(), own_confirmation_tag)?; diff --git a/openmls/src/group/mls_group/test_mls_group.rs b/openmls/src/group/mls_group/test_mls_group.rs index 2c0edcac9..0e9b3bc51 100644 --- a/openmls/src/group/mls_group/test_mls_group.rs +++ b/openmls/src/group/mls_group/test_mls_group.rs @@ -11,9 +11,12 @@ use crate::{ key_packages::*, messages::proposals::*, prelude::Capabilities, - test_utils::test_framework::{ - errors::ClientError, noop_authentication_service, ActionType::Commit, CodecUse, - MlsGroupTestSetup, + test_utils::{ + frankenstein::{self, FrankenMlsMessage}, + test_framework::{ + errors::ClientError, noop_authentication_service, ActionType::Commit, CodecUse, + MlsGroupTestSetup, + }, }, tree::sender_ratchet::SenderRatchetConfiguration, }; @@ -1143,15 +1146,57 @@ fn remove_prosposal_by_ref( // Test that the builder pattern accurately configures the new group. #[openmls_test] fn group_context_extensions_proposal() { + let alice_provider = &mut Provider::default(); + let bob_provider = &mut Provider::default(); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = - setup_client("Alice", ciphersuite, provider); + setup_client("Alice", ciphersuite, alice_provider); + let (bob_credential_with_key, _bob_kpb, bob_signer, _bob_pk) = + setup_client("bob", ciphersuite, bob_provider); // === Alice creates a group === let mut alice_group = MlsGroup::builder() .ciphersuite(ciphersuite) - .build(provider, &alice_signer, alice_credential_with_key) + .with_wire_format_policy(WireFormatPolicy::new( + OutgoingWireFormatPolicy::AlwaysPlaintext, + IncomingWireFormatPolicy::Mixed, + )) + .build(alice_provider, &alice_signer, alice_credential_with_key) .expect("error creating group using builder"); + // === Alice adds Bob === + let bob_key_package = KeyPackage::builder() + .build( + ciphersuite, + bob_provider, + &bob_signer, + bob_credential_with_key, + ) + .expect("error building key package"); + + let (_, welcome, _) = alice_group + .add_members( + alice_provider, + &alice_signer, + &[bob_key_package.key_package().clone()], + ) + .unwrap(); + alice_group.merge_pending_commit(alice_provider).unwrap(); + + let welcome: MlsMessageIn = welcome.into(); + let welcome = welcome + .into_welcome() + .expect("expected message to be a welcome"); + + let mut bob_group = StagedWelcome::new_from_welcome( + bob_provider, + alice_group.configuration(), + welcome, + Some(alice_group.export_ratchet_tree().into()), + ) + .expect("Error creating staged join from Welcome") + .into_group(bob_provider) + .expect("Error creating group from staged join"); + // No required capabilities, so no specifically required extensions. assert!(alice_group .group() @@ -1168,20 +1213,104 @@ fn group_context_extensions_proposal() { RequiredCapabilitiesExtension::new(&[ExtensionType::RatchetTree], &[], &[]), )); - alice_group - .propose_group_context_extensions(provider, new_extensions.clone(), &alice_signer) + let (proposal, _) = alice_group + .propose_group_context_extensions(alice_provider, new_extensions.clone(), &alice_signer) .expect("failed to build group context extensions proposal"); + let proc_msg = bob_group + .process_message(bob_provider, proposal.into_protocol_message().unwrap()) + .unwrap(); + match proc_msg.into_content() { + ProcessedMessageContent::ProposalMessage(proposal) => bob_group + .store_pending_proposal(bob_provider.storage(), *proposal) + .unwrap(), + _ => unreachable!(), + }; + assert_eq!(alice_group.pending_proposals().count(), 1); - alice_group - .commit_to_pending_proposals(provider, &alice_signer) + let (commit, _, _) = alice_group + .commit_to_pending_proposals(alice_provider, &alice_signer) .expect("failed to commit to pending proposals"); + // we'll change the commit we feed to bob to include two GCE proposals + let mut franken_commit = FrankenMlsMessage::tls_deserialize( + &mut commit.tls_serialize_detached().unwrap().as_slice(), + ) + .unwrap(); + + // Craft a commit that has two GroupContextExtension proposals. This is forbidden by the RFC. + // Change the commit before alice commits, so alice's state is still in the old epoch and we can + // use her state to forge the macs and signatures + match &mut franken_commit.body { + frankenstein::FrankenMlsMessageBody::PublicMessage(msg) => { + match &mut msg.content.body { + frankenstein::FrankenFramedContentBody::Commit(commit) => { + let second_gces = frankenstein::FrankenProposalOrRef::Proposal( + frankenstein::FrankenProposal::GroupContextExtensions(vec![ + frankenstein::FrankenExtension::LastResort, + ]), + ); + + commit.proposals.push(second_gces); + } + _ => unreachable!(), + } + + let group_context = alice_group.export_group_context().clone(); + + let bob_group_context = bob_group.export_group_context(); + assert_eq!( + bob_group_context.confirmed_transcript_hash(), + group_context.confirmed_transcript_hash() + ); + + let secrets = alice_group.group.message_secrets(); + let membership_key = secrets.membership_key().as_slice(); + + *msg = frankenstein::FrankenPublicMessage::auth( + alice_provider, + group_context.ciphersuite(), + &alice_signer, + msg.content.clone(), + Some(&group_context.into()), + Some(membership_key), + // this is a dummy confirmation_tag: + Some(vec![0u8; 32].into()), + ); + } + _ => unreachable!(), + } + + // alice merges the unmodified commit alice_group - .merge_pending_commit(provider) + .merge_pending_commit(alice_provider) .expect("error merging pending commit"); + let fake_commit = MlsMessageIn::tls_deserialize( + &mut franken_commit.tls_serialize_detached().unwrap().as_slice(), + ) + .unwrap(); + + let fake_commit_protocol_msg = fake_commit.into_protocol_message().unwrap(); + + let err = { + let validation_skip_handle = crate::skip_validation::checks::confirmation_tag::handle(); + validation_skip_handle.with_disabled(|| { + bob_group.process_message(bob_provider, fake_commit_protocol_msg.clone()) + }) + } + .expect_err("expected an error"); + + assert!(matches!( + err, + ProcessMessageError::InvalidCommit( + StageCommitError::GroupContextExtensionsProposalValidationError( + GroupContextExtensionsProposalValidationError::TooManyGCEProposals + ) + ) + )); + let required_capabilities = alice_group .group() .context() @@ -1195,18 +1324,18 @@ fn group_context_extensions_proposal() { // === committing to two group context extensions should fail alice_group - .propose_group_context_extensions(provider, new_extensions, &alice_signer) + .propose_group_context_extensions(alice_provider, new_extensions, &alice_signer) .expect("failed to build group context extensions proposal"); // the proposals need to be different or they will be deduplicated alice_group - .propose_group_context_extensions(provider, new_extensions_2, &alice_signer) + .propose_group_context_extensions(alice_provider, new_extensions_2, &alice_signer) .expect("failed to build group context extensions proposal"); assert_eq!(alice_group.pending_proposals().count(), 2); alice_group - .commit_to_pending_proposals(provider, &alice_signer) + .commit_to_pending_proposals(alice_provider, &alice_signer) .expect_err( "expected error when committing to multiple group context extensions proposals", ); @@ -1220,12 +1349,8 @@ fn group_context_extensions_proposal() { )); alice_group - .propose_group_context_extensions(provider, new_extensions, &alice_signer) + .propose_group_context_extensions(alice_provider, new_extensions, &alice_signer) .expect_err("expected an error building GCE proposal with bad required_capabilities"); - - // TODO: we need to test that processing a commit with multiple group context extensions - // proposal also fails. however, we can't generate this commit, because our functions for - // constructing commits does not permit it. See #1476 } // Test that the builder pattern accurately configures the new group. diff --git a/openmls/src/lib.rs b/openmls/src/lib.rs index e12ebb36d..3f46f0455 100644 --- a/openmls/src/lib.rs +++ b/openmls/src/lib.rs @@ -189,6 +189,7 @@ pub mod storage; // Private mod binary_tree; +mod skip_validation; mod tree; /// Single place, re-exporting the most used public functions. diff --git a/openmls/src/skip_validation.rs b/openmls/src/skip_validation.rs new file mode 100644 index 000000000..94782322b --- /dev/null +++ b/openmls/src/skip_validation.rs @@ -0,0 +1,100 @@ +//! This module contains helpers for skipping validation. It is built such that setting the flag to +//! disable validation can only by set when the "test-utils" feature is enabled. +//! This module is used in two places, and they use different parts of it. +//! Code that performs validation and wants to check whether a check is disabled only uses the +//! [`is_disabled`] submodule. It contains getter functions that read the current state of the +//! flag. +//! Test code that disables checks uses the code in the [`checks`] submodule. It contains a module +//! for each check that can be disabled, and a getter for a handle, protected by a [`Mutex`]. This +//! is done because the flag state is shared between tests, and tests that set and unset the same +//! checks are not safe to run concurrently. +//! For example, a test could cann [`checks::confirmation_tag::handle`] to get a handle to disable +//! and re-enable the validation of confirmation tags. + +pub(crate) mod is_disabled { + use super::checks::*; + + pub(crate) fn confirmation_tag() -> bool { + confirmation_tag::FLAG.load(core::sync::atomic::Ordering::Relaxed) + } +} + +#[cfg(test)] +use std::sync::atomic::AtomicBool; + +/// Contains a reference to a flag. Provides convenience functions to set and clear the flag. +#[cfg(test)] +#[derive(Clone, Copy, Debug)] +pub struct SkipValidationHandle { + // we keep this field so we can see which handle this is when printing it. we don't need it otherwise + #[allow(dead_code)] + name: &'static str, + flag: &'static AtomicBool, +} + +/// Contains the flags and functions that return handles to control them. +pub(crate) mod checks { + /// Disables validation of the confirmation_tag. + pub(crate) mod confirmation_tag { + use std::sync::atomic::AtomicBool; + + /// A way of disabling verification and validation of confirmation tags. + pub(in crate::skip_validation) static FLAG: AtomicBool = AtomicBool::new(false); + + #[cfg(test)] + pub(crate) use lock::handle; + + #[cfg(test)] + mod lock { + use super::FLAG; + use crate::skip_validation::SkipValidationHandle; + use once_cell::sync::Lazy; + use std::sync::{Mutex, MutexGuard}; + + /// The name of the check that can be skipped here + const NAME: &str = "confirmation_tag"; + + /// A mutex needed to run tests that use this flag sequentially + static MUTEX: Lazy> = + Lazy::new(|| Mutex::new(SkipValidationHandle::new_confirmation_tag_handle())); + + /// Takes the mutex and returns the control handle to the validation skipper + pub(crate) fn handle() -> MutexGuard<'static, SkipValidationHandle> { + MUTEX.lock().unwrap_or_else(|e| { + panic!("error taking skip-validation mutex for '{NAME}': {e}") + }) + } + + impl SkipValidationHandle { + pub fn new_confirmation_tag_handle() -> Self { + Self { + name: NAME, + flag: &FLAG, + } + } + } + } + } +} + +#[cfg(test)] +impl SkipValidationHandle { + /// Disables validation for the check controlled by this handle + pub fn disable_validation(self) { + self.flag.store(true, core::sync::atomic::Ordering::Relaxed); + } + + /// Enables validation for the check controlled by this handle + pub fn enable_validation(self) { + self.flag + .store(false, core::sync::atomic::Ordering::Relaxed); + } + + /// Runs function `f` with validation disabled + pub fn with_disabled R>(self, mut f: F) -> R { + self.disable_validation(); + let r = f(); + self.enable_validation(); + r + } +} diff --git a/openmls/src/test_utils/frankenstein/codec.rs b/openmls/src/test_utils/frankenstein/codec.rs index 428c4e63a..5b8dbc150 100644 --- a/openmls/src/test_utils/frankenstein/codec.rs +++ b/openmls/src/test_utils/frankenstein/codec.rs @@ -214,7 +214,9 @@ impl Size for FrankenExtension { impl Serialize for FrankenExtension { fn tls_serialize(&self, writer: &mut W) -> Result { let written = self.extension_type().tls_serialize(writer)?; - let extension_data_len = self.tls_serialized_len(); + + // subtract the two bytes for the type header + let extension_data_len = self.tls_serialized_len() - 2; let mut extension_data = Vec::with_capacity(extension_data_len); let _ = match self { diff --git a/openmls/src/test_utils/frankenstein/credentials.rs b/openmls/src/test_utils/frankenstein/credentials.rs index 6c672209e..e13d8b52a 100644 --- a/openmls/src/test_utils/frankenstein/credentials.rs +++ b/openmls/src/test_utils/frankenstein/credentials.rs @@ -1,5 +1,7 @@ use tls_codec::*; +use crate::credentials::Credential; + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] @@ -7,3 +9,12 @@ pub struct FrankenCredential { credential_type: u16, serialized_credential_content: VLBytes, } + +impl From for FrankenCredential { + fn from(value: Credential) -> Self { + FrankenCredential { + credential_type: value.credential_type().into(), + serialized_credential_content: value.serialized_content().to_owned().into(), + } + } +} diff --git a/openmls/src/test_utils/frankenstein/crypto.rs b/openmls/src/test_utils/frankenstein/crypto.rs new file mode 100644 index 000000000..606696d6a --- /dev/null +++ b/openmls/src/test_utils/frankenstein/crypto.rs @@ -0,0 +1,49 @@ +use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, types::Ciphersuite}; +use tls_codec::{Serialize, TlsSerialize, TlsSize, VLBytes}; + +use super::FrankenAuthenticatedContentTbm; + +/// Computes a valid membership tag for the provided content. +pub fn compute_membership_tag( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + membership_key: &[u8], + auth_content_tbm: &FrankenAuthenticatedContentTbm, +) -> VLBytes { + let serialized_auth_content_tbm = &auth_content_tbm.tls_serialize_detached().unwrap(); + crypto + .hkdf_extract( + ciphersuite.hash_algorithm(), + membership_key, // Extract salt is HMAC key + serialized_auth_content_tbm, // Extract ikm is HMAC message + ) + .unwrap() + .as_slice() + .into() +} + +/// Implements the "sign with label" function of the spec. +pub fn sign_with_label(signer: &impl Signer, label: &[u8], msg: &[u8]) -> Vec { + let data = FrankenSignContent::new(label, msg) + .tls_serialize_detached() + .unwrap(); + signer.sign(&data).unwrap() +} + +#[derive(Debug, Clone, PartialEq, Eq, TlsSerialize, TlsSize)] +pub struct FrankenSignContent<'a> { + label: Vec, + content: &'a [u8], +} + +impl<'a> FrankenSignContent<'a> { + pub fn new(label: &[u8], content: &'a [u8]) -> Self { + let mut tagged_label = b"MLS 1.0 ".to_vec(); + tagged_label.extend_from_slice(label); + + Self { + label: tagged_label, + content, + } + } +} diff --git a/openmls/src/test_utils/frankenstein/extensions.rs b/openmls/src/test_utils/frankenstein/extensions.rs index 146cf6e5a..cc51f436d 100644 --- a/openmls/src/test_utils/frankenstein/extensions.rs +++ b/openmls/src/test_utils/frankenstein/extensions.rs @@ -1,5 +1,12 @@ use tls_codec::*; +use crate::{ + extensions::{ + ApplicationIdExtension, Extension, RatchetTreeExtension, RequiredCapabilitiesExtension, + }, + treesync::{node::NodeIn, Node, ParentNode}, +}; + use super::{FrankenCredential, FrankenLeafNode}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -67,6 +74,13 @@ impl FrankenExtension { } } +impl From for FrankenExtension { + fn from(value: Extension) -> Self { + let bytes = value.tls_serialize_detached().unwrap(); + FrankenExtension::tls_deserialize(&mut bytes.as_slice()).unwrap() + } +} + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] @@ -92,6 +106,20 @@ pub enum FrankenNode { ParentNode(FrankenParentNode), } +impl From for FrankenNode { + fn from(value: Node) -> Self { + let bytes = value.tls_serialize_detached().unwrap(); + FrankenNode::tls_deserialize(&mut bytes.as_slice()).unwrap() + } +} + +impl From for FrankenNode { + fn from(value: NodeIn) -> Self { + let bytes = value.tls_serialize_detached().unwrap(); + FrankenNode::tls_deserialize(&mut bytes.as_slice()).unwrap() + } +} + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] @@ -101,6 +129,13 @@ pub struct FrankenParentNode { pub unmerged_leaves: Vec, } +impl From for FrankenParentNode { + fn from(value: ParentNode) -> Self { + let bytes = value.tls_serialize_detached().unwrap(); + Self::tls_deserialize(&mut bytes.as_slice()).unwrap() + } +} + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] @@ -110,6 +145,13 @@ pub struct FrankenRequiredCapabilitiesExtension { pub credential_types: Vec, } +impl From for FrankenRequiredCapabilitiesExtension { + fn from(value: RequiredCapabilitiesExtension) -> Self { + let bytes = value.tls_serialize_detached().unwrap(); + Self::tls_deserialize(&mut bytes.as_slice()).unwrap() + } +} + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] diff --git a/openmls/src/test_utils/frankenstein/framing.rs b/openmls/src/test_utils/frankenstein/framing.rs index 8bc98bd74..5c4afb1c2 100644 --- a/openmls/src/test_utils/frankenstein/framing.rs +++ b/openmls/src/test_utils/frankenstein/framing.rs @@ -1,15 +1,26 @@ +use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, types::Ciphersuite}; use tls_codec::*; use crate::{ + binary_tree::LeafNodeIndex, + extensions::SenderExtensionIndex, framing::{ + mls_content::{AuthenticatedContentTbm, FramedContentBody, FramedContentTbs}, + mls_content_in::FramedContentBodyIn, MlsMessageIn, MlsMessageOut, PrivateMessage, PrivateMessageIn, PublicMessage, - PublicMessageIn, + PublicMessageIn, Sender, WireFormat, }, - messages::Welcome, + group::GroupContext, + messages::{ConfirmationTag, Welcome}, + prelude_test::signable::Signable, + schedule::{ConfirmationKey, MembershipKey}, }; use super::{ - commit::FrankenCommit, group_info::FrankenGroupInfo, FrankenKeyPackage, FrankenProposal, + commit::FrankenCommit, + compute_membership_tag, + group_info::{FrankenGroupContext, FrankenGroupInfo}, + sign_with_label, FrankenKeyPackage, FrankenProposal, }; #[derive( @@ -38,15 +49,149 @@ pub enum FrankenMlsMessageBody { KeyPackage(FrankenKeyPackage), } -#[derive( - Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, -)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct FrankenPublicMessage { pub content: FrankenFramedContent, pub auth: FrankenFramedContentAuthData, pub membership_tag: Option, } +impl tls_codec::Size for FrankenPublicMessage { + fn tls_serialized_len(&self) -> usize { + let tag_len = self + .membership_tag + .as_ref() + .map_or(0, |tag| tag.tls_serialized_len()); + + self.content.tls_serialized_len() + self.auth.tls_serialized_len() + tag_len + } +} + +impl Deserialize for FrankenPublicMessage { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + let content = FrankenFramedContent::tls_deserialize(bytes)?; + let auth = if matches!(content.body, FrankenFramedContentBody::Commit(_)) { + FrankenFramedContentAuthData::tls_deserialize_with_tag(bytes)? + } else { + FrankenFramedContentAuthData::tls_deserialize_without_tag(bytes)? + }; + + let membership_tag = if matches!(content.sender, FrankenSender::Member(_)) { + Some(VLBytes::tls_deserialize(bytes)?) + } else { + None + }; + + Ok(Self { + content, + auth, + membership_tag, + }) + } +} + +impl DeserializeBytes for FrankenPublicMessage { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> + where + Self: Sized, + { + let (content, bytes) = FrankenFramedContent::tls_deserialize_bytes(bytes)?; + let (auth, bytes) = match content.body { + FrankenFramedContentBody::Commit(_) => { + FrankenFramedContentAuthData::tls_deserialize_bytes_with_tag(bytes) + } + _ => FrankenFramedContentAuthData::tls_deserialize_bytes_without_tag(bytes), + }?; + let (membership_tag, bytes) = match content.sender { + FrankenSender::Member(_) => { + let (tag, bytes) = VLBytes::tls_deserialize_bytes(bytes)?; + (Some(tag), bytes) + } + _ => (None, bytes), + }; + + Ok(( + Self { + content, + auth, + membership_tag, + }, + bytes, + )) + } +} + +impl Serialize for FrankenPublicMessage { + fn tls_serialize(&self, writer: &mut W) -> Result { + let mut written = 0; + written += self.content.tls_serialize(writer)?; + written += self.auth.tls_serialize(writer)?; + if let Some(tag) = &self.membership_tag { + written += tag.tls_serialize(writer)?; + } + + Ok(written) + } +} + +impl FrankenPublicMessage { + /// auth builds a mostly(!) valid fake public message. However, it does not compute a correct + /// confirmation_tag. If the caller wants to process a message that requires a + /// confirmation_tag, they have two options: + /// + /// 1. build a valid tag themselves and provide it through the option + /// 2. provide a dummy tag and disable the verification of confirmation tags using + /// [`crate::disable_confirmation_tag_verification`]. + /// NB: Usually, confirmation tag verification should be turned back on after the call that + /// needs to be tricked! + pub(crate) fn auth( + provider: &impl crate::storage::OpenMlsProvider, + ciphersuite: openmls_traits::types::Ciphersuite, + signer: &impl Signer, + content: FrankenFramedContent, + group_context: Option<&FrankenGroupContext>, + membership_key: Option<&[u8]>, + confirmation_tag: Option, + ) -> Self { + let version = 1; // MLS 1.0 + let wire_format = 1; // PublicMessage + + let franken_tbs = FrankenFramedContentTbs { + version: 1, + wire_format: 1, // PublicMessage + content: &content, + group_context, + }; + + let auth = FrankenFramedContentAuthData::build( + signer, + version, + wire_format, + &content, + group_context, + confirmation_tag, + ); + + let tbm = FrankenAuthenticatedContentTbm { + content_tbs: franken_tbs, + auth: auth.clone(), + }; + + let membership_tag = membership_key.map(|membership_key| { + compute_membership_tag(provider.crypto(), ciphersuite, membership_key, &tbm) + }); + + FrankenPublicMessage { + content, + auth, + membership_tag, + } + } +} + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] @@ -104,14 +249,152 @@ pub struct FrankenWelcome { pub encrypted_group_info: VLBytes, } -#[derive( - Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, -)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct FrankenFramedContentAuthData { pub signature: VLBytes, pub confirmation_tag: Option, } +impl FrankenFramedContentAuthData { + pub fn tls_deserialize_with_tag( + bytes: &mut R, + ) -> Result { + let signature = VLBytes::tls_deserialize(bytes)?; + let confirmation_tag = VLBytes::tls_deserialize(bytes)?; + + Ok(Self { + signature, + confirmation_tag: Some(confirmation_tag), + }) + } + + pub fn tls_deserialize_bytes_with_tag(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> { + let (signature, bytes) = VLBytes::tls_deserialize_bytes(bytes)?; + let (confirmation_tag, bytes) = VLBytes::tls_deserialize_bytes(bytes)?; + + Ok(( + Self { + signature, + confirmation_tag: Some(confirmation_tag), + }, + bytes, + )) + } + + pub fn tls_deserialize_without_tag( + bytes: &mut R, + ) -> Result { + let signature = VLBytes::tls_deserialize(bytes)?; + + Ok(Self { + signature, + confirmation_tag: None, + }) + } + + pub fn tls_deserialize_bytes_without_tag( + bytes: &[u8], + ) -> Result<(Self, &[u8]), tls_codec::Error> { + let (signature, bytes) = VLBytes::tls_deserialize_bytes(bytes)?; + + Ok(( + Self { + signature, + confirmation_tag: None, + }, + bytes, + )) + } +} + +impl tls_codec::Size for FrankenFramedContentAuthData { + fn tls_serialized_len(&self) -> usize { + if let Some(tag) = &self.confirmation_tag { + self.signature.tls_serialized_len() + tag.tls_serialized_len() + } else { + self.signature.tls_serialized_len() + } + } +} + +impl Serialize for FrankenFramedContentAuthData { + fn tls_serialize(&self, writer: &mut W) -> Result { + let mut written = 0; + written += self.signature.tls_serialize(writer)?; + if let Some(confirmation_tag) = &self.confirmation_tag { + written += confirmation_tag.tls_serialize(writer)?; + } + Ok(written) + } +} + +impl FrankenFramedContentAuthData { + pub fn build( + signer: &impl Signer, + version: u16, + wire_format: u16, + content: &FrankenFramedContent, + group_context: Option<&FrankenGroupContext>, + confirmation_tag: Option, + ) -> Self { + let content_tbs = FrankenFramedContentTbs { + version, + wire_format, + content, + group_context, + }; + + let content_tbs_serialized = content_tbs.tls_serialize_detached().unwrap(); + + let signature = + sign_with_label(signer, b"FramedContentTBS", &content_tbs_serialized).into(); + + Self { + signature, + confirmation_tag, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FrankenFramedContentTbs<'a> { + version: u16, + wire_format: u16, + content: &'a FrankenFramedContent, + group_context: Option<&'a FrankenGroupContext>, +} + +impl<'a> tls_codec::Size for FrankenFramedContentTbs<'a> { + fn tls_serialized_len(&self) -> usize { + if let Some(ctx) = self.group_context { + 4 + self.content.tls_serialized_len() + ctx.tls_serialized_len() + } else { + 4 + self.content.tls_serialized_len() + } + } +} + +impl<'a> Serialize for FrankenFramedContentTbs<'a> { + fn tls_serialize(&self, writer: &mut W) -> Result { + writer.write_all(&self.version.to_be_bytes())?; + writer.write_all(&self.wire_format.to_be_bytes())?; + + let mut written = 4; // contains the two u16 version and wire_format + written += self.content.tls_serialize(writer)?; + if let Some(group_context) = &self.group_context { + written += group_context.tls_serialize(writer)?; + } + + Ok(written) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, TlsSerialize, TlsSize)] +pub struct FrankenAuthenticatedContentTbm<'a> { + content_tbs: FrankenFramedContentTbs<'a>, + auth: FrankenFramedContentAuthData, +} + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] @@ -187,3 +470,42 @@ impl From for Welcome { Welcome::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()).unwrap() } } + +impl From for FramedContentBodyIn { + fn from(value: FrankenFramedContentBody) -> Self { + FramedContentBodyIn::tls_deserialize( + &mut value.tls_serialize_detached().unwrap().as_slice(), + ) + .unwrap() + } +} + +impl From for FramedContentBody { + fn from(value: FrankenFramedContentBody) -> Self { + FramedContentBodyIn::from(value).into() + } +} + +impl From for FrankenSender { + fn from(value: Sender) -> Self { + match value { + Sender::Member(i) => FrankenSender::Member(i.u32()), + // this cast is safe, because the index method casts it from u32 to usize for some + // reason, so it's known to fit u32 + Sender::External(i) => FrankenSender::External(i.index() as u32), + Sender::NewMemberProposal => FrankenSender::NewMemberProposal, + Sender::NewMemberCommit => FrankenSender::NewMemberCommit, + } + } +} + +impl From for Sender { + fn from(value: FrankenSender) -> Self { + match value { + FrankenSender::Member(i) => Sender::Member(LeafNodeIndex::new(i)), + FrankenSender::External(i) => Sender::External(SenderExtensionIndex::new(i)), + FrankenSender::NewMemberProposal => Sender::NewMemberProposal, + FrankenSender::NewMemberCommit => Sender::NewMemberCommit, + } + } +} diff --git a/openmls/src/test_utils/frankenstein/group_info.rs b/openmls/src/test_utils/frankenstein/group_info.rs index 1f116e1ea..71868e414 100644 --- a/openmls/src/test_utils/frankenstein/group_info.rs +++ b/openmls/src/test_utils/frankenstein/group_info.rs @@ -10,6 +10,7 @@ use crate::{ signable::{Signable, SignedStruct}, signature::{OpenMlsSignaturePublicKey, Signature}, }, + group::GroupContext, messages::group_info::GroupInfo, }; @@ -91,6 +92,28 @@ pub struct FrankenGroupContext { extensions: Vec, } +impl From for FrankenGroupContext { + fn from(value: GroupContext) -> Self { + let extensions = value + .extensions() + .iter() + .map(|ext| ext.clone().into()) + .collect(); + FrankenGroupContext { + protocol_version: match value.protocol_version() { + crate::versions::ProtocolVersion::Mls10 => 1, + crate::versions::ProtocolVersion::Other(other) => other, + }, + ciphersuite: value.ciphersuite().into(), + group_id: value.group_id().as_slice().to_vec().into(), + epoch: value.epoch().as_u64(), + tree_hash: value.tree_hash().to_vec().into(), + confirmed_transcript_hash: value.confirmed_transcript_hash().to_vec().into(), + extensions, + } + } +} + impl From for FrankenGroupInfo { fn from(ln: GroupInfo) -> Self { FrankenGroupInfo::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice()) diff --git a/openmls/src/test_utils/frankenstein/key_package.rs b/openmls/src/test_utils/frankenstein/key_package.rs index d6a41429f..106be806f 100644 --- a/openmls/src/test_utils/frankenstein/key_package.rs +++ b/openmls/src/test_utils/frankenstein/key_package.rs @@ -29,7 +29,7 @@ pub struct FrankenKeyPackage { impl FrankenKeyPackage { // Re-sign both the KeyPackage and the enclosed LeafNode pub fn resign(&mut self, signer: &impl Signer) { - self.payload.leaf_node.resign(signer); + self.payload.leaf_node.resign(None, signer); let new_self = self.payload.clone().sign(signer).unwrap(); let _ = std::mem::replace(self, new_self); } diff --git a/openmls/src/test_utils/frankenstein/leaf_node.rs b/openmls/src/test_utils/frankenstein/leaf_node.rs index 186207333..1c2774618 100644 --- a/openmls/src/test_utils/frankenstein/leaf_node.rs +++ b/openmls/src/test_utils/frankenstein/leaf_node.rs @@ -6,31 +6,39 @@ use tls_codec::*; use super::{extensions::FrankenExtension, key_package::FrankenLifetime, FrankenCredential}; use crate::{ + binary_tree::array_representation::tree, ciphersuite::{ signable::{Signable, SignedStruct}, signature::Signature, }, - treesync::{node::leaf_node::LeafNodeIn, LeafNode}, + treesync::{ + node::leaf_node::{LeafNodeIn, LeafNodeTbs, TreePosition}, + LeafNode, + }, }; #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] pub struct FrankenLeafNode { - pub payload: FrankenLeafNodeTbs, + pub payload: FrankenLeafNodePayload, pub signature: VLBytes, } impl FrankenLeafNode { // Re-sign the LeafNode - pub fn resign(&mut self, signer: &impl Signer) { - let new_self = self.payload.clone().sign(signer).unwrap(); + pub fn resign(&mut self, tree_position: Option, signer: &impl Signer) { + let tbs = FrankenLeafNodeTbs { + payload: self.payload.clone(), + tree_position, + }; + let new_self = tbs.sign(signer).unwrap(); let _ = std::mem::replace(self, new_self); } } impl Deref for FrankenLeafNode { - type Target = FrankenLeafNodeTbs; + type Target = FrankenLeafNodePayload; fn deref(&self) -> &Self::Target { &self.payload @@ -44,9 +52,9 @@ impl DerefMut for FrankenLeafNode { } impl SignedStruct for FrankenLeafNode { - fn from_payload(payload: FrankenLeafNodeTbs, signature: Signature) -> Self { + fn from_payload(tbs: FrankenLeafNodeTbs, signature: Signature) -> Self { Self { - payload, + payload: tbs.payload, signature: signature.as_slice().to_owned().into(), } } @@ -90,7 +98,7 @@ impl From for LeafNodeIn { #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] -pub struct FrankenLeafNodeTbs { +pub struct FrankenLeafNodePayload { pub encryption_key: VLBytes, pub signature_key: VLBytes, pub credential: FrankenCredential, @@ -99,6 +107,97 @@ pub struct FrankenLeafNodeTbs { pub extensions: Vec, } +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenTreePosition { + pub group_id: VLBytes, + pub leaf_index: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq, TlsSize)] +pub struct FrankenLeafNodeTbs { + pub payload: FrankenLeafNodePayload, + pub tree_position: Option, +} + +impl FrankenLeafNodeTbs { + fn deserialize_without_treeposition(bytes: &mut R) -> Result { + let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?; + + Ok(Self { + payload, + tree_position: None, + }) + } + + fn deserialize_with_treeposition(bytes: &mut R) -> Result { + let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?; + let tree_position = FrankenTreePosition::tls_deserialize(bytes)?; + Ok(Self { + payload, + tree_position: Some(tree_position), + }) + } +} + +impl Deserialize for FrankenLeafNodeTbs { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + let payload = FrankenLeafNodePayload::tls_deserialize(bytes)?; + let tree_position = match payload.leaf_node_source { + FrankenLeafNodeSource::KeyPackage(_) => None, + FrankenLeafNodeSource::Update | FrankenLeafNodeSource::Commit(_) => { + let tree_position = FrankenTreePosition::tls_deserialize(bytes)?; + Some(tree_position) + } + }; + + Ok(Self { + payload, + tree_position, + }) + } +} + +impl DeserializeBytes for FrankenLeafNodeTbs { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> + where + Self: Sized, + { + let (payload, rest) = FrankenLeafNodePayload::tls_deserialize_bytes(bytes)?; + let (tree_position, rest) = match payload.leaf_node_source { + FrankenLeafNodeSource::KeyPackage(_) => (None, rest), + FrankenLeafNodeSource::Update | FrankenLeafNodeSource::Commit(_) => { + let (tree_position, rest) = FrankenTreePosition::tls_deserialize_bytes(bytes)?; + (Some(tree_position), rest) + } + }; + + Ok(( + Self { + payload, + tree_position, + }, + rest, + )) + } +} + +impl Serialize for FrankenLeafNodeTbs { + fn tls_serialize(&self, writer: &mut W) -> Result { + let mut written = self.payload.tls_serialize(writer)?; + + if let Some(tree_info) = &self.tree_position { + written += tree_info.tls_serialize(writer)? + }; + + Ok(written) + } +} + #[derive( Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, )] diff --git a/openmls/src/test_utils/frankenstein/mod.rs b/openmls/src/test_utils/frankenstein/mod.rs index f26393ea4..2c1a098b7 100644 --- a/openmls/src/test_utils/frankenstein/mod.rs +++ b/openmls/src/test_utils/frankenstein/mod.rs @@ -7,6 +7,7 @@ mod codec; mod commit; mod credentials; +mod crypto; mod extensions; mod framing; mod group_info; @@ -14,7 +15,10 @@ mod key_package; mod leaf_node; mod proposals; +pub use self::commit::*; pub use self::credentials::*; +pub use self::crypto::*; +pub use self::extensions::*; pub use self::framing::*; pub use self::key_package::*; pub use self::leaf_node::*;