Skip to content

Commit

Permalink
Force Provider to take ownership of DbConnection (#982)
Browse files Browse the repository at this point in the history
* force XmtpOpenMlsProvider to take ownership of connection. Dont clone connections

* restrict excessive clones of connections
  • Loading branch information
insipx authored Aug 21, 2024
1 parent 2ca8be1 commit abe9373
Show file tree
Hide file tree
Showing 15 changed files with 455 additions and 446 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions bindings_ffi/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions bindings_node/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions xmtp_mls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ xmtp_cryptography = { workspace = true }
xmtp_id = { path = "../xmtp_id" }
xmtp_proto = { workspace = true, features = ["proto_full", "convert"] }
xmtp_v2 = { path = "../xmtp_v2" }
scoped-futures = "0.1"

# Test/Bench Utils
xmtp_api_grpc = { path = "../xmtp_api_grpc", optional = true }
Expand Down
78 changes: 42 additions & 36 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,16 @@ impl XmtpMlsLocalContext {
self.identity.sequence_id(conn)
}

/// Pulls a new database connection and creates a new provider
pub fn mls_provider(&self) -> Result<XmtpOpenMlsProvider, ClientError> {
Ok(self.store.conn()?.into())
}

/// Integrators should always check the `signature_request` return value of this function before calling [`register_identity`](Self::register_identity).
/// If `signature_request` returns `None`, then the wallet signature is not required and [`register_identity`](Self::register_identity) can be called with None as an argument.
pub fn signature_request(&self) -> Option<SignatureRequest> {
self.identity.signature_request()
}

pub(crate) fn mls_provider(&self, conn: DbConnection) -> XmtpOpenMlsProvider {
XmtpOpenMlsProvider::new(conn)
}
}

impl<ApiClient> Client<ApiClient>
Expand Down Expand Up @@ -280,6 +281,11 @@ where
self.context.inbox_id()
}

/// Pulls a connection and creates a new MLS Provider
pub fn mls_provider(&self) -> Result<XmtpOpenMlsProvider, ClientError> {
self.context.mls_provider()
}

pub async fn find_inbox_id_from_address(
&self,
address: String,
Expand Down Expand Up @@ -332,10 +338,6 @@ where
&self.context.identity
}

pub(crate) fn mls_provider(&self, conn: DbConnection) -> XmtpOpenMlsProvider {
XmtpOpenMlsProvider::new(conn)
}

pub fn context(&self) -> &Arc<XmtpMlsLocalContext> {
&self.context
}
Expand Down Expand Up @@ -439,7 +441,8 @@ where
) -> Result<(), ClientError> {
log::info!("registering identity");
// Register the identity before applying the signature request
let provider = self.mls_provider(self.store().conn()?);
let provider: XmtpOpenMlsProvider = self.store().conn()?.into();

self.identity()
.register(&provider, &self.api_client)
.await?;
Expand All @@ -452,10 +455,9 @@ where
/// Upload a new key package to the network replacing an existing key package
/// This is expected to be run any time the client receives new Welcome messages
pub async fn rotate_key_package(&self) -> Result<(), ClientError> {
let connection = self.store().conn()?;
let kp = self
.identity()
.new_key_package(&self.mls_provider(connection))?;
let provider: XmtpOpenMlsProvider = self.store().conn()?.into();

let kp = self.identity().new_key_package(&provider)?;
let kp_bytes = kp.tls_serialize_detached()?;
self.api_client.upload_key_package(kp_bytes, true).await?;

Expand Down Expand Up @@ -499,8 +501,7 @@ where
) -> Result<Vec<VerifiedKeyPackageV2>, ClientError> {
let key_package_results = self.api_client.fetch_key_packages(installation_ids).await?;

let conn = self.store().conn()?;
let mls_provider = self.mls_provider(conn);
let mls_provider = self.mls_provider()?;
Ok(key_package_results
.values()
.map(|bytes| VerifiedKeyPackageV2::from_bytes(mls_provider.crypto(), bytes.as_slice()))
Expand All @@ -522,7 +523,7 @@ where
.transaction_async(|provider| async move {
let is_updated =
provider
.conn()
.conn_ref()
.update_cursor(entity_id, entity_kind, cursor as i64)?;
if !is_updated {
return Err(MessageProcessingError::AlreadyProcessed(cursor));
Expand Down Expand Up @@ -589,30 +590,36 @@ where
}

pub async fn sync_all_groups(&self, groups: Vec<MlsGroup>) -> Result<(), GroupError> {
use scoped_futures::ScopedFutureExt;

// Acquire a single connection to be reused
let conn = &self.store().conn()?;
let provider: XmtpOpenMlsProvider = self.mls_provider()?;

let sync_futures: Vec<_> = groups
.into_iter()
.map(|group| {
let conn = conn.clone();
let mls_provider = self.mls_provider(conn.clone());

async move {
log::info!("[{}] syncing group", self.inbox_id());
log::info!(
"current epoch for [{}] in sync_all_groups() is Epoch: [{}]",
self.inbox_id(),
group.load_mls_group(mls_provider.clone()).unwrap().epoch()
);

group
.maybe_update_installations(conn.clone(), None, self)
.await?;

group.sync_with_conn(conn.clone(), self).await?;
Ok::<(), GroupError>(())
async {
// create new provider ref that gets moved, leaving original
// provider alone.
let provider_ref = &provider;
async move {
log::info!("[{}] syncing group", self.inbox_id());
log::info!(
"current epoch for [{}] in sync_all_groups() is Epoch: [{}]",
self.inbox_id(),
group.load_mls_group(provider_ref)?.epoch()
);

group
.maybe_update_installations(provider_ref, None, self)
.await?;

group.sync_with_conn(provider_ref, self).await?;
Ok::<(), GroupError>(())
}
.await
}
.scoped()
})
.collect();

Expand Down Expand Up @@ -904,8 +911,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_welcome_encryption() {
let client = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let conn = client.store().conn().unwrap();
let provider = client.mls_provider(conn);
let provider = client.mls_provider().unwrap();

let kp = client.identity().new_key_package(&provider).unwrap();
let hpke_public_key = kp.hpke_init_key().as_slice();
Expand Down
3 changes: 1 addition & 2 deletions xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ pub enum PermissionLevel {
impl MlsGroup {
// Load the member list for the group from the DB, merging together multiple installations into a single entry
pub fn members(&self) -> Result<Vec<GroupMember>, GroupError> {
let conn = self.context.store.conn()?;
let provider = self.context.mls_provider(conn);
let provider = self.mls_provider()?;
self.members_with_provider(&provider)
}

Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/groups/message_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ where
})?;

// publish the intent
if let Err(err) = sync_group.publish_intents(conn, self).await {
if let Err(err) = sync_group.publish_intents(&conn.into(), self).await {
log::error!("error publishing sync group intents: {:?}", err);
}

Expand Down Expand Up @@ -172,7 +172,7 @@ where
})?;

// publish the intent
if let Err(err) = sync_group.publish_intents(conn, self).await {
if let Err(err) = sync_group.publish_intents(&conn.into(), self).await {
log::error!("error publishing sync group intents: {:?}", err);
}
Ok(())
Expand Down
Loading

0 comments on commit abe9373

Please sign in to comment.