Skip to content

Commit

Permalink
Merge pull request #50 from EspressoSystems/abort-on-drop
Browse files Browse the repository at this point in the history
abort task on drop
  • Loading branch information
rob-maron authored Jul 16, 2024
2 parents d28f82f + 26a7b21 commit 4e1a846
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 28 deletions.
2 changes: 1 addition & 1 deletion cdn-broker/src/connections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use cdn_proto::{
connection::{protocols::Connection, UserPublicKey},
discovery::BrokerIdentifier,
message::Topic,
mnemonic,
util::mnemonic,
};
use tokio::task::AbortHandle;
use tracing::{error, info, warn};
Expand Down
19 changes: 11 additions & 8 deletions cdn-broker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use cdn_proto::{
def::{Listener, Protocol, RunDef, Scheme},
discovery::{BrokerIdentifier, DiscoveryClient},
error::{Error, Result},
util::AbortOnDropHandle,
};
use cdn_proto::{crypto::signature::KeyPair, metrics as proto_metrics};
use connections::Connections;
Expand Down Expand Up @@ -245,26 +246,28 @@ impl<R: RunDef> Broker<R> {
// Spawn the heartbeat task, which we use to register with `Discovery` every so often.
// We also use it to check for new brokers who may have joined.
let inner_ = self.inner.clone();
let heartbeat_task = spawn(inner_.run_heartbeat_task());
let heartbeat_task = AbortOnDropHandle(spawn(inner_.run_heartbeat_task()));

// Spawn the sync task, which updates other brokers with our keys periodically.
let inner_ = self.inner.clone();
let sync_task = spawn(inner_.run_sync_task());
let sync_task = AbortOnDropHandle(spawn(inner_.run_sync_task()));

// Spawn the public (user) listener task
// TODO: maybe macro this, since it's repeat code with the private listener task
let inner_ = self.inner.clone();
let user_listener_task = spawn(inner_.clone().run_user_listener_task(self.user_listener));
let user_listener_task = AbortOnDropHandle(spawn(
inner_.clone().run_user_listener_task(self.user_listener),
));

// Spawn the private (broker) listener task
let inner_ = self.inner.clone();
let broker_listener_task = spawn(inner_.run_broker_listener_task(self.broker_listener));
let broker_listener_task =
AbortOnDropHandle(spawn(inner_.run_broker_listener_task(self.broker_listener)));

// Serve the (possible) metrics task
if let Some(metrics_bind_endpoint) = self.metrics_bind_endpoint {
// Spawn the serving task
spawn(proto_metrics::serve_metrics(metrics_bind_endpoint));
}
let _possible_metrics_task = self
.metrics_bind_endpoint
.map(|endpoint| AbortOnDropHandle(spawn(proto_metrics::serve_metrics(endpoint))));

// If one of the tasks exists, we want to return (stopping the program)
select! {
Expand Down
3 changes: 2 additions & 1 deletion cdn-broker/src/tasks/broker/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use cdn_proto::{
discovery::BrokerIdentifier,
error::{Error, Result},
message::{Message, Topic},
mnemonic, verify_broker,
util::mnemonic,
verify_broker,
};
use tokio::{spawn, time::timeout};
use tracing::{debug, error};
Expand Down
3 changes: 2 additions & 1 deletion cdn-broker/src/tasks/user/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use std::time::Duration;
use cdn_proto::connection::{protocols::Connection, UserPublicKey};
use cdn_proto::def::{RunDef, Topic as _};
use cdn_proto::error::{Error, Result};
use cdn_proto::{connection::auth::broker::BrokerAuth, message::Message, mnemonic};
use cdn_proto::util::mnemonic;
use cdn_proto::{connection::auth::broker::BrokerAuth, message::Message};
use tokio::spawn;
use tokio::time::timeout;
use tracing::{error, warn};
Expand Down
1 change: 0 additions & 1 deletion cdn-client/src/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ impl<C: ConnectionDef> Retry<C> {
inner: Arc::from(Inner {
endpoint,
use_local_authority,
// TODO: parameterize batch params
connection: Arc::default(),
connecting_guard: Arc::from(Semaphore::const_new(1)),
keypair,
Expand Down
2 changes: 1 addition & 1 deletion cdn-marshal/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::time::Duration;
use cdn_proto::{
connection::{auth::marshal::MarshalAuth, protocols::Connection},
def::RunDef,
mnemonic,
util::mnemonic,
};
use tokio::time::timeout;
use tracing::info;
Expand Down
8 changes: 4 additions & 4 deletions cdn-marshal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use cdn_proto::{
discovery::DiscoveryClient,
error::{Error, Result},
metrics as proto_metrics,
util::AbortOnDropHandle,
};
use tokio::spawn;
use tracing::info;
Expand Down Expand Up @@ -143,10 +144,9 @@ impl<R: RunDef> Marshal<R> {
/// Right now, we return a `Result` but don't actually ever error.
pub async fn start(self) -> Result<()> {
// Serve the (possible) metrics task
if let Some(metrics_bind_endpoint) = self.metrics_bind_endpoint {
// Spawn the serving task
spawn(proto_metrics::serve_metrics(metrics_bind_endpoint));
}
let _possible_metrics_task = self
.metrics_bind_endpoint
.map(|endpoint| AbortOnDropHandle(spawn(proto_metrics::serve_metrics(endpoint))));

// Listen for connections forever
loop {
Expand Down
1 change: 0 additions & 1 deletion cdn-proto/src/connection/auth/broker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ impl<R: RunDef> BrokerAuth<R> {

// See if we're the right type of message
let Message::AuthenticateWithKey(auth_message) = auth_message else {
// TODO: macro for this error thing
fail_verification_with_message!(connection, "wrong message type");
};

Expand Down
11 changes: 1 addition & 10 deletions cdn-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

#![forbid(unsafe_code)]

use std::hash::{Hash, Hasher};

pub mod connection;
pub mod crypto;
pub mod def;
pub mod discovery;
pub mod error;
pub mod message;
pub mod util;

#[cfg(feature = "metrics")]
pub mod metrics;
Expand All @@ -24,11 +23,3 @@ pub const MAX_MESSAGE_SIZE: u32 = u32::MAX / 8;
pub mod messages_capnp {
include!("../schema/messages_capnp.rs");
}

/// A function for generating a cute little user mnemonic from a hash
#[must_use]
pub fn mnemonic<H: Hash>(bytes: H) -> String {
let mut state = std::collections::hash_map::DefaultHasher::new();
bytes.hash(&mut state);
mnemonic::to_string(state.finish().to_le_bytes())
}
32 changes: 32 additions & 0 deletions cdn-proto/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::{
future::Future,
hash::{Hash, Hasher},
pin::Pin,
task::{Context, Poll},
};

use tokio::task::{JoinError, JoinHandle};

/// A function for generating a cute little user mnemonic from a hash
#[must_use]
pub fn mnemonic<H: Hash>(bytes: H) -> String {
let mut state = std::collections::hash_map::DefaultHasher::new();
bytes.hash(&mut state);
mnemonic::to_string(state.finish().to_le_bytes())
}

/// A wrapper for a `JoinHandle` that will abort the task if dropped
pub struct AbortOnDropHandle<T>(pub JoinHandle<T>);

impl<T> Drop for AbortOnDropHandle<T> {
fn drop(&mut self) {
self.0.abort();
}
}

impl<T> Future for AbortOnDropHandle<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}

0 comments on commit 4e1a846

Please sign in to comment.