Skip to content

Commit

Permalink
fix ensure_initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-maron committed Jul 2, 2024
1 parent d2347eb commit 0938936
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 15 deletions.
22 changes: 12 additions & 10 deletions cdn-client/src/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub struct Inner<C: ConnectionDef> {
connection: Arc<RwLock<Option<Connection>>>,

/// The semaphore to ensure only one reconnection is happening at a time
connecting_guard: Semaphore,
connecting_guard: Arc<Semaphore>,

/// The keypair to use when authenticating
pub keypair: KeyPair<Scheme<C>>,
Expand Down Expand Up @@ -159,7 +159,7 @@ impl<C: ConnectionDef> Retry<C> {
let Some(connection) = possible_connection else {
// If the connection is not initialized for one reason or another, try to reconnect
// Acquire the semaphore to ensure only one reconnection is happening at a time
if let Ok(permit) = self.inner.connecting_guard.try_acquire() {
if let Ok(permit) = Arc::clone(&self.inner.connecting_guard).try_acquire_owned() {
// We were the first to try reconnecting, spawn a reconnection task
let inner = self.inner.clone();
spawn(async move {
Expand All @@ -174,15 +174,15 @@ impl<C: ConnectionDef> Retry<C> {
*connection = Some(new_connection);
break;
}
Err(e) => {
Err(err) => {
// We failed to reconnect
// Sleep for 2 seconds and then try again
error!(error = %e, "failed to reconnect");
error!("failed to connect: {err}");
sleep(Duration::from_secs(2)).await;
}
}
}
_ = permit;
drop(permit);
});
}

Expand Down Expand Up @@ -243,7 +243,7 @@ impl<C: ConnectionDef> Retry<C> {
use_local_authority,
// TODO: parameterize batch params
connection: Arc::default(),
connecting_guard: Semaphore::const_new(1),
connecting_guard: Arc::from(Semaphore::const_new(1)),
keypair,
subscribed_topics,
}),
Expand All @@ -252,11 +252,13 @@ impl<C: ConnectionDef> Retry<C> {

/// Returns only when the connection is fully initialized
pub async fn ensure_initialized(&self) {
// Try to get the underlying connection
while let Err(err) = self.get_connection().await {
error!("failed to initialize connection: {err}");
sleep(Duration::from_secs(2)).await;
// If we are already connected, return
if self.try_get_connection().is_ok() {
return;
}

// Otherwise, wait to acquire the connecting guard
let _ = self.inner.connecting_guard.acquire().await;
}

/// Sends a message to the underlying connection. Reconnection is handled under
Expand Down
8 changes: 8 additions & 0 deletions tests/src/tests/basic_connect.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::time::Duration;

use cdn_proto::{
def::TestTopic,
message::{Direct, Message},
};
use tokio::time::timeout;

use crate::tests::*;

Expand All @@ -21,6 +24,11 @@ async fn test_end_to_end_connection() {
let client = new_client(0, vec![TestTopic::Global as u8], "8082");
let client_public_key = keypair_from_seed(0).1;

// Ensure we are connected
let Ok(()) = timeout(Duration::from_secs(1), client.ensure_initialized()).await else {
panic!("client failed to connect");
};

// Send a message to ourself
client
.send_direct_message(&client_public_key, b"hello direct".to_vec())
Expand Down
32 changes: 27 additions & 5 deletions tests/src/tests/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ async fn test_subscribe() {
// Create and get the handle to a new client subscribed to the global topic
let client = new_client(0, vec![TestTopic::Global as u8], "8097");

// Ensure the client is connected
let Ok(()) = timeout(Duration::from_secs(1), client.ensure_initialized()).await else {
panic!("client failed to connect");
};

// Send a message to the global topic
client
.send_broadcast_message(vec![TestTopic::Global as u8], b"hello global".to_vec())
Expand Down Expand Up @@ -120,8 +125,8 @@ async fn test_invalid_subscribe() {
// Create and start a new marshal
new_marshal("8100", &discovery_endpoint).await;

// Create and get the handle to a new client subscribed to an invalid topic
let client = new_client(0, vec![99], "8100");
// Create and get the handle to a new client
let client = new_client(0, vec![], "8100");

// Ensure the connection is open
let Ok(()) = timeout(Duration::from_secs(1), client.ensure_initialized()).await else {
Expand All @@ -143,9 +148,25 @@ async fn test_invalid_subscribe() {
|| client.soft_close().await.is_err(),
"sent message but should've been disconnected"
);
}

// Reinitialize the connection
let Ok(()) = timeout(Duration::from_secs(4), client.ensure_initialized()).await else {
// Test that unsubscribing from an invalid topic kills the connection.
#[tokio::test]
async fn test_invalid_unsubscribe() {
// Get a temporary path for the discovery endpoint
let discovery_endpoint = get_temp_db_path();

// Create and start a new broker
new_broker(0, "8101", "8102", &discovery_endpoint).await;

// Create and start a new marshal
new_marshal("8103", &discovery_endpoint).await;

// Create and get the handle to a new client
let client = new_client(0, vec![], "8103");

// Ensure the connection is open
let Ok(()) = timeout(Duration::from_secs(1), client.ensure_initialized()).await else {
panic!("client failed to connect");
};

Expand All @@ -160,7 +181,8 @@ async fn test_invalid_subscribe() {
client
.send_broadcast_message(vec![1], b"hello invalid".to_vec())
.await
.is_err(),
.is_err()
|| client.soft_close().await.is_err(),
"sent message but should've been disconnected"
);
}

0 comments on commit 0938936

Please sign in to comment.