Skip to content

Commit

Permalink
finish test (#897)
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx authored Jul 9, 2024
1 parent ec1da4e commit d9ec08e
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions xmtp_mls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ smart-default = "0.7.1"
thiserror = { workspace = true }
tls_codec = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio-stream = { version = "0.1", features = ["sync"] }
toml = "0.8.4"
xmtp_cryptography = { workspace = true }
xmtp_id = { path = "../xmtp_id" }
Expand Down
8 changes: 8 additions & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use openmls::{
use openmls_traits::OpenMlsProvider;
use prost::EncodeError;
use thiserror::Error;
use tokio::sync::broadcast;

use xmtp_cryptography::signature::{sanitize_evm_addresses, AddressValidationError};
use xmtp_id::{
Expand Down Expand Up @@ -46,6 +47,7 @@ use crate::{
refresh_state::EntityKind,
sql_key_store, EncryptedMessageStore, StorageError,
},
subscriptions::LocalEvents,
verified_key_package_v2::{KeyPackageVerificationError, VerifiedKeyPackageV2},
xmtp_openmls_provider::XmtpOpenMlsProvider,
Fetch, XmtpApi,
Expand Down Expand Up @@ -207,6 +209,7 @@ pub struct Client<ApiClient> {
pub(crate) api_client: ApiClientWrapper<ApiClient>,
pub(crate) context: Arc<XmtpMlsLocalContext>,
pub(crate) history_sync_url: Option<String>,
pub(crate) local_events: broadcast::Sender<LocalEvents>,
}

/// The local context a XMTP MLS needs to function:
Expand Down Expand Up @@ -261,10 +264,12 @@ where
history_sync_url: Option<String>,
) -> Self {
let context = XmtpMlsLocalContext { identity, store };
let (tx, _) = broadcast::channel(10);
Self {
api_client,
context: Arc::new(context),
history_sync_url,
local_events: tx,
}
}

Expand Down Expand Up @@ -339,6 +344,9 @@ where
)
.map_err(Box::new)?;

// notify any streams of the new group
let _ = self.local_events.send(LocalEvents::NewGroup(group.clone()));

Ok(group)
}

Expand Down
81 changes: 80 additions & 1 deletion xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{
use futures::{Stream, StreamExt};
use prost::Message;
use tokio::sync::oneshot::{self, Sender};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage;

use crate::{
Expand All @@ -22,6 +23,14 @@ use crate::{
Client, XmtpApi,
};

/// Events local to this client
/// are broadcast across all senders/receivers of streams
#[derive(Clone, Debug)]
pub(crate) enum LocalEvents {
// a new group was created
NewGroup(MlsGroup),
}

// TODO simplify FfiStreamCloser + StreamCloser duplication
pub struct StreamCloser {
pub close_fn: Arc<Mutex<Option<Sender<()>>>>,
Expand Down Expand Up @@ -117,6 +126,19 @@ where
pub async fn stream_conversations(
&self,
) -> Result<Pin<Box<dyn Stream<Item = MlsGroup> + Send + '_>>, ClientError> {
let event_queue =
tokio_stream::wrappers::BroadcastStream::new(self.local_events.subscribe());

let event_queue = event_queue.filter_map(|event| async move {
match event {
Ok(LocalEvents::NewGroup(g)) => Some(g),
Err(BroadcastStreamRecvError::Lagged(missed)) => {
log::warn!("Missed {missed} messages due to local event queue lagging");
None
}
}
});

let installation_key = self.installation_public_key();
let id_cursor = 0;

Expand All @@ -141,7 +163,7 @@ where
}
});

Ok(Box::pin(stream))
Ok(Box::pin(futures::stream::select(stream, event_queue)))
}

pub(crate) async fn stream_messages(
Expand Down Expand Up @@ -365,6 +387,7 @@ mod tests {
};
use futures::StreamExt;
use std::sync::{Arc, Mutex};
use tokio::sync::Notify;
use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient;
use xmtp_cryptography::utils::generate_local_wallet;

Expand Down Expand Up @@ -546,4 +569,60 @@ mod tests {
let messages = messages.lock().unwrap();
assert_eq!(messages.len(), 5);
}

#[tokio::test(flavor = "multi_thread")]
async fn test_self_group_creation() {
let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);

let groups = Arc::new(Mutex::new(Vec::new()));
let notify = Arc::new(Notify::new());
let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone());

let closer = Client::<GrpcClient>::stream_conversations_with_callback(
alix.clone(),
move |g| {
let mut groups = groups_pointer.lock().unwrap();
groups.push(g);
notify_pointer.notify_one();
},
|| {},
)
.unwrap();

alix.create_group(None, GroupMetadataOptions::default())
.unwrap();

tokio::time::timeout(std::time::Duration::from_secs(60), async {
notify.notified().await
})
.await
.expect("Stream never received group");

{
let grps = groups.lock().unwrap();
assert_eq!(grps.len(), 1);
}

let group = bo
.create_group(None, GroupMetadataOptions::default())
.unwrap();
group
.add_members_by_inbox_id(&bo, vec![alix.inbox_id()])
.await
.unwrap();

tokio::time::timeout(std::time::Duration::from_secs(60), async {
notify.notified().await
})
.await
.expect("Stream never received group");

{
let grps = groups.lock().unwrap();
assert_eq!(grps.len(), 2);
}

closer.end();
}
}

0 comments on commit d9ec08e

Please sign in to comment.