From b4d676eaa9577d7e6e1a7ca80c4cb6eacad06eb3 Mon Sep 17 00:00:00 2001 From: lfzkoala Date: Tue, 2 Jul 2024 03:01:28 -0400 Subject: [PATCH] finish key update test --- xmtp_mls/src/groups/mod.rs | 101 ++++++++++++++++++++++++++++++------- 1 file changed, 82 insertions(+), 19 deletions(-) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index cd27c97cb..895b4dec8 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1020,13 +1020,13 @@ mod tests { use openmls::prelude::tls_codec::Deserialize; use openmls::prelude::MlsMessageIn; use openmls::prelude::MlsMessageBodyIn; - use openmls::prelude::Proposal; use openmls::prelude::ProcessedMessageContent; use crate::groups::GroupMessageVersion; use prost::Message; use tracing_test::traced_test; use xmtp_cryptography::utils::generate_local_wallet; use xmtp_proto::xmtp::mls::message_contents::EncodedContent; + use xmtp_proto::xmtp::mls::api::v1::GroupMessage; use crate::{ assert_logged, @@ -1374,7 +1374,7 @@ mod tests { .await .unwrap(); - assert_eq!(messages.len(), 1); + assert_eq!(messages.len(), 3); } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] @@ -1510,18 +1510,14 @@ mod tests { messages = client_a.api_client.query_group_messages(group.group_id.clone(), None).await.unwrap(); assert_eq!(messages.len(), 3); - // call pre_intent_hook on client B. No new message request because we have updated the `rotated_time` which is not equal to 0 now. - group.pre_intent_hook(&client_b).await.unwrap(); - messages = client_a.api_client.query_group_messages(group.group_id.clone(), None).await.unwrap(); - assert_eq!(messages.len(),3); - - // can key_update on client B. We have one more queried message. - group.key_update(&client_b).await.unwrap(); - messages = client_a.api_client.query_group_messages(group.group_id.clone(), None).await.unwrap(); - assert_eq!(messages.len(),4); + // call pre_intent_hook on client B. + client_b_group.pre_intent_hook(&client_b).await.unwrap(); // Verify client A receives a key rotation payload - + messages = client_b.api_client.query_group_messages(group.group_id.clone(), None).await.unwrap(); + assert_eq!(messages.len(),4); + + // steps to get the leaf node of the updated path. let first_message = &messages[messages.len()-1]; let msgv1 = match &first_message.version { @@ -1540,14 +1536,81 @@ mod tests { let mut openmls_group = group.load_mls_group(&provider).unwrap(); let decrypted_message = openmls_group.process_message(&provider, mls_message).unwrap(); - // let staged_commit = match decrypted_message.into_content(){ - // ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, - // _ => panic!("error staged_commit"), - // }; + let staged_commit = match decrypted_message.into_content(){ + ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, + _ => panic!("error staged_commit"), + }; + + // check there is indeed some updated leaf node, which means the key update works. + let path_update_leaf_node = staged_commit.update_path_leaf_node(); + assert!(path_update_leaf_node.is_some()); + + // call pre_intent_hook on client B again, client A receives nothing new. + client_b_group.pre_intent_hook(&client_b).await.unwrap(); + messages = client_b.api_client.query_group_messages(group.group_id.clone(), None).await.unwrap(); + assert_eq!(messages.len(),4); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_send_message_with_pre_intent_hook(){ + let client_a = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let client_b = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + // client A makes a group with client B. + let group = client_a.create_group(None, GroupMetadataOptions::default()).expect("create group"); + group + .add_members_by_inbox_id(&client_a, vec![client_b.inbox_id()]) + .await + .unwrap(); + + // client B creates it from welcome + let client_b_group = receive_group_invite(&client_b).await; + client_b_group.sync(&client_b).await.unwrap(); + + // Verify no new payloads on client A + let mut messages = client_a.api_client.query_group_messages(group.group_id.clone(), None).await.unwrap(); + assert_eq!(messages.len(), 3); + + // Client B sends a message to Client A + let b_message = b"hello from client b"; + client_b_group.send_message(b_message, &client_b).await.expect("send message"); + + // Verify client A receives a key rotation. + messages = client_b.api_client.query_group_messages(group.group_id.clone(), None).await.unwrap(); + assert_eq!(messages.len(), 5); + + // Steps to get the leaf node of the updated path. + let queried_message = &messages[messages.len()-2]; + + let msgv1 = match &queried_message.version { + Some(GroupMessageVersion::V1(value)) => value, + _ => panic!("error msgv1"), + }; + + let mls_message_in = MlsMessageIn::tls_deserialize_exact(&msgv1.data).unwrap(); + let mls_message = match mls_message_in.extract() { + MlsMessageBodyIn::PrivateMessage(mls_message) => mls_message, + _ => panic!("error mls_message"), + }; + + let conn = &client_a.context.store.conn().unwrap(); + let provider = client_a.mls_provider(conn.clone()); + let mut openmls_group = group.load_mls_group(&provider).unwrap(); + let decrypted_message = openmls_group.process_message(&provider, mls_message).unwrap(); + + let staged_commit = match decrypted_message.into_content(){ + ProcessedMessageContent::StagedCommitMessage(staged_commit) => *staged_commit, + _ => panic!("error staged_commit"), + }; + + // Check there is indeed some updated leaf node, which means the key update works. + let path_update_leaf_node = staged_commit.update_path_leaf_node(); + assert!(path_update_leaf_node.is_some()); + + // Verify client A receives the message. + let message = get_latest_message(&group, &client_a).await; + assert_eq!(message.decrypted_message_bytes, b_message); - // let leaf_node = match staged_commit.update_path_leaf_node(){ - // Some() - // }; } #[tokio::test(flavor = "multi_thread", worker_threads = 1)]