diff --git a/src/lib.rs b/src/lib.rs index ba7f6ef..180c417 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,8 @@ mod interfaces; mod message; mod network_connect; mod repo; +pub mod share_policy; +pub use share_policy::{SharePolicy, SharePolicyError}; pub use crate::dochandle::DocHandle; pub use crate::interfaces::{ diff --git a/src/repo.rs b/src/repo.rs index 65b8174..571eb03 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -1,6 +1,8 @@ use crate::dochandle::{DocHandle, SharedDocument}; use crate::interfaces::{DocumentId, RepoId}; use crate::interfaces::{NetworkError, RepoMessage, Storage, StorageError}; +use crate::share_policy::ShareDecision; +use crate::{share_policy, SharePolicy, SharePolicyError}; use automerge::sync::{Message as SyncMessage, State as SyncState, SyncDoc}; use automerge::{Automerge, ChangeHash}; use core::pin::Pin; @@ -356,10 +358,6 @@ impl DocState { matches!(self, DocState::LoadPending { .. }) } - fn should_announce(&self) -> bool { - matches!(self, DocState::Sync(_)) - } - fn should_sync(&self) -> bool { matches!(self, DocState::Sync(_)) || matches!( @@ -569,7 +567,7 @@ pub(crate) struct DocumentInfo { /// Ref count for handles(shared with handles). handle_count: Arc, /// Per repo automerge sync state. - sync_states: HashMap, + peer_connections: HashMap, /// Used to resolve futures for DocHandle::changed. change_observers: Vec>>, /// Counter of patches since last save, @@ -577,6 +575,50 @@ pub(crate) struct DocumentInfo { patches_since_last_save: usize, } +/// A state machine representing a connection between a remote repo and a particular document +#[derive(Debug)] +enum PeerConnection { + /// we've accepted the peer and are syncing with them + Accepted(SyncState), + /// We're waiting for a response from the share policy + PendingAuth { received_messages: Vec }, +} + +impl PeerConnection { + fn pending() -> Self { + PeerConnection::PendingAuth { + received_messages: vec![], + } + } + + fn receive_sync_message( + &mut self, + doc: &mut Automerge, + msg: SyncMessage, + ) -> Result<(), automerge::AutomergeError> { + match self { + PeerConnection::Accepted(sync_state) => doc.receive_sync_message(sync_state, msg), + PeerConnection::PendingAuth { received_messages } => { + received_messages.push(msg); + Ok(()) + } + } + } + + fn generate_sync_message(&mut self, doc: &Automerge) -> Option { + match self { + Self::Accepted(sync_state) => doc.generate_sync_message(sync_state), + Self::PendingAuth { .. } => None, + } + } +} + +/// A change requested by a peer connection +enum PeerConnCommand { + /// Request authorization from the share policy + RequestAuth(RepoId), +} + impl DocumentInfo { fn new( state: DocState, @@ -587,7 +629,7 @@ impl DocumentInfo { state, document, handle_count, - sync_states: Default::default(), + peer_connections: Default::default(), change_observers: Default::default(), patches_since_last_save: 0, } @@ -753,42 +795,95 @@ impl DocumentInfo { } /// Apply incoming sync messages, - /// returns whether the document changed due to applying the message. - fn receive_sync_message(&mut self, per_remote: HashMap>) -> bool { + /// + /// # Returns + /// + /// A tuple of `(has_changes, commands)` where `has_changes` is true if the document changed as + /// a result of applying the sync message and `commands` is a list of changes requested by the + /// peer connections for this document (e.g. requesting authorization from the share policy). + fn receive_sync_message( + &mut self, + per_remote: HashMap>, + ) -> (bool, Vec) { + let mut commands = Vec::new(); let (start_heads, new_heads) = { let mut document = self.document.write(); let start_heads = document.automerge.get_heads(); for (repo_id, messages) in per_remote { - let sync_state = self.sync_states.entry(repo_id).or_default(); - - // TODO: remove remote if there is an error. + let conn = match self.peer_connections.entry(repo_id.clone()) { + Entry::Vacant(entry) => { + // if this is a new peer, request authorization + commands.push(PeerConnCommand::RequestAuth(repo_id.clone())); + entry.insert(PeerConnection::pending()) + } + Entry::Occupied(entry) => entry.into_mut(), + }; for message in messages { - document - .automerge - .receive_sync_message(sync_state, message) - .expect("Failed to apply sync message."); + conn.receive_sync_message(&mut document.automerge, message) + .expect("Failed to receive sync message."); } } let new_heads = document.automerge.get_heads(); (start_heads, new_heads) }; - start_heads != new_heads + (start_heads != new_heads, commands) + } + + /// Promote a peer awaiting authorization to a full peer + /// + /// Returns any messages which the peer sent while we were waiting for authorization + fn promote_pending_peer(&mut self, repo_id: &RepoId) -> Option> { + if let Some(PeerConnection::PendingAuth { received_messages }) = + self.peer_connections.remove(repo_id) + { + self.peer_connections + .insert(repo_id.clone(), PeerConnection::Accepted(SyncState::new())); + Some(received_messages) + } else { + tracing::warn!(remote=%repo_id, "Tried to promote a peer which was not pending authorization"); + None + } } /// Potentially generate an outgoing sync message. fn generate_first_sync_message(&mut self, repo_id: RepoId) -> Option { - let sync_state = self.sync_states.entry(repo_id).or_default(); - let document = self.document.read(); - document.automerge.generate_sync_message(sync_state) + match self.peer_connections.entry(repo_id) { + Entry::Vacant(entry) => { + let mut sync_state = SyncState::new(); + let document = self.document.read(); + let message = document.automerge.generate_sync_message(&mut sync_state); + entry.insert(PeerConnection::Accepted(sync_state)); + message + } + Entry::Occupied(mut entry) => match entry.get_mut() { + PeerConnection::PendingAuth { received_messages } => { + let mut document = self.document.write(); + let mut sync_state = SyncState::new(); + for msg in received_messages.drain(..) { + document + .automerge + .receive_sync_message(&mut sync_state, msg) + .expect("Failed to receive sync message."); + } + let message = document.automerge.generate_sync_message(&mut sync_state); + entry.insert(PeerConnection::Accepted(sync_state)); + message + } + PeerConnection::Accepted(ref mut sync_state) => { + let document = self.document.read(); + document.automerge.generate_sync_message(sync_state) + } + }, + } } /// Generate outgoing sync message for all repos we are syncing with. fn generate_sync_messages(&mut self) -> Vec<(RepoId, SyncMessage)> { let document = self.document.read(); - self.sync_states + self.peer_connections .iter_mut() - .filter_map(|(repo_id, sync_state)| { - let message = document.automerge.generate_sync_message(sync_state); + .filter_map(|(repo_id, conn)| { + let message = conn.generate_sync_message(&document.automerge); message.map(|msg| (repo_id.clone(), msg)) }) .collect() @@ -803,6 +898,7 @@ enum WakeSignal { PendingCloseSink(RepoId), Storage(DocumentId), StorageList, + ShareDecision(RepoId), } /// Waking mechanism for stream and sinks. @@ -813,6 +909,7 @@ enum RepoWaker { PendingCloseSink(Sender, RepoId), Storage(Sender, DocumentId), StorageList(Sender), + ShareDecision(Sender, RepoId), } /// @@ -828,6 +925,9 @@ impl ArcWake for RepoWaker { } RepoWaker::Storage(sender, doc_id) => sender.send(WakeSignal::Storage(doc_id.clone())), RepoWaker::StorageList(sender) => sender.send(WakeSignal::StorageList), + RepoWaker::ShareDecision(sender, repo_id) => { + sender.send(WakeSignal::ShareDecision(repo_id.clone())) + } }; } } @@ -848,6 +948,19 @@ struct RemoteRepo { type PendingCloseSinks = Vec>>; +struct PendingShareDecision { + doc_id: DocumentId, + share_type: ShareType, + future: BoxFuture<'static, Result>, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +enum ShareType { + Request, + Announce, + Synchronize, +} + /// The backend of a repo: runs an event-loop in a background thread. pub struct Repo { /// The Id of the repo. @@ -869,6 +982,7 @@ pub struct Repo { /// to poll in the current loop iteration. streams_to_poll: HashSet, sinks_to_poll: HashSet, + share_decisions_to_poll: HashSet, /// Sender and receiver of repo events. repo_sender: Sender, @@ -888,6 +1002,12 @@ pub struct Repo { /// Network sinks that are pending close. pending_close_sinks: HashMap, + + /// The authorization API + share_policy: Box, + + /// Pending share policy futures + pending_share_decisions: HashMap>, } impl Repo { @@ -896,6 +1016,7 @@ impl Repo { let (wake_sender, wake_receiver) = unbounded(); let (repo_sender, repo_receiver) = unbounded(); let repo_id = repo_id.map_or_else(|| RepoId(Uuid::new_v4().to_string()), RepoId); + let share_policy = Box::new(share_policy::Permissive); Repo { repo_id, documents: Default::default(), @@ -912,9 +1033,17 @@ impl Repo { documents_with_changes: Default::default(), storage, pending_close_sinks: Default::default(), + share_policy, + pending_share_decisions: HashMap::new(), + share_decisions_to_poll: HashSet::new(), } } + pub fn with_share_policy(mut self, share_policy: Box) -> Self { + self.share_policy = share_policy; + self + } + fn get_repo_id(&self) -> &RepoId { &self.repo_id } @@ -934,7 +1063,7 @@ impl Repo { fn remove_unused_sync_states(&mut self) { for document_info in self.documents.values_mut() { let sync_keys = document_info - .sync_states + .peer_connections .keys() .cloned() .collect::>(); @@ -942,7 +1071,7 @@ impl Repo { let delenda = sync_keys.difference(&live_keys).collect::>(); for key in delenda { - document_info.sync_states.remove(key); + document_info.peer_connections.remove(key); } } } @@ -1188,26 +1317,23 @@ impl Repo { &self.repo_sender, &self.repo_id, ); - if info.state.should_sync() { - tracing::trace!(remotes=?self.remote_repos.keys().collect::>(), "sending sync message to remotes"); - // Send a sync message to all other repos we are connected with. - for to_repo_id in self.remote_repos.keys().cloned() { - if let Some(message) = - info.generate_first_sync_message(to_repo_id.clone()) - { - let outgoing = NetworkMessage::Sync { - from_repo_id: self.repo_id.clone(), - to_repo_id: to_repo_id.clone(), - document_id: document_id.clone(), - message, - }; - self.pending_messages - .entry(to_repo_id.clone()) - .or_default() - .push_back(outgoing); - self.sinks_to_poll.insert(to_repo_id); - } - } + + let share_type = if info.is_boostrapping() { + Some(ShareType::Request) + } else if info.state.should_sync() { + Some(ShareType::Announce) + } else { + None + }; + if let Some(share_type) = share_type { + Self::enqueue_share_decisions( + self.remote_repos.keys(), + &mut self.pending_share_decisions, + &mut self.share_decisions_to_poll, + self.share_policy.as_ref(), + document_id.clone(), + share_type, + ); } } } @@ -1240,23 +1366,16 @@ impl Repo { self.sinks_to_poll.insert(to_repo_id); } if is_first_edit { - // Send a sync message to all other repos we are connected with. - for repo_id in self.remote_repos.keys() { - if let Some(message) = info.generate_first_sync_message(repo_id.clone()) - { - let outgoing = NetworkMessage::Sync { - from_repo_id: local_repo_id.clone(), - to_repo_id: repo_id.clone(), - document_id: doc_id.clone(), - message, - }; - self.pending_messages - .entry(repo_id.clone()) - .or_default() - .push_back(outgoing); - self.sinks_to_poll.insert(repo_id.clone()); - } - } + // Send a sync message to all other repos we are connected with and with + // whom we should share this document + Self::enqueue_share_decisions( + self.remote_repos.keys(), + &mut self.pending_share_decisions, + &mut self.share_decisions_to_poll, + self.share_policy.as_ref(), + doc_id.clone(), + ShareType::Announce, + ); } } } @@ -1291,7 +1410,6 @@ impl Repo { } }, RepoEvent::LoadDoc(doc_id, resolver) => { - // TODO: handle multiple calls, through a list of resolvers. let mut resolver_clone = resolver.clone(); let info = self.documents.entry(doc_id.clone()).or_insert_with(|| { let storage_fut = self.storage.get(doc_id.clone()); @@ -1358,23 +1476,16 @@ impl Repo { .insert(repo_id.clone(), RemoteRepo { stream, sink }) .is_none()); // Try to sync all docs we know about. - let our_id = self.get_repo_id().clone(); - for (document_id, info) in self.documents.iter_mut() { - if !info.state.should_announce() { - continue; - } - tracing::trace!(?document_id, remote=%repo_id, "sending sync message to new remote"); - if let Some(message) = info.generate_first_sync_message(repo_id.clone()) { - let outgoing = NetworkMessage::Sync { - from_repo_id: our_id.clone(), - to_repo_id: repo_id.clone(), - document_id: document_id.clone(), - message, - }; - self.pending_messages - .entry(repo_id.clone()) - .or_default() - .push_back(outgoing); + for (document_id, info) in self.documents.iter() { + if info.state.should_sync() { + Self::enqueue_share_decisions( + std::iter::once(&repo_id), + &mut self.pending_share_decisions, + &mut self.share_decisions_to_poll, + self.share_policy.as_ref(), + document_id.clone(), + ShareType::Announce, + ); } } self.sinks_to_poll.insert(repo_id.clone()); @@ -1443,11 +1554,25 @@ impl Repo { .get_mut(&document_id) .expect("Doc should have an info by now."); - if info.receive_sync_message(per_remote) { + let (has_changes, peer_conn_commands) = info.receive_sync_message(per_remote); + if has_changes { info.note_changes(); self.documents_with_changes.push(document_id.clone()); } + for cmd in peer_conn_commands { + match cmd { + PeerConnCommand::RequestAuth(peer_id) => Self::enqueue_share_decisions( + std::iter::once(&peer_id), + &mut self.pending_share_decisions, + &mut self.share_decisions_to_poll, + self.share_policy.as_ref(), + document_id.clone(), + ShareType::Synchronize, + ), + } + } + // Note: since receiving and generating sync messages is done // in two separate critical sections, // local changes could be made in between those, @@ -1511,6 +1636,105 @@ impl Repo { } } + fn collect_sharepolicy_responses(&mut self) { + let mut decisions = Vec::new(); + for repo_id in mem::take(&mut self.share_decisions_to_poll) { + if let Some(pending) = self.pending_share_decisions.remove(&repo_id) { + let mut still_pending = Vec::new(); + for PendingShareDecision { + doc_id, + mut future, + share_type, + } in pending + { + let waker = Arc::new(RepoWaker::ShareDecision( + self.wake_sender.clone(), + repo_id.clone(), + )); + let waker = waker_ref(&waker); + let pinned_fut = Pin::new(&mut future); + let result = pinned_fut.poll(&mut Context::from_waker(&waker)); + + match result { + Poll::Pending => { + still_pending.push(PendingShareDecision { + doc_id, + future, + share_type, + }); + } + Poll::Ready(Ok(res)) => { + decisions.push((repo_id.clone(), doc_id, res, share_type)) + } + Poll::Ready(Err(e)) => { + tracing::error!(err=?e, "error while polling share policy decision"); + } + } + } + if !still_pending.is_empty() { + self.pending_share_decisions + .insert(repo_id.clone(), still_pending); + } + } + } + for (peer, doc, share_decision, share_type) in decisions { + let our_id = self.get_repo_id().clone(); + let Some(info) = self.documents.get_mut(&doc) else { + tracing::warn!(document=?doc, peer=?peer, "document not found when evaluating share policy decision result"); + return; + }; + if share_decision == ShareDecision::Share { + match share_type { + ShareType::Announce | ShareType::Request => { + tracing::debug!(%doc, remote=%peer, "sharing document with remote"); + if let Some(pending_messages) = info.promote_pending_peer(&peer) { + tracing::trace!(remote=%peer, %doc, "we already had pending messages for this peer when announcing so we just wait to generate a sync message"); + for message in pending_messages { + self.pending_events.push_back(NetworkEvent::Sync { + from_repo_id: peer.clone(), + to_repo_id: our_id.clone(), + document_id: doc.clone(), + message, + }); + } + } else if let Some(message) = info.generate_first_sync_message(peer.clone()) + { + tracing::trace!(remote=%peer, %doc, "sending first sync message"); + let outgoing = NetworkMessage::Sync { + from_repo_id: our_id.clone(), + to_repo_id: peer.clone(), + document_id: doc.clone(), + message, + }; + self.pending_messages + .entry(peer.clone()) + .or_default() + .push_back(outgoing); + self.sinks_to_poll.insert(peer); + } + } + ShareType::Synchronize => { + tracing::debug!(%doc, remote=%peer, "synchronizing document with remote"); + if let Some(pending_messages) = info.promote_pending_peer(&peer) { + let events = + pending_messages + .into_iter() + .map(|message| NetworkEvent::Sync { + from_repo_id: peer.clone(), + to_repo_id: our_id.clone(), + document_id: doc.clone(), + message, + }); + self.pending_events.extend(events); + } + } + } + } else { + tracing::debug!(?doc, ?peer, "refusing to share document with remote"); + } + } + } + /// The event-loop of the repo. /// Handles events from handles and adapters. /// Returns a handle for optional clean shutdown. @@ -1527,6 +1751,7 @@ impl Repo { let handle = thread::spawn(move || { let _entered = span.entered(); loop { + self.collect_sharepolicy_responses(); self.collect_network_events(); self.sync_documents(); self.process_outgoing_network_messages(); @@ -1570,27 +1795,24 @@ impl Repo { &self.repo_id, ); if info.state.should_sync() { - // Send a sync message to all other repos we are connected with. - for to_repo_id in self.remote_repos.keys().cloned() { - if let Some(message) = info.generate_first_sync_message(to_repo_id.clone()) { - let outgoing = NetworkMessage::Sync { - from_repo_id: self.repo_id.clone(), - to_repo_id: to_repo_id.clone(), - document_id: doc_id.clone(), - message, - }; - self.pending_messages - .entry(to_repo_id.clone()) - .or_default() - .push_back(outgoing); - self.sinks_to_poll.insert(to_repo_id); - } - } + // Send a sync message to all other repos we are connected + // with and with whom we should share this document + Self::enqueue_share_decisions( + self.remote_repos.keys(), + &mut self.pending_share_decisions, + &mut self.share_decisions_to_poll, + self.share_policy.as_ref(), + doc_id.clone(), + ShareType::Announce, + ); } } } WakeSignal::PendingCloseSink(repo_id) => self.poll_close_sinks(repo_id), WakeSignal::StorageList => self.process_pending_storage_list(), + WakeSignal::ShareDecision(repo_id) => { + self.share_decisions_to_poll.insert(repo_id); + } } }, } @@ -1659,6 +1881,7 @@ impl Repo { ); } } + WakeSignal::ShareDecision(_) => {} } } // Shutdown finished. @@ -1685,4 +1908,42 @@ impl Repo { false } } + + fn enqueue_share_decisions<'a, I: Iterator>( + remote_repos: I, + pending_share_decisions: &mut HashMap>, + share_decisions_to_poll: &mut HashSet, + share_policy: &dyn SharePolicy, + document_id: DocumentId, + share_type: ShareType, + ) { + let remote_repos = remote_repos.collect::>(); + match share_type { + ShareType::Request => { + tracing::debug!(remotes=?remote_repos, ?document_id, "checking if we should request this document from remotes"); + } + ShareType::Announce => { + tracing::debug!(remotes=?remote_repos, ?document_id, "checking if we should announce this document to remotes"); + } + ShareType::Synchronize => { + tracing::debug!(remotes=?remote_repos, ?document_id, "checking if we should synchronize this document with remotes"); + } + } + for repo_id in remote_repos { + let future = match share_type { + ShareType::Request => share_policy.should_request(&document_id, repo_id), + ShareType::Announce => share_policy.should_announce(&document_id, repo_id), + ShareType::Synchronize => share_policy.should_sync(&document_id, repo_id), + }; + pending_share_decisions + .entry(repo_id.clone()) + .or_default() + .push(PendingShareDecision { + doc_id: document_id.clone(), + future, + share_type, + }); + share_decisions_to_poll.insert(repo_id.clone()); + } + } } diff --git a/src/share_policy.rs b/src/share_policy.rs new file mode 100644 index 0000000..d30220c --- /dev/null +++ b/src/share_policy.rs @@ -0,0 +1,205 @@ +use futures::{future::BoxFuture, FutureExt}; + +use crate::{DocumentId, RepoId}; + +/// A policy for deciding whether to share a document with a peer +/// +/// There are three situations when we need to decide whether to share a document with a peer: +/// +/// 1. When we receive a sync message from a peer, we need to decide whether to incorporate the +/// changes in the sync message and whether to respond to the sync message with our own changes +/// 2. When we are trying to find a document that we don't have locally we need to decide whether +/// to request the document from other peers we are connected to +/// 3. When we need to decide whether to announce a document to another peer. This happens either +/// when a document is created locally in which case we need to decide which of our connected +/// peers to announce to; or when a peer connects for the first time in which case we need to +/// decide whether to announce any of the documents we have locally to the new peer. +/// +/// This trait is implemented for `Fn(&RepoId, &DocumentId) -> ShareDecision` so if you don't need +/// to make different decisions for these three situations you can just pass a boxed async closure +/// to the repo. +/// +/// ## Examples +/// +/// ### Using the `Fn(&RepoId, &DocumentId) -> ShareDecision` implementation +/// +/// ```no_run +/// use automerge_repo::{Repo, RepoId, DocumentId, share_policy::ShareDecision, Storage}; +/// +/// let storage: Box = unimplemented!(); +/// let repo = Repo::new(None, storage) +/// .with_share_policy(Box::new(|peer, document| { +/// // A share policy which only responds to peers with a particular peer ID +/// if peer == &RepoId::from("some-peer-id") { +/// ShareDecision::Share +/// } else { +/// ShareDecision::DontShare +/// } +/// })); +/// ``` +/// +/// ### Using a custom share policy +/// +/// ```no_run +/// use automerge_repo::{Repo, RepoId, DocumentId, share_policy::{SharePolicy, ShareDecision, SharePolicyError}, Storage}; +/// use futures::future::BoxFuture; +/// use std::sync::Arc; +/// +/// /// A sync policy which only allows request to a particular peer +/// struct OnlyRequestFrom(RepoId); +/// +/// impl SharePolicy for OnlyRequestFrom { +/// fn should_sync( +/// &self, +/// document_id: &DocumentId, +/// with_peer: &RepoId, +/// ) -> BoxFuture<'static, Result> { +/// Box::pin(async move { Ok(ShareDecision::Share) }) +/// } +/// +/// fn should_request( +/// &self, +/// document_id: &DocumentId, +/// from_peer: &RepoId, +/// ) -> BoxFuture<'static, Result> { +/// let us = self.0.clone(); +/// let them = from_peer.clone(); +/// Box::pin(async move { +/// if them == us { +/// Ok(ShareDecision::Share) +/// } else { +/// Ok(ShareDecision::DontShare) +/// } +/// }) +/// } +/// +/// fn should_announce( +/// &self, +/// document_id: &DocumentId, +/// to_peer: &RepoId, +/// ) -> BoxFuture<'static, Result> { +/// Box::pin(async move { Ok(ShareDecision::Share) }) +/// } +/// } +/// ``` +/// +pub trait SharePolicy: Send { + /// Whether we should incorporate changes from this peer into our local document and respond to + /// sync messages from them + fn should_sync( + &self, + document_id: &DocumentId, + with_peer: &RepoId, + ) -> BoxFuture<'static, Result>; + + /// Whether we should request this document from this peer if we don't have the document + /// locally + fn should_request( + &self, + document_id: &DocumentId, + from_peer: &RepoId, + ) -> BoxFuture<'static, Result>; + + /// Whether we should announce this document to this peer + fn should_announce( + &self, + document_id: &DocumentId, + to_peer: &RepoId, + ) -> BoxFuture<'static, Result>; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ShareDecision { + Share, + DontShare, +} + +pub struct SharePolicyError(String); + +impl From for SharePolicyError { + fn from(s: String) -> Self { + Self(s) + } +} + +impl<'a> From<&'a str> for SharePolicyError { + fn from(s: &'a str) -> Self { + Self(s.to_string()) + } +} + +impl std::fmt::Display for SharePolicyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl std::fmt::Debug for SharePolicyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl std::error::Error for SharePolicyError {} + +impl SharePolicy for F +where + F: for<'a, 'b> Fn(&'a RepoId, &'b DocumentId) -> ShareDecision, + F: Send + Sync + 'static, +{ + fn should_sync( + &self, + document_id: &DocumentId, + with_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + let result = self(with_peer, document_id); + std::future::ready(Ok(result)).boxed() + } + + fn should_request( + &self, + document_id: &DocumentId, + from_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + let result = self(from_peer, document_id); + std::future::ready(Ok(result)).boxed() + } + + fn should_announce( + &self, + document_id: &DocumentId, + to_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + let result = self(to_peer, document_id); + std::future::ready(Ok(result)).boxed() + } +} + +/// A share policy which always shares documents with all peers +pub struct Permissive; + +impl SharePolicy for Permissive { + fn should_sync( + &self, + _document_id: &DocumentId, + _with_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + Box::pin(async move { Ok(ShareDecision::Share) }) + } + + fn should_request( + &self, + _document_id: &DocumentId, + _from_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + Box::pin(async move { Ok(ShareDecision::Share) }) + } + + fn should_announce( + &self, + _document_id: &DocumentId, + _to_peer: &RepoId, + ) -> BoxFuture<'static, Result> { + Box::pin(async move { Ok(ShareDecision::Share) }) + } +} diff --git a/tests/network/document_request.rs b/tests/network/document_request.rs index 836c06e..91b0efa 100644 --- a/tests/network/document_request.rs +++ b/tests/network/document_request.rs @@ -1,7 +1,9 @@ extern crate test_utils; +use std::time::Duration; + use automerge::transaction::Transactable; -use automerge_repo::Repo; +use automerge_repo::{DocumentId, Repo, RepoHandle, RepoId}; use test_log::test; use test_utils::storage_utils::{InMemoryStorage, SimpleStorage}; @@ -10,12 +12,12 @@ use crate::tincans::connect_repos; #[test(tokio::test)] async fn test_requesting_document_connected_peers() { // Create two repos. - let repo_1 = Repo::new(None, Box::new(SimpleStorage)); + let repo_1 = Repo::new(Some("repo1".to_string()), Box::new(SimpleStorage)); // Keeping a handle to the storage of repo_2, // to later assert requested doc is saved. let storage = InMemoryStorage::default(); - let repo_2 = Repo::new(None, Box::new(storage.clone())); + let repo_2 = Repo::new(Some("repo2".to_string()), Box::new(storage.clone())); // Run the repos in the background. let repo_handle_1 = repo_1.run(); @@ -43,16 +45,23 @@ async fn test_requesting_document_connected_peers() { let load = repo_handle_2.load(document_handle_1.document_id()); assert_eq!( - doc_handle_future.await.unwrap().unwrap().document_id(), + tokio::time::timeout(Duration::from_millis(100), doc_handle_future) + .await + .expect("load future timed out") + .unwrap() + .expect("document should be found") + .document_id(), document_handle_1.document_id() ); - let _ = tokio::task::spawn_blocking(move || { + + let _ = tokio::task::spawn(async move { // Check that the document has been saved in storage. // TODO: replace the loop with an async notification mechanism. loop { if storage.contains_document(document_handle_1.document_id()) { break; } + tokio::time::sleep(Duration::from_millis(100)).await; } }) .await; @@ -338,3 +347,38 @@ async fn test_request_unavailable_point_to_point() { // Since the repo is stopping, the future should error. assert!(doc_handle_future.await.is_err()); } + +#[test(tokio::test)] +async fn request_doc_which_is_not_shared_does_not_announce() { + let repo_1 = Repo::new(Some("repo1".to_string()), Box::new(SimpleStorage)).with_share_policy( + Box::new(|_peer: &RepoId, _doc_id: &DocumentId| { + automerge_repo::share_policy::ShareDecision::DontShare + }), + ); + let repo_2 = Repo::new(Some("repo2".to_string()), Box::new(SimpleStorage)); + + let repo_handle_1 = repo_1.run(); + let repo_handle_2 = repo_2.run(); + + connect_repos(&repo_handle_1, &repo_handle_2); + + let document_id = create_doc_with_contents(&repo_handle_1, "peer", "repo1"); + + // Wait for the announcement to have (maybe) taken place + tokio::time::sleep(Duration::from_millis(100)).await; + + // now try and resolve the document from storage of repo 2 + let doc_handle = repo_handle_2.load(document_id).await.unwrap(); + assert!(doc_handle.is_none()); +} + +fn create_doc_with_contents(handle: &RepoHandle, key: &str, value: &str) -> DocumentId { + let document_handle = handle.new_document(); + document_handle.with_doc_mut(|doc| { + let mut tx = doc.transaction(); + tx.put(automerge::ROOT, key, value) + .expect("Failed to change the document."); + tx.commit(); + }); + document_handle.document_id() +} diff --git a/tests/network/main.rs b/tests/network/main.rs index 202e7ae..06a4b4a 100644 --- a/tests/network/main.rs +++ b/tests/network/main.rs @@ -1,7 +1,7 @@ extern crate test_utils; use automerge::transaction::Transactable; -use automerge_repo::Repo; +use automerge_repo::{share_policy::ShareDecision, DocumentId, Repo, RepoId}; use futures::{select, FutureExt}; use std::time::Duration; use test_utils::storage_utils::SimpleStorage; @@ -337,3 +337,44 @@ async fn test_streams_chained_on_replacement() { assert!(new_left_closed.load(std::sync::atomic::Ordering::Acquire)); assert!(new_right_closed.load(std::sync::atomic::Ordering::Acquire)); } + +#[test(tokio::test)] +async fn sync_with_unauthorized_peer_never_occurs() { + let repo_handle_1 = Repo::new(Some("repo1".to_string()), Box::new(SimpleStorage)) + .with_share_policy(Box::new(|repo: &RepoId, _: &DocumentId| { + if repo == &RepoId::from("repo2") { + ShareDecision::DontShare + } else { + ShareDecision::Share + } + })) + .run(); + let repo_handle_2 = Repo::new(Some("repo2".to_string()), Box::new(SimpleStorage)).run(); + let repo_handle_3 = Repo::new(Some("repo3".to_string()), Box::new(SimpleStorage)).run(); + + connect_repos(&repo_handle_1, &repo_handle_2); + connect_repos(&repo_handle_1, &repo_handle_3); + + let doc_handle_1 = repo_handle_1.new_document(); + doc_handle_1.with_doc_mut(|doc| { + let mut tx = doc.transaction(); + tx.put( + automerge::ROOT, + "repo_id", + format!("{}", repo_handle_1.get_repo_id()), + ) + .expect("Failed to change the document."); + tx.commit(); + }); + + let doc_handle_2 = repo_handle_2.request_document(doc_handle_1.document_id()); + tokio::time::timeout(Duration::from_secs(1), doc_handle_2) + .await + .expect_err("doc_handle_2 should never resolve"); + + let doc_handle_3 = repo_handle_3.request_document(doc_handle_1.document_id()); + tokio::time::timeout(Duration::from_secs(1), doc_handle_3) + .await + .expect("doc_handle_3 should resolve") + .expect("doc_handle_3 should resolve to a document"); +}