From abe937398fc88c2cd7499e64a4739a2ef31b5800 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Wed, 21 Aug 2024 17:35:33 -0400 Subject: [PATCH] Force Provider to take ownership of `DbConnection` (#982) * force XmtpOpenMlsProvider to take ownership of connection. Dont clone connections * restrict excessive clones of connections --- Cargo.lock | 11 + bindings_ffi/Cargo.lock | 11 + bindings_node/Cargo.lock | 11 + xmtp_mls/Cargo.toml | 1 + xmtp_mls/src/client.rs | 78 ++-- xmtp_mls/src/groups/members.rs | 3 +- xmtp_mls/src/groups/message_history.rs | 4 +- xmtp_mls/src/groups/mod.rs | 147 +++---- xmtp_mls/src/groups/sync.rs | 110 ++--- xmtp_mls/src/groups/validated_commit.rs | 396 +++++++++--------- xmtp_mls/src/identity.rs | 10 +- .../storage/encrypted_store/db_connection.rs | 22 +- xmtp_mls/src/storage/encrypted_store/mod.rs | 35 +- xmtp_mls/src/storage/sql_key_store.rs | 49 ++- xmtp_mls/src/xmtp_openmls_provider.rs | 13 - 15 files changed, 455 insertions(+), 446 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1a18b2acf..23664eb4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4490,6 +4490,16 @@ dependencies = [ "parking_lot", ] +[[package]] +name = "scoped-futures" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1473e24c637950c9bd38763220bea91ec3e095a89f672bbd7a10d03e77ba467" +dependencies = [ + "cfg-if", + "pin-utils", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -6398,6 +6408,7 @@ dependencies = [ "rand", "reqwest 0.12.5", "ring 0.17.8", + "scoped-futures", "serde", "serde_json", "sha2 0.10.8", diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index 8d174507a..7dc79d249 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -4074,6 +4074,16 @@ dependencies = [ "parking_lot", ] +[[package]] +name = "scoped-futures" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1473e24c637950c9bd38763220bea91ec3e095a89f672bbd7a10d03e77ba467" +dependencies = [ + "cfg-if", + "pin-utils", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -5831,6 +5841,7 @@ dependencies = [ "rand", "reqwest 0.12.4", "ring 0.17.8", + "scoped-futures", "serde", "serde_json", "sha2", diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index 292ff858a..9e2f2f37f 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -3827,6 +3827,16 @@ dependencies = [ "parking_lot", ] +[[package]] +name = "scoped-futures" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1473e24c637950c9bd38763220bea91ec3e095a89f672bbd7a10d03e77ba467" +dependencies = [ + "cfg-if", + "pin-utils", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -5288,6 +5298,7 @@ dependencies = [ "rand", "reqwest 0.12.4", "ring 0.17.8", + "scoped-futures", "serde", "serde_json", "sha2", diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index 4b0f47072..eff01d7bd 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -58,6 +58,7 @@ xmtp_cryptography = { workspace = true } xmtp_id = { path = "../xmtp_id" } xmtp_proto = { workspace = true, features = ["proto_full", "convert"] } xmtp_v2 = { path = "../xmtp_v2" } +scoped-futures = "0.1" # Test/Bench Utils xmtp_api_grpc = { path = "../xmtp_api_grpc", optional = true } diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 225d33c2f..c23f26b4c 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -238,15 +238,16 @@ impl XmtpMlsLocalContext { self.identity.sequence_id(conn) } + /// Pulls a new database connection and creates a new provider + pub fn mls_provider(&self) -> Result { + Ok(self.store.conn()?.into()) + } + /// Integrators should always check the `signature_request` return value of this function before calling [`register_identity`](Self::register_identity). /// If `signature_request` returns `None`, then the wallet signature is not required and [`register_identity`](Self::register_identity) can be called with None as an argument. pub fn signature_request(&self) -> Option { self.identity.signature_request() } - - pub(crate) fn mls_provider(&self, conn: DbConnection) -> XmtpOpenMlsProvider { - XmtpOpenMlsProvider::new(conn) - } } impl Client @@ -280,6 +281,11 @@ where self.context.inbox_id() } + /// Pulls a connection and creates a new MLS Provider + pub fn mls_provider(&self) -> Result { + self.context.mls_provider() + } + pub async fn find_inbox_id_from_address( &self, address: String, @@ -332,10 +338,6 @@ where &self.context.identity } - pub(crate) fn mls_provider(&self, conn: DbConnection) -> XmtpOpenMlsProvider { - XmtpOpenMlsProvider::new(conn) - } - pub fn context(&self) -> &Arc { &self.context } @@ -439,7 +441,8 @@ where ) -> Result<(), ClientError> { log::info!("registering identity"); // Register the identity before applying the signature request - let provider = self.mls_provider(self.store().conn()?); + let provider: XmtpOpenMlsProvider = self.store().conn()?.into(); + self.identity() .register(&provider, &self.api_client) .await?; @@ -452,10 +455,9 @@ where /// Upload a new key package to the network replacing an existing key package /// This is expected to be run any time the client receives new Welcome messages pub async fn rotate_key_package(&self) -> Result<(), ClientError> { - let connection = self.store().conn()?; - let kp = self - .identity() - .new_key_package(&self.mls_provider(connection))?; + let provider: XmtpOpenMlsProvider = self.store().conn()?.into(); + + let kp = self.identity().new_key_package(&provider)?; let kp_bytes = kp.tls_serialize_detached()?; self.api_client.upload_key_package(kp_bytes, true).await?; @@ -499,8 +501,7 @@ where ) -> Result, ClientError> { let key_package_results = self.api_client.fetch_key_packages(installation_ids).await?; - let conn = self.store().conn()?; - let mls_provider = self.mls_provider(conn); + let mls_provider = self.mls_provider()?; Ok(key_package_results .values() .map(|bytes| VerifiedKeyPackageV2::from_bytes(mls_provider.crypto(), bytes.as_slice())) @@ -522,7 +523,7 @@ where .transaction_async(|provider| async move { let is_updated = provider - .conn() + .conn_ref() .update_cursor(entity_id, entity_kind, cursor as i64)?; if !is_updated { return Err(MessageProcessingError::AlreadyProcessed(cursor)); @@ -589,30 +590,36 @@ where } pub async fn sync_all_groups(&self, groups: Vec) -> Result<(), GroupError> { + use scoped_futures::ScopedFutureExt; + // Acquire a single connection to be reused - let conn = &self.store().conn()?; + let provider: XmtpOpenMlsProvider = self.mls_provider()?; let sync_futures: Vec<_> = groups .into_iter() .map(|group| { - let conn = conn.clone(); - let mls_provider = self.mls_provider(conn.clone()); - - async move { - log::info!("[{}] syncing group", self.inbox_id()); - log::info!( - "current epoch for [{}] in sync_all_groups() is Epoch: [{}]", - self.inbox_id(), - group.load_mls_group(mls_provider.clone()).unwrap().epoch() - ); - - group - .maybe_update_installations(conn.clone(), None, self) - .await?; - - group.sync_with_conn(conn.clone(), self).await?; - Ok::<(), GroupError>(()) + async { + // create new provider ref that gets moved, leaving original + // provider alone. + let provider_ref = &provider; + async move { + log::info!("[{}] syncing group", self.inbox_id()); + log::info!( + "current epoch for [{}] in sync_all_groups() is Epoch: [{}]", + self.inbox_id(), + group.load_mls_group(provider_ref)?.epoch() + ); + + group + .maybe_update_installations(provider_ref, None, self) + .await?; + + group.sync_with_conn(provider_ref, self).await?; + Ok::<(), GroupError>(()) + } + .await } + .scoped() }) .collect(); @@ -904,8 +911,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_welcome_encryption() { let client = ClientBuilder::new_test_client(&generate_local_wallet()).await; - let conn = client.store().conn().unwrap(); - let provider = client.mls_provider(conn); + let provider = client.mls_provider().unwrap(); let kp = client.identity().new_key_package(&provider).unwrap(); let hpke_public_key = kp.hpke_init_key().as_slice(); diff --git a/xmtp_mls/src/groups/members.rs b/xmtp_mls/src/groups/members.rs index dffee76d1..0cdedf843 100644 --- a/xmtp_mls/src/groups/members.rs +++ b/xmtp_mls/src/groups/members.rs @@ -24,8 +24,7 @@ pub enum PermissionLevel { impl MlsGroup { // Load the member list for the group from the DB, merging together multiple installations into a single entry pub fn members(&self) -> Result, GroupError> { - let conn = self.context.store.conn()?; - let provider = self.context.mls_provider(conn); + let provider = self.mls_provider()?; self.members_with_provider(&provider) } diff --git a/xmtp_mls/src/groups/message_history.rs b/xmtp_mls/src/groups/message_history.rs index 3a5450db8..8b5e26512 100644 --- a/xmtp_mls/src/groups/message_history.rs +++ b/xmtp_mls/src/groups/message_history.rs @@ -125,7 +125,7 @@ where })?; // publish the intent - if let Err(err) = sync_group.publish_intents(conn, self).await { + if let Err(err) = sync_group.publish_intents(&conn.into(), self).await { log::error!("error publishing sync group intents: {:?}", err); } @@ -172,7 +172,7 @@ where })?; // publish the intent - if let Err(err) = sync_group.publish_intents(conn, self).await { + if let Err(err) = sync_group.publish_intents(&conn.into(), self).await { log::error!("error publishing sync group intents: {:?}", err); } Ok(()) diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 1645769b2..9c5d43d6b 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -247,13 +247,12 @@ impl MlsGroup { /// Instantiate a new [`XmtpOpenMlsProvider`] pulling a connection from the database. /// prefer to use an already-instantiated mls provider if possible. pub fn mls_provider(&self) -> Result { - let conn = self.context.store.conn()?; - Ok(self.context.mls_provider(conn)) + Ok(self.context.store.conn()?.into()) } // Load the stored MLS group from the OpenMLS provider's keystore #[tracing::instrument(level = "trace", skip_all)] - pub fn load_mls_group( + pub(crate) fn load_mls_group( &self, provider: impl OpenMlsProvider, ) -> Result { @@ -350,7 +349,7 @@ impl MlsGroup { validate_initial_group_membership(client, provider.conn_ref(), &mls_group).await?; - let stored_group = provider.conn().insert_or_replace_group(to_store)?; + let stored_group = provider.conn_ref().insert_or_replace_group(to_store)?; Ok(Self::new( client.context.clone(), @@ -444,18 +443,21 @@ impl MlsGroup { { let update_interval = Some(5_000_000); // 5 seconds in nanoseconds let conn = self.context.store.conn()?; - self.maybe_update_installations(conn.clone(), update_interval, client) + let provider = XmtpOpenMlsProvider::from(conn); + self.maybe_update_installations(&provider, update_interval, client) .await?; - let message_id = - self.prepare_message(message, &conn, |now| Self::into_envelope(message, now)); + let message_id = self.prepare_message(message, provider.conn_ref(), |now| { + Self::into_envelope(message, now) + }); // Skipping a full sync here and instead just firing and forgetting - if let Err(err) = self.publish_intents(conn.clone(), client).await { + if let Err(err) = self.publish_intents(&provider, client).await { log::error!("Send: error publishing intents: {:?}", err); } - self.sync_until_last_intent_resolved(conn, client).await?; + self.sync_until_last_intent_resolved(&provider, client) + .await?; message_id } @@ -469,11 +471,13 @@ impl MlsGroup { ApiClient: XmtpApi, { let conn = self.context.store.conn()?; + let provider = XmtpOpenMlsProvider::from(conn); let update_interval = Some(5_000_000); - self.maybe_update_installations(conn.clone(), update_interval, client) + self.maybe_update_installations(&provider, update_interval, client) + .await?; + self.publish_intents(&provider, client).await?; + self.sync_until_last_intent_resolved(&provider, client) .await?; - self.publish_intents(conn.clone(), client).await?; - self.sync_until_last_intent_resolved(conn, client).await?; Ok(()) } @@ -607,8 +611,7 @@ impl MlsGroup { client: &Client, inbox_ids: Vec, ) -> Result<(), GroupError> { - let conn = client.store().conn()?; - let provider = client.mls_provider(conn); + let provider = client.mls_provider()?; let intent_data = self .get_membership_update_intent(client, &provider, inbox_ids, vec![]) .await?; @@ -621,13 +624,15 @@ impl MlsGroup { return Ok(()); } - let intent = provider.conn().insert_group_intent(NewGroupIntent::new( - IntentKind::UpdateGroupMembership, - self.group_id.clone(), - intent_data.into(), - ))?; + let intent = provider + .conn_ref() + .insert_group_intent(NewGroupIntent::new( + IntentKind::UpdateGroupMembership, + self.group_id.clone(), + intent_data.into(), + ))?; - self.sync_until_intent_resolved(provider.conn(), intent.id, client) + self.sync_until_intent_resolved(&provider, intent.id, client) .await } @@ -648,8 +653,8 @@ impl MlsGroup { client: &Client, inbox_ids: Vec, ) -> Result<(), GroupError> { - let conn = client.store().conn()?; - let provider = client.mls_provider(conn); + let provider = client.store().conn()?.into(); + let intent_data = self .get_membership_update_intent(client, &provider, vec![], inbox_ids) .await?; @@ -662,7 +667,7 @@ impl MlsGroup { intent_data.into(), ))?; - self.sync_until_intent_resolved(provider.conn(), intent.id, client) + self.sync_until_intent_resolved(&provider, intent.id, client) .await } @@ -683,7 +688,7 @@ impl MlsGroup { intent_data, ))?; - self.sync_until_intent_resolved(conn, intent.id, client) + self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } @@ -715,7 +720,7 @@ impl MlsGroup { intent_data, ))?; - self.sync_until_intent_resolved(conn, intent.id, client) + self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } @@ -749,7 +754,7 @@ impl MlsGroup { intent_data, ))?; - self.sync_until_intent_resolved(conn, intent.id, client) + self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } @@ -784,7 +789,7 @@ impl MlsGroup { intent_data, ))?; - self.sync_until_intent_resolved(conn, intent.id, client) + self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } @@ -821,7 +826,7 @@ impl MlsGroup { intent_data, ))?; - self.sync_until_intent_resolved(conn, intent.id, client) + self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } @@ -896,7 +901,7 @@ impl MlsGroup { intent_data, ))?; - self.sync_until_intent_resolved(conn, intent.id, client) + self.sync_until_intent_resolved(&conn.into(), intent.id, client) .await } @@ -921,7 +926,7 @@ impl MlsGroup { let intent = NewGroupIntent::new(IntentKind::KeyUpdate, self.group_id.clone(), vec![]); intent.store(&conn)?; - self.sync_with_conn(conn, client).await + self.sync_with_conn(&conn.into(), client).await } pub fn is_active(&self, provider: impl OpenMlsProvider) -> Result { @@ -1257,7 +1262,6 @@ mod tests { members::{GroupMember, PermissionLevel}, DeliveryStatus, GroupMetadataOptions, PreconfiguredPolicies, UpdateAdminListType, }, - identity_updates::tests::sign_with_wallet, storage::{ group_intent::IntentState, group_message::{GroupMessageKind, StoredGroupMessage}, @@ -1302,8 +1306,7 @@ mod tests { sender_mls_group: &mut OpenMlsGroup, sender_provider: &XmtpOpenMlsProvider, ) { - let new_member_provider = - new_member_client.mls_provider(new_member_client.store().conn().unwrap()); + let new_member_provider = new_member_client.mls_provider().unwrap(); let key_package = new_member_client .identity() @@ -1371,7 +1374,7 @@ mod tests { .expect("send message"); group - .receive(&client.store().conn().unwrap(), &client) + .receive(&client.store().conn().unwrap().into(), &client) .await .unwrap(); // Check for messages @@ -1498,27 +1501,24 @@ mod tests { .expect("bola's add should succeed in a no-op"); amal_group - .receive(&amal.store().conn().unwrap(), &amal) + .receive(&amal.store().conn().unwrap().into(), &amal) .await .expect_err("expected error"); // Check Amal's MLS group state. - let amal_db = amal.context.store.conn().unwrap(); - let amal_mls_group = amal_group - .load_mls_group(amal.mls_provider(amal_db.clone())) - .unwrap(); + let amal_db = XmtpOpenMlsProvider::from(amal.context.store.conn().unwrap()); + let amal_mls_group = amal_group.load_mls_group(&amal_db).unwrap(); let amal_members: Vec = amal_mls_group.members().collect(); assert_eq!(amal_members.len(), 3); // Check Bola's MLS group state. - let bola_db = bola.context.store.conn().unwrap(); - let bola_mls_group = bola_group - .load_mls_group(bola.mls_provider(bola_db.clone())) - .unwrap(); + let bola_db = XmtpOpenMlsProvider::from(bola.context.store.conn().unwrap()); + let bola_mls_group = bola_group.load_mls_group(&bola_db).unwrap(); let bola_members: Vec = bola_mls_group.members().collect(); assert_eq!(bola_members.len(), 3); let amal_uncommitted_intents = amal_db + .conn_ref() .find_group_intents( amal_group.group_id.clone(), Some(vec![IntentState::ToPublish, IntentState::Published]), @@ -1528,6 +1528,7 @@ mod tests { assert_eq!(amal_uncommitted_intents.len(), 0); let bola_failed_intents = bola_db + .conn_ref() .find_group_intents( bola_group.group_id.clone(), Some(vec![IntentState::Error]), @@ -1547,7 +1548,7 @@ mod tests { let alix_group: MlsGroup = alix .create_group(None, GroupMetadataOptions::default()) .unwrap(); - let provider = alix.mls_provider(alix.store().conn().unwrap()); + let provider = alix.mls_provider().unwrap(); // Doctor the group membership let mut mls_group = alix_group.load_mls_group(&provider).unwrap(); let mut existing_extensions = mls_group.extensions().clone(); @@ -1685,8 +1686,7 @@ mod tests { .unwrap(); assert_eq!(messages.len(), 2); - let conn = &client.context.store.conn().unwrap(); - let provider = super::XmtpOpenMlsProvider::new(conn.clone()); + let provider: XmtpOpenMlsProvider = client.context.store.conn().unwrap().into(); let mls_group = group.load_mls_group(&provider).unwrap(); let pending_commit = mls_group.pending_commit(); assert!(pending_commit.is_none()); @@ -1859,8 +1859,7 @@ mod tests { assert_eq!(group.members().unwrap().len(), 2); - let conn = &amal.context.store.conn().unwrap(); - let provider = super::XmtpOpenMlsProvider::new(conn.clone()); + let provider: XmtpOpenMlsProvider = amal.context.store.conn().unwrap().into(); // Finished with setup // add a second installation for amal using the same wallet @@ -2901,12 +2900,12 @@ mod tests { .await .unwrap(); - let conn_1 = bo.store().conn().unwrap(); + let conn_1: XmtpOpenMlsProvider = bo.store().conn().unwrap().into(); let mut conn_2 = bo.store().raw_conn().unwrap(); // Begin an exclusive transaction on a second connection to lock the database conn_2.batch_execute("BEGIN EXCLUSIVE").unwrap(); - let process_result = bo_group.process_messages(bo_messages, conn_1, &bo).await; + let process_result = bo_group.process_messages(bo_messages, &conn_1, &bo).await; if let Some(GroupError::ReceiveErrors(errors)) = process_result.err() { assert_eq!(errors.len(), 1); assert!(errors @@ -2918,52 +2917,4 @@ mod tests { panic!("Expected error") } } - - #[tokio::test(flavor = "multi_thread")] - async fn ensure_removed_after_revoke() { - let alix_wallet = generate_local_wallet(); - let bo_wallet = generate_local_wallet(); - let alix1 = ClientBuilder::new_test_client(&alix_wallet).await; - let alix2 = ClientBuilder::new_test_client(&alix_wallet).await; - let bo = ClientBuilder::new_test_client(&bo_wallet).await; - - let alix_group = alix1 - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - alix_group - .add_members(&alix1, vec![bo_wallet.get_address()]) - .await - .unwrap(); - let bo_group = receive_group_invite(&bo).await; - - // Check the MLS group for the number of members - let bo_provider = bo.mls_provider(bo.store().conn().unwrap()); - let bo_mls_group = bo_group.load_mls_group(&bo_provider).unwrap(); - let members = bo_mls_group.members().collect::>(); - assert_eq!(members.len(), 3); - - let mut revoke_installation_request = alix1 - .revoke_installations(vec![alix2.installation_public_key()]) - .await - .unwrap(); - - sign_with_wallet(&alix_wallet, &mut revoke_installation_request).await; - alix1 - .apply_signature_request(revoke_installation_request) - .await - .unwrap(); - - bo_group.sync(&bo).await.unwrap(); - // Check the MLS group for the number of members after alix2 has been removed - let bo_mls_group = bo_group.load_mls_group(&bo_provider).unwrap(); - let members = bo_mls_group.members().collect::>(); - assert_eq!(members.len(), 2); - - let members = bo_group.members().unwrap(); - let alix_member = members - .iter() - .find(|m| m.inbox_id == alix1.inbox_id()) - .unwrap(); - assert_eq!(alix_member.installation_ids.len(), 1); - } } diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index 0723de47c..0ba94c0de 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -83,24 +83,24 @@ impl MlsGroup { ApiClient: XmtpApi, { let conn = self.context.store.conn()?; - let mls_provider = client.mls_provider(conn.clone()); + let mls_provider = XmtpOpenMlsProvider::from(conn); log::info!("[{}] syncing group", client.inbox_id()); log::info!( "current epoch for [{}] in sync() is Epoch: [{}]", client.inbox_id(), - self.load_mls_group(mls_provider).unwrap().epoch() + self.load_mls_group(&mls_provider).unwrap().epoch() ); - self.maybe_update_installations(conn.clone(), None, client) + self.maybe_update_installations(&mls_provider, None, client) .await?; - self.sync_with_conn(conn, client).await + self.sync_with_conn(&mls_provider, client).await } - #[tracing::instrument(level = "trace", skip(client, self, conn))] - pub async fn sync_with_conn( + #[tracing::instrument(level = "trace", skip(self, provider, client))] + pub(crate) async fn sync_with_conn( &self, - conn: DbConnection, + provider: &XmtpOpenMlsProvider, client: &Client, ) -> Result<(), GroupError> where @@ -108,21 +108,23 @@ impl MlsGroup { { let mut errors: Vec = vec![]; + let conn = provider.conn_ref(); + // Even if publish fails, continue to receiving - if let Err(publish_error) = self.publish_intents(conn.clone(), client).await { + if let Err(publish_error) = self.publish_intents(provider, client).await { log::error!("Sync: error publishing intents {:?}", publish_error); errors.push(publish_error); } // Even if receiving fails, continue to post_commit - if let Err(receive_error) = self.receive(&conn, client).await { + if let Err(receive_error) = self.receive(provider, client).await { log::error!("receive error {:?}", receive_error); // We don't return an error if receive fails, because it's possible this is caused // by malicious data sent over the network, or messages from before the user was // added to the group } - if let Err(post_commit_err) = self.post_commit(&conn, client).await { + if let Err(post_commit_err) = self.post_commit(conn, client).await { log::error!("post commit error {:?}", post_commit_err); errors.push(post_commit_err); } @@ -136,13 +138,13 @@ impl MlsGroup { pub(super) async fn sync_until_last_intent_resolved( &self, - conn: DbConnection, + provider: &XmtpOpenMlsProvider, client: &Client, ) -> Result<(), GroupError> where ApiClient: XmtpApi, { - let intents = conn.find_group_intents( + let intents = provider.conn_ref().find_group_intents( self.group_id.clone(), Some(vec![IntentState::ToPublish, IntentState::Published]), None, @@ -152,7 +154,7 @@ impl MlsGroup { return Ok(()); } - self.sync_until_intent_resolved(conn, intents[intents.len() - 1].id, client) + self.sync_until_intent_resolved(provider, intents[intents.len() - 1].id, client) .await } @@ -163,10 +165,10 @@ impl MlsGroup { * * This method will retry up to `crate::configuration::MAX_GROUP_SYNC_RETRIES` times. */ - #[tracing::instrument(level = "trace", skip(client, self, conn))] + #[tracing::instrument(level = "trace", skip(client, self, provider))] pub(super) async fn sync_until_intent_resolved( &self, - conn: DbConnection, + provider: &XmtpOpenMlsProvider, intent_id: ID, client: &Client, ) -> Result<(), GroupError> @@ -177,12 +179,12 @@ impl MlsGroup { // Return the last error to the caller if we fail to sync let mut last_err: Option = None; while num_attempts < crate::configuration::MAX_GROUP_SYNC_RETRIES { - if let Err(err) = self.sync_with_conn(conn.clone(), client).await { + if let Err(err) = self.sync_with_conn(provider, client).await { log::error!("error syncing group {:?}", err); last_err = Some(err); } - match Fetch::::fetch(&conn, &intent_id) { + match Fetch::::fetch(provider.conn_ref(), &intent_id) { Ok(None) => { // This is expected. The intent gets deleted on success return Ok(()); @@ -271,7 +273,7 @@ impl MlsGroup { message_epoch ); - let conn = provider.conn(); + let conn = provider.conn_ref(); match intent.kind { IntentKind::KeyUpdate | IntentKind::UpdateGroupMembership @@ -301,7 +303,7 @@ impl MlsGroup { ); let maybe_validated_commit = ValidatedCommit::from_staged_commit( client, - &conn, + conn, maybe_pending_commit.expect("already checked"), openmls_group, ) @@ -336,7 +338,7 @@ impl MlsGroup { conn.set_group_intent_to_publish(intent.id)?; } else { // If no error committing the change, write a transcript message - self.save_transcript_message(&conn, validated_commit, envelope_timestamp_ns)?; + self.save_transcript_message(conn, validated_commit, envelope_timestamp_ns)?; } } IntentKind::SendMessage => { @@ -594,7 +596,7 @@ impl MlsGroup { }?; let intent = provider - .conn() + .conn_ref() .find_group_intent_by_payload_hash(sha256(envelope.data.as_slice())); match intent { @@ -676,14 +678,13 @@ impl MlsGroup { pub async fn process_messages( &self, messages: Vec, - conn: DbConnection, + provider: &XmtpOpenMlsProvider, client: &Client, ) -> Result<(), GroupError> where ApiClient: XmtpApi, { - let provider = self.context.mls_provider(conn); - let mut openmls_group = self.load_mls_group(&provider)?; + let mut openmls_group = self.load_mls_group(provider)?; let mut receive_errors = vec![]; for message in messages.into_iter() { @@ -718,18 +719,19 @@ impl MlsGroup { } } - #[tracing::instrument(level = "trace", skip(conn, client, self))] + #[tracing::instrument(level = "trace", skip_all)] pub(super) async fn receive( &self, - conn: &DbConnection, + provider: &XmtpOpenMlsProvider, client: &Client, ) -> Result<(), GroupError> where ApiClient: XmtpApi, { - let messages = client.query_group_messages(&self.group_id, conn).await?; - self.process_messages(messages, conn.clone(), client) + let messages = client + .query_group_messages(&self.group_id, provider.conn_ref()) .await?; + self.process_messages(messages, provider, client).await?; Ok(()) } @@ -780,19 +782,18 @@ impl MlsGroup { Ok(Some(msg)) } - #[tracing::instrument(level = "trace", skip(conn, self, client))] + #[tracing::instrument(level = "trace", skip(self, provider, client))] pub(super) async fn publish_intents( &self, - conn: DbConnection, + provider: &XmtpOpenMlsProvider, client: &Client, ) -> Result<(), GroupError> where ApiClient: XmtpApi, { - let provider = self.context.mls_provider(conn); - let mut openmls_group = self.load_mls_group(&provider)?; + let mut openmls_group = self.load_mls_group(provider)?; - let intents = provider.conn().find_group_intents( + let intents = provider.conn_ref().find_group_intents( self.group_id.clone(), Some(vec![IntentState::ToPublish]), None, @@ -802,7 +803,7 @@ impl MlsGroup { let result = retry_async!( Retry::default(), (async { - self.get_publish_intent_data(&provider, client, &mut openmls_group, &intent) + self.get_publish_intent_data(provider, client, &mut openmls_group, &intent) .await }) ); @@ -814,11 +815,11 @@ impl MlsGroup { log::error!("intent {} has reached max publish attempts", intent.id); // TODO: Eventually clean up errored attempts provider - .conn() + .conn_ref() .set_group_intent_error_and_fail_msg(&intent)?; } else { provider - .conn() + .conn_ref() .increment_intent_publish_attempt_count(intent.id)?; } @@ -837,7 +838,7 @@ impl MlsGroup { intent.id, intent.kind ); - provider.conn().set_group_intent_published( + provider.conn_ref().set_group_intent_published( intent.id, sha256(payload_slice), post_commit_data, @@ -850,7 +851,7 @@ impl MlsGroup { } Ok(None) => { log::info!("Skipping intent because no publish data returned"); - let deleter: &dyn Delete = &provider.conn(); + let deleter: &dyn Delete = provider.conn_ref(); deleter.delete(intent.id)?; } } @@ -1002,7 +1003,7 @@ impl MlsGroup { pub async fn maybe_update_installations( &self, - conn: DbConnection, + provider: &XmtpOpenMlsProvider, update_interval: Option, client: &Client, ) -> Result<(), GroupError> @@ -1016,12 +1017,17 @@ impl MlsGroup { }; let now = crate::utils::time::now_ns(); - let last = conn.get_installations_time_checked(self.group_id.clone())?; + let last = provider + .conn_ref() + .get_installations_time_checked(self.group_id.clone())?; let elapsed = now - last; if elapsed > interval { - let provider = self.context.mls_provider(conn.clone()); - self.add_missing_installations(&provider, client).await?; - conn.update_installations_time_checked(self.group_id.clone())?; + self.add_missing_installations(provider, client) + .await + .unwrap(); + provider + .conn_ref() + .update_installations_time_checked(self.group_id.clone())?; } Ok(()) @@ -1054,14 +1060,14 @@ impl MlsGroup { debug!("Adding missing installations {:?}", intent_data); - let conn = provider.conn(); + let conn = provider.conn_ref(); let intent = conn.insert_group_intent(NewGroupIntent::new( IntentKind::UpdateGroupMembership, self.group_id.clone(), intent_data.into(), ))?; - self.sync_until_intent_resolved(conn, intent.id, client) + self.sync_until_intent_resolved(provider, intent.id, client) .await } @@ -1231,7 +1237,7 @@ async fn apply_update_group_membership_intent( // This function goes to the network and fills in any missing Identity Updates let installation_diff = client .get_installation_diff( - &provider.conn(), + provider.conn_ref(), &old_group_membership, &new_group_membership, &membership_diff, @@ -1311,6 +1317,7 @@ mod tests { use super::*; use crate::builder::ClientBuilder; use futures::future; + use scoped_futures::ScopedFutureExt; use std::sync::Arc; use xmtp_cryptography::utils::generate_local_wallet; @@ -1328,17 +1335,12 @@ mod tests { amal_group.send_message_optimistic(b"5").unwrap(); amal_group.send_message_optimistic(b"6").unwrap(); - let mut futures = vec![]; let conn = amal.context().store.conn().unwrap(); + let provider: XmtpOpenMlsProvider = conn.into(); + let mut futures = vec![]; for _ in 0..10 { - let client = amal.clone(); - let conn = conn.clone(); - let group = amal_group.clone(); - - futures.push(async move { - group.publish_intents(conn, &client).await.unwrap(); - }); + futures.push(amal_group.publish_intents(&provider, &amal).scoped()) } future::join_all(futures).await; } diff --git a/xmtp_mls/src/groups/validated_commit.rs b/xmtp_mls/src/groups/validated_commit.rs index 6c4f8271a..04162b98d 100644 --- a/xmtp_mls/src/groups/validated_commit.rs +++ b/xmtp_mls/src/groups/validated_commit.rs @@ -861,200 +861,204 @@ impl From for GroupUpdatedProto { } // TODO:nm bring these tests back in add/remove members PR +/* +#[cfg(test)] +mod tests { + use openmls::{ + credentials::{BasicCredential, CredentialWithKey}, + extensions::ExtensionType, + messages::proposals::ProposalType, + prelude::Capabilities, + prelude_test::KeyPackage, + }; + use xmtp_api_grpc::Client as GrpcClient; + use xmtp_cryptography::utils::generate_local_wallet; + + use super::ValidatedCommit; + use crate::{ + builder::ClientBuilder, + configuration::{ + CIPHERSUITE, GROUP_MEMBERSHIP_EXTENSION_ID, MUTABLE_METADATA_EXTENSION_ID, + }, + Client, + }; + + fn get_key_package(client: &Client) -> KeyPackage { + client + .identity() + .new_key_package(client.mls_provider().unwrap()) + .unwrap() + } + + #[tokio::test] + async fn test_membership_changes() { + let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bola_key_package = get_key_package(&bola); + + let amal_group = amal.create_group(None, Default::default()).unwrap(); + let amal_provider = amal.mls_provider().unwrap(); + let mut mls_group = amal_group.load_mls_group(&amal_provider).unwrap(); + // Create a pending commit to add bola to the group + mls_group + .add_members( + &amal_provider, + &amal.identity().installation_keys, + &[bola_key_package], + ) + .unwrap(); + + let mut staged_commit = mls_group.pending_commit().unwrap(); + + let validated_commit = ValidatedCommit::from_staged_commit( + &amal, + amal_provider.conn_ref(), + staged_commit, + &mls_group, + ) + .await + .unwrap(); + + assert_eq!(validated_commit.added_inboxes.len(), 1); + assert_eq!(validated_commit.added_inboxes[0].inbox_id, bola.inbox_id()); + // Amal is the creator of the group and the actor + assert!(validated_commit.actor.is_creator); + // Bola is not the creator of the group + assert!(!validated_commit.added_inboxes[0].is_creator); + + // Merge the commit adding bola + mls_group.merge_pending_commit(&amal_provider).unwrap(); + // Now we are going to remove bola + + let bola_leaf_node = mls_group + .members() + .find(|m| { + m.signature_key + .eq(&bola.identity().installation_keys.public()) + }) + .unwrap() + .index; + mls_group + .remove_members( + &amal_provider, + &amal.identity().installation_keys, + &[bola_leaf_node], + ) + .unwrap(); + + staged_commit = mls_group.pending_commit().unwrap(); + let remove_message = ValidatedCommit::from_staged_commit( + &amal, + amal_provider.conn_ref(), + staged_commit, + &mls_group, + ) + .await + .unwrap(); + + assert_eq!(remove_message.removed_inboxes.len(), 1); + } + + #[tokio::test] + async fn test_installation_changes() { + let wallet = generate_local_wallet(); + let amal_1 = ClientBuilder::new_test_client(&wallet).await; + let amal_2 = ClientBuilder::new_test_client(&wallet).await; + + let amal_1_provider = amal_1.mls_provider().unwrap(); + let amal_2_provider = amal_2.mls_provider().unwrap(); + + let amal_group = amal_1.create_group(None, Default::default()).unwrap(); + let mut amal_mls_group = amal_group.load_mls_group(&amal_1_provider).unwrap(); + + let amal_2_kp = amal_2.identity().new_key_package(&amal_2_provider).unwrap(); -// #[cfg(test)] -// mod tests { -// use openmls::{ -// credentials::{BasicCredential, CredentialWithKey}, -// extensions::ExtensionType, -// group::config::CryptoConfig, -// messages::proposals::ProposalType, -// prelude::Capabilities, -// prelude_test::KeyPackage, -// versions::ProtocolVersion, -// }; -// use xmtp_api_grpc::Client as GrpcClient; -// use xmtp_cryptography::utils::generate_local_wallet; - -// use super::ValidatedCommit; -// use crate::{ -// builder::ClientBuilder, -// configuration::{ -// CIPHERSUITE, GROUP_MEMBERSHIP_EXTENSION_ID, MUTABLE_METADATA_EXTENSION_ID, -// }, -// Client, -// }; - -// fn get_key_package(client: &Client) -> KeyPackage { -// client -// .identity() -// .new_key_package(&client.mls_provider(client.store().conn().unwrap())) -// .unwrap() -// } - -// #[tokio::test] -// async fn test_membership_changes() { -// let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; -// let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; -// let bola_key_package = get_key_package(&bola); - -// let amal_group = amal.create_group(None).unwrap(); -// let amal_conn = amal.store().conn().unwrap(); -// let amal_provider = amal.mls_provider(amal_conn); -// let mut mls_group = amal_group.load_mls_group(&amal_provider).unwrap(); -// // Create a pending commit to add bola to the group -// mls_group -// .add_members( -// &amal_provider, -// &amal.identity().installation_keys, -// &[bola_key_package], -// ) -// .unwrap(); - -// let mut staged_commit = mls_group.pending_commit().unwrap(); - -// let validated_commit = ValidatedCommit::from_staged_commit( -// &amal.store().conn().unwrap(), -// staged_commit, -// &mls_group, -// &amal, -// ) -// .await -// .unwrap(); - -// assert_eq!(validated_commit.added_inboxes.len(), 1); -// assert_eq!(validated_commit.added_inboxes[0].inbox_id, bola.inbox_id()); -// // Amal is the creator of the group and the actor -// assert!(validated_commit.actor.is_creator); -// // Bola is not the creator of the group -// assert!(!validated_commit.added_inboxes[0].is_creator); - -// // Merge the commit adding bola -// mls_group.merge_pending_commit(&amal_provider).unwrap(); -// // Now we are going to remove bola - -// let bola_leaf_node = mls_group -// .members() -// .find(|m| { -// m.signature_key -// .eq(&bola.identity.installation_keys.public()) -// }) -// .unwrap() -// .index; -// mls_group -// .remove_members( -// &amal_provider, -// &amal.identity.installation_keys, -// &[bola_leaf_node], -// ) -// .unwrap(); - -// staged_commit = mls_group.pending_commit().unwrap(); -// let remove_message = ValidatedCommit::from_staged_commit(staged_commit, &mls_group) -// .unwrap() -// .unwrap(); - -// assert_eq!(remove_message.members_removed.len(), 1); -// assert_eq!(remove_message.installations_removed.len(), 0); -// } - -// #[tokio::test] -// async fn test_installation_changes() { -// let wallet = generate_local_wallet(); -// let amal_1 = ClientBuilder::new_test_client(&wallet).await; -// let amal_2 = ClientBuilder::new_test_client(&wallet).await; - -// let amal_1_conn = amal_1.store().conn().unwrap(); -// let amal_2_conn = amal_2.store().conn().unwrap(); - -// let amal_1_provider = amal_1().mls_provider(&amal_1_conn); -// let amal_2_provider = amal_2().mls_provider(&amal_2_conn); - -// let amal_group = amal_1.create_group(None).unwrap(); -// let mut amal_mls_group = amal_group.load_mls_group(&amal_1_provider).unwrap(); - -// let amal_2_kp = amal_2.identity.new_key_package(&amal_2_provider).unwrap(); - -// // Add Amal's second installation to the existing group -// amal_mls_group -// .add_members( -// &amal_1_provider, -// &amal_1.identity.installation_keys, -// &[amal_2_kp], -// ) -// .unwrap(); - -// let staged_commit = amal_mls_group.pending_commit().unwrap(); - -// let validated_commit = ValidatedCommit::from_staged_commit(staged_commit, &amal_mls_group) -// .unwrap() -// .unwrap(); - -// assert_eq!(validated_commit.installations_added.len(), 1); -// assert_eq!( -// validated_commit.installations_added[0].installation_ids[0], -// amal_2.installation_public_key() -// ) -// } - -// #[tokio::test] -// async fn test_bad_key_package() { -// let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; -// let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; - -// let amal_conn = amal.store.conn().unwrap(); -// let bola_conn = bola.store.conn().unwrap(); - -// let amal_provider = amal.mls_provider(&amal_conn); -// let bola_provider = bola.mls_provider(&bola_conn); - -// let amal_group = amal.create_group(None).unwrap(); -// let mut amal_mls_group = amal_group.load_mls_group(&amal_provider).unwrap(); - -// let capabilities = Capabilities::new( -// None, -// Some(&[CIPHERSUITE]), -// Some(&[ -// ExtensionType::LastResort, -// ExtensionType::ApplicationId, -// ExtensionType::Unknown(MUTABLE_METADATA_EXTENSION_ID), -// ExtensionType::Unknown(GROUP_MEMBERSHIP_EXTENSION_ID), -// ExtensionType::ImmutableMetadata, -// ]), -// Some(&[ProposalType::GroupContextExtensions]), -// None, -// ); - -// // Create a key package with a malformed credential -// let bad_key_package = KeyPackage::builder() -// .leaf_node_capabilities(capabilities) -// .build( -// CryptoConfig { -// ciphersuite: CIPHERSUITE, -// version: ProtocolVersion::default(), -// }, -// &bola_provider, -// &bola.identity.installation_keys, -// CredentialWithKey { -// // Broken credential -// credential: BasicCredential::new(vec![1, 2, 3]).unwrap().into(), -// signature_key: bola.identity.installation_keys.to_public_vec().into(), -// }, -// ) -// .unwrap(); - -// amal_mls_group -// .add_members( -// &amal_provider, -// &amal.identity.installation_keys, -// &[bad_key_package], -// ) -// .unwrap(); - -// let staged_commit = amal_mls_group.pending_commit().unwrap(); - -// let validated_commit = ValidatedCommit::from_staged_commit(staged_commit, &amal_mls_group); - -// assert!(validated_commit.is_err()); -// } -// } + // Add Amal's second installation to the existing group + amal_mls_group + .add_members( + &amal_1_provider, + &amal_1.identity().installation_keys, + &[amal_2_kp], + ) + .unwrap(); + + let staged_commit = amal_mls_group.pending_commit().unwrap(); + + let validated_commit = ValidatedCommit::from_staged_commit( + &amal_1, + amal_1_provider.conn_ref(), + staged_commit, + &amal_mls_group, + ) + .await + .unwrap(); + + assert_eq!(validated_commit.added_inboxes.len(), 1); + assert_eq!( + validated_commit.added_inboxes[0].inbox_id, + amal_2.inbox_id() + ) + } + + #[tokio::test] + async fn test_bad_key_package() { + let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + let amal_provider = amal.mls_provider().unwrap(); + let bola_provider = bola.mls_provider().unwrap(); + + let amal_group = amal.create_group(None, Default::default()).unwrap(); + let mut amal_mls_group = amal_group.load_mls_group(&amal_provider).unwrap(); + + let capabilities = Capabilities::new( + None, + Some(&[CIPHERSUITE]), + Some(&[ + ExtensionType::LastResort, + ExtensionType::ApplicationId, + ExtensionType::Unknown(MUTABLE_METADATA_EXTENSION_ID), + ExtensionType::Unknown(GROUP_MEMBERSHIP_EXTENSION_ID), + ExtensionType::ImmutableMetadata, + ]), + Some(&[ProposalType::GroupContextExtensions]), + None, + ); + + // Create a key package with a malformed credential + let bad_key_package = KeyPackage::builder() + .leaf_node_capabilities(capabilities) + .build( + CIPHERSUITE, + &bola_provider, + &bola.identity().installation_keys, + CredentialWithKey { + // Broken credential + credential: BasicCredential::new(vec![1, 2, 3]).into(), + signature_key: bola.identity().installation_keys.to_public_vec().into(), + }, + ) + .unwrap(); + + amal_mls_group + .add_members( + &amal_provider, + &amal.identity().installation_keys, + &[bad_key_package.key_package().clone()], + ) + .unwrap(); + + let staged_commit = amal_mls_group.pending_commit().unwrap(); + + let validated_commit = ValidatedCommit::from_staged_commit( + &amal, + amal_provider.conn_ref(), + staged_commit, + &amal_mls_group, + ) + .await; + + assert!(validated_commit.is_err()); + } +} +*/ diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index bb40eee84..8322810ed 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -4,7 +4,7 @@ use crate::configuration::GROUP_PERMISSIONS_EXTENSION_ID; use crate::retry::RetryableError; use crate::storage::db_connection::DbConnection; use crate::storage::identity::StoredIdentity; -use crate::storage::sql_key_store::{SqlKeyStoreError, KEY_PACKAGE_REFERENCES}; +use crate::storage::sql_key_store::{SqlKeyStore, SqlKeyStoreError, KEY_PACKAGE_REFERENCES}; use crate::storage::EncryptedMessageStore; use crate::{ api::{ApiClientWrapper, WrappedApiError}, @@ -66,7 +66,7 @@ impl IdentityStrategy { let conn = store.conn()?; let provider = XmtpOpenMlsProvider::new(conn); let stored_identity: Option = provider - .conn() + .conn_ref() .fetch(&())? .map(|i: StoredIdentity| i.into()); debug!("identity in store: {:?}", stored_identity); @@ -352,7 +352,7 @@ impl Identity { pub(crate) fn new_key_package( &self, - provider: &XmtpOpenMlsProvider, + provider: impl OpenMlsProvider, ) -> Result { let last_resort = Extension::LastResort(LastResortExtension::default()); let key_package_extensions = Extensions::single(last_resort); @@ -382,7 +382,7 @@ impl Identity { .key_package_lifetime(Lifetime::new(6 * 30 * 86400)) .build( CIPHERSUITE, - provider, + &provider, &self.installation_keys, CredentialWithKey { credential: self.credential(), @@ -419,7 +419,7 @@ impl Identity { provider: &XmtpOpenMlsProvider, api_client: &ApiClientWrapper, ) -> Result<(), IdentityError> { - let stored_identity: Option = provider.conn().fetch(&())?; + let stored_identity: Option = provider.conn_ref().fetch(&())?; if stored_identity.is_some() { info!("Identity already registered. skipping key package publishing"); return Ok(()); diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 6227c1a43..d53d51ec8 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -3,12 +3,15 @@ use std::fmt; use std::sync::Arc; use crate::storage::RawDbConnection; +use crate::xmtp_openmls_provider::XmtpOpenMlsProvider; /// A wrapper for RawDbConnection that houses all XMTP DB operations. /// Uses a [`Mutex]` internally for interior mutability, so that the connection /// and transaction state can be shared between the OpenMLS Provider and /// native XMTP operations -#[derive(Clone)] +/// ~~~~ *_NOTE_* ~~~~~ +// Do not derive clone here. +// callers should be able to accomplish everything with one conn/reference. pub struct DbConnection { wrapped_conn: Arc>, } @@ -16,12 +19,16 @@ pub struct DbConnection { /// Owned DBConnection Methods /// Lifetime is 'static' because we are using [`RefOrValue::Value`] variant. impl DbConnection { - pub(crate) fn new(conn: RawDbConnection) -> Self { + pub(super) fn new(conn: RawDbConnection) -> Self { Self { wrapped_conn: Arc::new(Mutex::new(conn)), } } + pub(super) fn from_arc_mutex(conn: Arc>) -> Self { + Self { wrapped_conn: conn } + } + // Note: F is a synchronous fn. If it ever becomes async, we need to use // tokio::sync::mutex instead of std::sync::Mutex pub(crate) fn raw_query(&self, fun: F) -> Result @@ -33,6 +40,17 @@ impl DbConnection { } } +// Forces a move for conn +// This is an important distinction from deriving `Clone` on `DbConnection`. +// This way, conn will be moved into XmtpOpenMlsProvider. This forces codepaths to +// use a connection from the provider, rather than pulling a new one from the pool, resulting +// in two connections in the same scope. +impl From for XmtpOpenMlsProvider { + fn from(conn: DbConnection) -> XmtpOpenMlsProvider { + XmtpOpenMlsProvider::new(conn) + } +} + impl fmt::Debug for DbConnection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("DbConnection") diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index fa18b2e73..f5d2f843d 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -277,19 +277,28 @@ impl EncryptedMessageStore { log::debug!("Transaction async beginning"); let mut connection = self.raw_conn()?; AnsiTransactionManager::begin_transaction(&mut *connection)?; - - let db_connection = DbConnection::new(connection); + let connection = Arc::new(parking_lot::Mutex::new(connection)); + let local_connection = Arc::clone(&connection); + let db_connection = DbConnection::from_arc_mutex(connection); let provider = XmtpOpenMlsProvider::new(db_connection); - let local_provider = provider.clone(); + // the other connection is dropped in the closure + // ensuring we have only one strong reference let result = fun(provider).await; + if Arc::strong_count(&local_connection) > 1 { + log::warn!("More than 1 strong connection references still exist during transaction"); + } + + if Arc::weak_count(&local_connection) > 1 { + log::warn!("More than 1 weak connection references still exist during transaction"); + } // after the closure finishes, `local_provider` should have the only reference ('strong') // to `XmtpOpenMlsProvider` inner `DbConnection`.. - let conn_ref = local_provider.conn_ref(); + let local_connection = DbConnection::from_arc_mutex(local_connection); match result { Ok(value) => { - conn_ref.raw_query(|conn| { + local_connection.raw_query(|conn| { PoolTransactionManager::::commit_transaction(&mut *conn) })?; log::debug!("Transaction async being committed"); @@ -297,7 +306,7 @@ impl EncryptedMessageStore { } Err(err) => { log::debug!("Transaction async being rolled back"); - match conn_ref.raw_query(|conn| { + match local_connection.raw_query(|conn| { PoolTransactionManager::::rollback_transaction( &mut *conn, ) @@ -657,9 +666,9 @@ mod tests { let barrier_pointer = barrier.clone(); let handle = std::thread::spawn(move || { store_pointer.transaction(|provider| { - let conn1 = provider.conn(); + let conn1 = provider.conn_ref(); StoredIdentity::new("correct".to_string(), rand_vec(), rand_vec()) - .store(&conn1) + .store(conn1) .unwrap(); // wait for second transaction to start barrier_pointer.wait(); @@ -673,14 +682,14 @@ mod tests { let handle2 = std::thread::spawn(move || { barrier.wait(); let result = store_pointer.transaction(|provider| -> Result<(), anyhow::Error> { - let connection = provider.conn(); + let connection = provider.conn_ref(); let group = StoredGroup::new( b"should not exist".to_vec(), 0, GroupMembershipState::Allowed, "goodbye".to_string(), ); - group.store(&connection)?; + group.store(connection)?; Ok(()) }); barrier.wait(); @@ -720,9 +729,9 @@ mod tests { let handle = tokio::spawn(async move { store_pointer .transaction_async(|provider| async move { - let conn1 = provider.conn(); + let conn1 = provider.conn_ref(); StoredIdentity::new("crab".to_string(), rand_vec(), rand_vec()) - .store(&conn1) + .store(conn1) .unwrap(); let group = StoredGroup::new( @@ -731,7 +740,7 @@ mod tests { GroupMembershipState::Allowed, "goodbye".to_string(), ); - group.store(&conn1).unwrap(); + group.store(conn1).unwrap(); anyhow::bail!("force a rollback") }) diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index ec18964b0..12f795115 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -25,7 +25,7 @@ struct StorageData { value_bytes: Vec, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SqlKeyStore { // Directly wrap the DbConnection which is a SqliteConnection in this case conn: DbConnection, @@ -36,10 +36,6 @@ impl SqlKeyStore { Self { conn } } - pub fn conn(&self) -> DbConnection { - self.conn.clone() - } - pub fn conn_ref(&self) -> &DbConnection { &self.conn } @@ -48,7 +44,7 @@ impl SqlKeyStore { &self, storage_key: &Vec, ) -> Result, diesel::result::Error> { - self.conn().raw_query(|conn| { + self.conn_ref().raw_query(|conn| { sql_query(SELECT_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -61,7 +57,7 @@ impl SqlKeyStore { storage_key: &Vec, value: &[u8], ) -> Result { - self.conn().raw_query(|conn| { + self.conn_ref().raw_query(|conn| { sql_query(REPLACE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -75,7 +71,7 @@ impl SqlKeyStore { storage_key: &Vec, modified_data: &Vec, ) -> Result { - self.conn().raw_query(|conn| { + self.conn_ref().raw_query(|conn| { sql_query(UPDATE_QUERY) .bind::(&modified_data) .bind::(&storage_key) @@ -223,7 +219,7 @@ impl SqlKeyStore { ) -> Result<(), >::Error> { let storage_key = build_key_from_vec::(label, key.to_vec()); - let _ = self.conn().raw_query(|conn| { + let _ = self.conn_ref().raw_query(|conn| { sql_query(DELETE_QUERY) .bind::(&storage_key) .bind::(VERSION as i32) @@ -804,7 +800,7 @@ impl StorageProvider for SqlKeyStore { let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?"; - let data: Vec = self.conn().raw_query(|conn| { + let data: Vec = self.conn_ref().raw_query(|conn| { sql_query(query) .bind::(&storage_key) .bind::(CURRENT_VERSION as i32) @@ -1046,8 +1042,8 @@ mod tests { ) .unwrap(); - let conn = &store.conn().unwrap(); - let key_store = SqlKeyStore::new(conn.clone()); + let conn = store.conn().unwrap(); + let key_store = SqlKeyStore::new(conn); let signature_keys = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm()).unwrap(); let public_key = StorageId::from(signature_keys.to_public_vec()); @@ -1095,7 +1091,6 @@ mod tests { ) .unwrap(); let conn = store.conn().unwrap(); - let key_store = SqlKeyStore::new(conn.clone()); let provider = XmtpOpenMlsProvider::new(conn); let group_id = GroupId::random(provider.rand()); let proposals = (0..10) @@ -1104,7 +1099,8 @@ mod tests { // Store proposals for (i, proposal) in proposals.iter().enumerate() { - key_store + provider + .storage() .queue_proposal::( &group_id, &ProposalRef(i), @@ -1115,7 +1111,8 @@ mod tests { log::debug!("Finished with queued proposals"); // Read proposal refs - let proposal_refs_read: Vec = key_store + let proposal_refs_read: Vec = provider + .storage() .queued_proposal_refs(&group_id) .expect("Failed to read proposal refs"); assert_eq!( @@ -1125,7 +1122,7 @@ mod tests { // Read proposals let proposals_read: Vec<(ProposalRef, Proposal)> = - key_store.queued_proposals(&group_id).unwrap(); + provider.storage().queued_proposals(&group_id).unwrap(); let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) .map(ProposalRef) .zip(proposals.clone().into_iter()) @@ -1133,18 +1130,19 @@ mod tests { assert_eq!(proposals_expected, proposals_read); // Remove proposal 5 - key_store + provider + .storage() .remove_proposal(&group_id, &ProposalRef(5)) .unwrap(); let proposal_refs_read: Vec = - key_store.queued_proposal_refs(&group_id).unwrap(); + provider.storage().queued_proposal_refs(&group_id).unwrap(); let mut expected = (0..10).map(ProposalRef).collect::>(); expected.remove(5); assert_eq!(expected, proposal_refs_read); let proposals_read: Vec<(ProposalRef, Proposal)> = - key_store.queued_proposals(&group_id).unwrap(); + provider.storage().queued_proposals(&group_id).unwrap(); let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10) .map(ProposalRef) .zip(proposals.clone().into_iter()) @@ -1153,15 +1151,16 @@ mod tests { assert_eq!(proposals_expected, proposals_read); // Clear all proposals - key_store + provider + .storage() .clear_proposal_queue::(&group_id) .unwrap(); let proposal_refs_read: Result, SqlKeyStoreError> = - key_store.queued_proposal_refs(&group_id); + provider.storage().queued_proposal_refs(&group_id); assert!(proposal_refs_read.unwrap().is_empty()); let proposals_read: Result, SqlKeyStoreError> = - key_store.queued_proposals(&group_id); + provider.storage().queued_proposals(&group_id); assert!(proposals_read.unwrap().is_empty()); } @@ -1174,7 +1173,6 @@ mod tests { ) .unwrap(); let conn = store.conn().unwrap(); - let key_store = SqlKeyStore::new(conn.clone()); let provider = XmtpOpenMlsProvider::new(conn); #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)] @@ -1185,12 +1183,13 @@ mod tests { let group_id = GroupId::random(provider.rand()); // Group state - key_store + provider + .storage() .write_group_state(&group_id, &GroupState(77)) .unwrap(); // Read group state - let group_state: Option = key_store.group_state(&group_id).unwrap(); + let group_state: Option = provider.storage().group_state(&group_id).unwrap(); assert_eq!(GroupState(77), group_state.unwrap()); } } diff --git a/xmtp_mls/src/xmtp_openmls_provider.rs b/xmtp_mls/src/xmtp_openmls_provider.rs index dcc93ee1f..b3e7b36a8 100644 --- a/xmtp_mls/src/xmtp_openmls_provider.rs +++ b/xmtp_mls/src/xmtp_openmls_provider.rs @@ -9,15 +9,6 @@ pub struct XmtpOpenMlsProvider { key_store: SqlKeyStore, } -impl Clone for XmtpOpenMlsProvider { - fn clone(&self) -> Self { - Self { - crypto: RustCrypto::default(), - key_store: self.key_store.clone(), - } - } -} - impl XmtpOpenMlsProvider { pub fn new(conn: DbConnection) -> Self { Self { @@ -26,10 +17,6 @@ impl XmtpOpenMlsProvider { } } - pub(crate) fn conn(&self) -> DbConnection { - self.key_store.conn() - } - pub(crate) fn conn_ref(&self) -> &DbConnection { self.key_store.conn_ref() }