Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into coda/find-dm-groups
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Oct 15, 2024
2 parents 5167e47 + 6edeed8 commit aa46bd7
Show file tree
Hide file tree
Showing 10 changed files with 846 additions and 207 deletions.
644 changes: 502 additions & 142 deletions bindings_ffi/src/mls.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bindings_node/src/consent_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub enum NapiConsentEntityType {
impl From<NapiConsentEntityType> for ConsentType {
fn from(entity_type: NapiConsentEntityType) -> Self {
match entity_type {
NapiConsentEntityType::GroupId => ConsentType::GroupId,
NapiConsentEntityType::GroupId => ConsentType::ConversationId,
NapiConsentEntityType::InboxId => ConsentType::InboxId,
NapiConsentEntityType::Address => ConsentType::Address,
}
Expand Down
4 changes: 3 additions & 1 deletion bindings_node/src/conversations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFun
use napi::JsFunction;
use napi_derive::napi;
use xmtp_mls::client::FindGroupParams;
use xmtp_mls::groups::group_metadata::ConversationType;
use xmtp_mls::groups::{GroupMetadataOptions, PreconfiguredPolicies};

use crate::messages::NapiMessage;
Expand Down Expand Up @@ -210,7 +211,7 @@ impl NapiConversations {
ThreadsafeFunctionCallMode::Blocking,
);
},
false,
Some(ConversationType::Group),
);

Ok(NapiStreamCloser::new(stream_closer))
Expand All @@ -225,6 +226,7 @@ impl NapiConversations {
move |message| {
tsfn.call(Ok(message.into()), ThreadsafeFunctionCallMode::Blocking);
},
Some(ConversationType::Group),
);

Ok(NapiStreamCloser::new(stream_closer))
Expand Down
9 changes: 5 additions & 4 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ use xmtp_proto::xmtp::mls::api::v1::{
use crate::{
api::ApiClientWrapper,
groups::{
group_permissions::PolicySet, validated_commit::CommitValidationError, GroupError,
GroupMetadataOptions, IntentError, MlsGroup,
group_metadata::ConversationType, group_permissions::PolicySet,
validated_commit::CommitValidationError, GroupError, GroupMetadataOptions, IntentError,
MlsGroup,
},
identity::{parse_credential, Identity, IdentityError},
identity_updates::{load_identity_updates, IdentityUpdateError},
Expand Down Expand Up @@ -224,7 +225,7 @@ pub struct FindGroupParams {
pub created_after_ns: Option<i64>,
pub created_before_ns: Option<i64>,
pub limit: Option<i64>,
pub include_dm_groups: bool,
pub conversation_type: Option<ConversationType>,
}

/// Clients manage access to the network, identity, and data store
Expand Down Expand Up @@ -660,7 +661,7 @@ where
params.created_after_ns,
params.created_before_ns,
params.limit,
params.include_dm_groups,
params.conversation_type,
)?
.into_iter()
.map(|stored_group| {
Expand Down
9 changes: 5 additions & 4 deletions xmtp_mls/src/groups/message_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use xmtp_proto::{
},
};

use super::group_metadata::ConversationType;
use super::{GroupError, MlsGroup};

use crate::XmtpApi;
Expand Down Expand Up @@ -135,7 +136,7 @@ where

pub async fn ensure_member_of_all_groups(&self, inbox_id: String) -> Result<(), GroupError> {
let conn = self.store().conn()?;
let groups = conn.find_groups(None, None, None, None, false)?;
let groups = conn.find_groups(None, None, None, None, Some(ConversationType::Group))?;
for group in groups {
let group = self.group(group.id)?;
Box::pin(group.add_members_by_inbox_id(self, vec![inbox_id.clone()])).await?;
Expand Down Expand Up @@ -384,7 +385,7 @@ where
self.sync_welcomes().await?;

let conn = self.store().conn()?;
let groups = conn.find_groups(None, None, None, None, false)?;
let groups = conn.find_groups(None, None, None, None, Some(ConversationType::Group))?;
for crate::storage::group::StoredGroup { id, .. } in groups.into_iter() {
let group = self.group(id)?;
Box::pin(group.sync(self)).await?;
Expand Down Expand Up @@ -502,14 +503,14 @@ where

async fn prepare_groups_to_sync(&self) -> Result<Vec<StoredGroup>, MessageHistoryError> {
let conn = self.store().conn()?;
Ok(conn.find_groups(None, None, None, None, false)?)
Ok(conn.find_groups(None, None, None, None, Some(ConversationType::Group))?)
}

async fn prepare_messages_to_sync(
&self,
) -> Result<Vec<StoredGroupMessage>, MessageHistoryError> {
let conn = self.store().conn()?;
let groups = conn.find_groups(None, None, None, None, false)?;
let groups = conn.find_groups(None, None, None, None, Some(ConversationType::Group))?;
let mut all_messages: Vec<StoredGroupMessage> = vec![];

for StoredGroup { id, .. } in groups.into_iter() {
Expand Down
20 changes: 11 additions & 9 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,11 @@ impl MlsGroup {
);

stored_group.store(provider.conn_ref())?;
Ok(Self::new(
context.clone(),
group_id,
stored_group.created_at_ns,
))
let new_group = Self::new(context.clone(), group_id, stored_group.created_at_ns);

// Consent state defaults to allowed when the user creates the group
new_group.update_consent_state(ConsentState::Allowed)?;
Ok(new_group)
}

// Create a group from a decrypted and decoded welcome message
Expand Down Expand Up @@ -1041,8 +1041,10 @@ impl MlsGroup {
/// Find the `consent_state` of the group
pub fn consent_state(&self) -> Result<ConsentState, GroupError> {
let conn = self.context.store.conn()?;
let record =
conn.get_consent_record(hex::encode(self.group_id.clone()), ConsentType::GroupId)?;
let record = conn.get_consent_record(
hex::encode(self.group_id.clone()),
ConsentType::ConversationId,
)?;

match record {
Some(rec) => Ok(rec.state),
Expand All @@ -1053,7 +1055,7 @@ impl MlsGroup {
pub fn update_consent_state(&self, state: ConsentState) -> Result<(), GroupError> {
let conn = self.context.store.conn()?;
conn.insert_or_replace_consent_records(vec![StoredConsentRecord::new(
ConsentType::GroupId,
ConsentType::ConversationId,
state,
hex::encode(self.group_id.clone()),
)])?;
Expand Down Expand Up @@ -3229,7 +3231,7 @@ mod tests {
let _ = bola.sync_welcomes().await;
let bola_groups = bola
.find_groups(FindGroupParams {
include_dm_groups: true,
conversation_type: None,
..FindGroupParams::default()
})
.unwrap();
Expand Down
8 changes: 4 additions & 4 deletions xmtp_mls/src/storage/encrypted_store/consent_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use serde::{Deserialize, Serialize};
#[diesel(table_name = consent_records)]
#[diesel(primary_key(entity_type, entity))]
pub struct StoredConsentRecord {
/// Enum, [`ConsentType`] representing the type of consent (group_id inbox_id, etc..)
/// Enum, [`ConsentType`] representing the type of consent (conversation_id inbox_id, etc..)
pub entity_type: ConsentType,
/// Enum, [`ConsentState`] representing the state of consent (allowed, denied, etc..)
pub state: ConsentState,
Expand Down Expand Up @@ -85,8 +85,8 @@ impl DbConnection {
#[diesel(sql_type = Integer)]
/// Type of consent record stored
pub enum ConsentType {
/// Consent is for a group
GroupId = 1,
/// Consent is for a conversation
ConversationId = 1,
/// Consent is for an inbox
InboxId = 2,
/// Consent is for an address
Expand All @@ -109,7 +109,7 @@ where
{
fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
match i32::from_sql(bytes)? {
1 => Ok(ConsentType::GroupId),
1 => Ok(ConsentType::ConversationId),
2 => Ok(ConsentType::InboxId),
3 => Ok(ConsentType::Address),
x => Err(format!("Unrecognized variant {}", x).into()),
Expand Down
45 changes: 36 additions & 9 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ use super::{
db_connection::DbConnection,
schema::{groups, groups::dsl},
};
use crate::{impl_fetch, impl_store, DuplicateItem, StorageError};
use crate::{
groups::group_metadata::ConversationType, impl_fetch, impl_store, DuplicateItem, StorageError,
};

/// The Group ID type.
pub type ID = Vec<u8>;
Expand Down Expand Up @@ -122,7 +124,7 @@ impl DbConnection {
created_after_ns: Option<i64>,
created_before_ns: Option<i64>,
limit: Option<i64>,
include_dm_groups: bool,
conversation_type: Option<ConversationType>,
) -> Result<Vec<StoredGroup>, StorageError> {
let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed();

Expand All @@ -142,8 +144,16 @@ impl DbConnection {
query = query.limit(limit);
}

if !include_dm_groups {
query = query.filter(dsl::dm_inbox_id.is_null());
if let Some(conversation_type) = conversation_type {
match conversation_type {
ConversationType::Group => {
query = query.filter(dsl::dm_inbox_id.is_null());
}
ConversationType::Dm => {
query = query.filter(dsl::dm_inbox_id.is_not_null());
}
ConversationType::Sync => {}
}
}

query = query.filter(dsl::purpose.eq(Purpose::Conversation));
Expand Down Expand Up @@ -481,7 +491,9 @@ pub(crate) mod tests {
let test_group_3 = generate_dm(Some(GroupMembershipState::Allowed));
test_group_3.store(conn).unwrap();

let all_results = conn.find_groups(None, None, None, None, false).unwrap();
let all_results = conn
.find_groups(None, None, None, None, Some(ConversationType::Group))
.unwrap();
assert_eq!(all_results.len(), 2);

let pending_results = conn
Expand All @@ -490,19 +502,27 @@ pub(crate) mod tests {
None,
None,
None,
false,
Some(ConversationType::Group),
)
.unwrap();
assert_eq!(pending_results[0].id, test_group_1.id);
assert_eq!(pending_results.len(), 1);

// Offset and limit
let results_with_limit = conn.find_groups(None, None, None, Some(1), false).unwrap();
let results_with_limit = conn
.find_groups(None, None, None, Some(1), Some(ConversationType::Group))
.unwrap();
assert_eq!(results_with_limit.len(), 1);
assert_eq!(results_with_limit[0].id, test_group_1.id);

let results_with_created_at_ns_after = conn
.find_groups(None, Some(test_group_1.created_at_ns), None, Some(1), false)
.find_groups(
None,
Some(test_group_1.created_at_ns),
None,
Some(1),
Some(ConversationType::Group),
)
.unwrap();
assert_eq!(results_with_created_at_ns_after.len(), 1);
assert_eq!(results_with_created_at_ns_after[0].id, test_group_2.id);
Expand All @@ -512,9 +532,16 @@ pub(crate) mod tests {
assert_eq!(synced_groups.len(), 0);

// test that dm groups are included
let dm_results = conn.find_groups(None, None, None, None, true).unwrap();
let dm_results = conn.find_groups(None, None, None, None, None).unwrap();
assert_eq!(dm_results.len(), 3);
assert_eq!(dm_results[2].id, test_group_3.id);

// test only dms are returned
let dm_results = conn
.find_groups(None, None, None, None, Some(ConversationType::Dm))
.unwrap();
assert_eq!(dm_results.len(), 1);
assert_eq!(dm_results[0].id, test_group_3.id);
})
}

Expand Down
Loading

0 comments on commit aa46bd7

Please sign in to comment.