diff --git a/openmls/src/group/core_group/mod.rs b/openmls/src/group/core_group/mod.rs index 3983d127ab..7703260865 100644 --- a/openmls/src/group/core_group/mod.rs +++ b/openmls/src/group/core_group/mod.rs @@ -211,6 +211,14 @@ impl CoreGroupBuilder { self } + /// Set the `group_context_extensions` of the [`CoreGroup`]. + pub fn with_group_context_extensions(mut self, extensions: Extensions) -> Self { + self.public_group_builder = self + .public_group_builder + .with_group_context_extensions(extensions); + self + } + /// Build the [`CoreGroup`]. /// Any values that haven't been set in the builder are set to their default /// values (which might be random). diff --git a/openmls/src/group/mls_group/config.rs b/openmls/src/group/mls_group/config.rs index a78c5e4ff6..ca2f00ea4e 100644 --- a/openmls/src/group/mls_group/config.rs +++ b/openmls/src/group/mls_group/config.rs @@ -60,6 +60,8 @@ pub struct MlsGroupConfig { pub(crate) lifetime: Lifetime, /// Ciphersuite and protocol version pub(crate) crypto_config: CryptoConfig, + // Other extensions + pub(crate) group_context_extensions: Extensions, } impl MlsGroupConfig { @@ -118,6 +120,11 @@ impl MlsGroupConfig { &self.crypto_config } + /// Set the `group_context_extensions` property of the MlsGroupConfig. + pub fn group_context_extensions(&self) -> &Extensions { + &self.group_context_extensions + } + #[cfg(any(feature = "test-utils", test))] pub fn test_default(ciphersuite: Ciphersuite) -> Self { Self::builder() @@ -220,6 +227,12 @@ impl MlsGroupConfigBuilder { self } + /// Sets the `group_context_extensions` property of the MlsGroupConfig. + pub fn group_context_extensions(mut self, extensions: Extensions) -> Self { + self.config.group_context_extensions = extensions; + self + } + /// Finalizes the builder and retursn an `[MlsGroupConfig`]. pub fn build(self) -> MlsGroupConfig { self.config diff --git a/openmls/src/group/mls_group/creation.rs b/openmls/src/group/mls_group/creation.rs index df2164bea4..91e94f486f 100644 --- a/openmls/src/group/mls_group/creation.rs +++ b/openmls/src/group/mls_group/creation.rs @@ -57,6 +57,7 @@ impl MlsGroup { credential_with_key, ) .with_config(group_config) + .with_group_context_extensions(mls_group_config.group_context_extensions.clone()) .with_required_capabilities(mls_group_config.required_capabilities.clone()) .with_external_senders(mls_group_config.external_senders.clone()) .with_max_past_epoch_secrets(mls_group_config.max_past_epochs) diff --git a/openmls/src/group/mls_group/test_mls_group.rs b/openmls/src/group/mls_group/test_mls_group.rs index 8cae8a6679..20a04289b1 100644 --- a/openmls/src/group/mls_group/test_mls_group.rs +++ b/openmls/src/group/mls_group/test_mls_group.rs @@ -645,3 +645,62 @@ fn remove_prosposal_by_ref(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv _ => unreachable!("Expected a StagedCommit."), } } + +#[apply(ciphersuites_and_providers)] +fn test_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { + let group_id = GroupId::from_slice(b"Test Group"); + let application_id = b"Test App ID"; + let metadata = vec![1, 2, 3]; + + let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = + setup_client("Alice", ciphersuite, provider); + + // Define the MlsGroup configuration + let mls_group_config = MlsGroupConfig::builder() + .wire_format_policy(WireFormatPolicy::new( + OutgoingWireFormatPolicy::AlwaysPlaintext, + IncomingWireFormatPolicy::Mixed, + )) + .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .group_context_extensions(Extensions::single(Extension::ProtectedMetadata( + ProtectedMetadata::new( + &alice_signer, + application_id.to_vec(), + alice_credential_with_key.credential.clone(), + alice_credential_with_key.signature_key.as_slice().to_vec(), + metadata, + ) + .unwrap(), + ))) + .build(); + + // === Alice creates a group === + let mut alice_group = MlsGroup::new_with_group_id( + provider, + &alice_signer, + &mls_group_config, + group_id.clone(), + alice_credential_with_key, + ) + .expect("An unexpected error occurred."); + + assert!(alice_group + .export_group_context() + .extensions() + .contains(ExtensionType::ProtectedMetadata)); + + // Check the internal state has changed + assert_eq!(alice_group.state_changed(), InnerState::Changed); + + alice_group + .save(provider.key_store()) + .expect("Could not write group state to file"); + + let alice_group_deserialized = + MlsGroup::load(&group_id, provider.key_store()).expect("Could not deserialize MlsGroup"); + + assert!(alice_group_deserialized + .export_group_context() + .extensions() + .contains(ExtensionType::ProtectedMetadata)); +} diff --git a/openmls/src/group/public_group/builder.rs b/openmls/src/group/public_group/builder.rs index aac7ac329f..aeae7f9630 100644 --- a/openmls/src/group/public_group/builder.rs +++ b/openmls/src/group/public_group/builder.rs @@ -26,6 +26,7 @@ pub(crate) struct TempBuilderPG1 { required_capabilities: Option, external_senders: Option, leaf_extensions: Option, + group_context_extensions: Option, } impl TempBuilderPG1 { @@ -34,6 +35,11 @@ impl TempBuilderPG1 { self } + pub(crate) fn with_group_context_extensions(mut self, extensions: Extensions) -> Self { + self.group_context_extensions = Some(extensions); + self + } + pub(crate) fn with_required_capabilities( mut self, required_capabilities: RequiredCapabilitiesExtension, @@ -87,17 +93,22 @@ impl TempBuilderPG1 { _ => LibraryError::custom("Unexpected ExtensionError").into(), })?; let required_capabilities = Extension::RequiredCapabilities(required_capabilities); - let extensions = - if let Some(ext_senders) = self.external_senders.map(Extension::ExternalSenders) { - vec![required_capabilities, ext_senders] - } else { - vec![required_capabilities] - }; + + let mut extensions = Extensions::from_vec(vec![required_capabilities])?; + if let Some(ext_senders) = self.external_senders.map(Extension::ExternalSenders) { + extensions.add(ext_senders)?; + } + if let Some(group_context_extensions) = self.group_context_extensions { + for extension in group_context_extensions.iter() { + extensions.add(extension.clone())?; + } + } + let group_context = GroupContext::create_initial_group_context( self.crypto_config.ciphersuite, self.group_id, treesync.tree_hash().to_vec(), - Extensions::from_vec(extensions)?, + extensions, ); let next_builder = TempBuilderPG2 { treesync, @@ -172,6 +183,7 @@ impl PublicGroup { required_capabilities: None, external_senders: None, leaf_extensions: None, + group_context_extensions: None, } } }