diff --git a/src/drivers/fake.rs b/src/drivers/fake.rs index 58b89b33..b803cf2e 100644 --- a/src/drivers/fake.rs +++ b/src/drivers/fake.rs @@ -11,14 +11,16 @@ use crate::{ }; pub struct FakeSink { - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, print: bool, } impl FakeSink { pub fn builder() -> FakeSinkBuilder { FakeSinkBuilder(Self { - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), print: false, }) } @@ -36,11 +38,19 @@ impl FakeSinkBuilder { self } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); + self + } + + pub fn on_message_output(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -51,9 +61,9 @@ impl Driver for FakeSink { let mut hub_receiver = hub_sender.subscribe(); while let Ok(message) = hub_receiver.recv().await { - for future in self.on_message.call_all(Arc::clone(&message)) { + for future in self.on_message_input.call_all(Arc::clone(&message)) { if let Err(error) = future.await { - debug!("Dropping message: on_message callback returned error: {error:?}"); + debug!("Dropping message: on_message_input callback returned error: {error:?}"); continue; } } @@ -113,14 +123,16 @@ impl DriverInfo for FakeSinkInfo { pub struct FakeSource { period: std::time::Duration, - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, } impl FakeSource { pub fn builder(period: std::time::Duration) -> FakeSourceBuilder { FakeSourceBuilder(Self { period, - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), }) } } @@ -132,11 +144,19 @@ impl FakeSourceBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message_input.add_callback(callback.into_boxed()); + self + } + + pub fn on_message_output(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -182,10 +202,10 @@ impl Driver for FakeSource { async move { trace!("Fake message created: {message:?}"); - for future in self.on_message.call_all(Arc::clone(&message)) { + for future in self.on_message_input.call_all(Arc::clone(&message)) { if let Err(error) = future.await { debug!( - "Dropping message: on_message callback returned error: {error:?}" + "Dropping message: on_message_input callback returned error: {error:?}" ); continue; } @@ -268,7 +288,7 @@ mod test { // FakeSink and task let sink_messages_clone = sink_messages.clone(); let sink = FakeSink::builder() - .on_message(move |message: Arc| { + .on_message_input(move |message: Arc| { let sink_messages = sink_messages_clone.clone(); async move { @@ -286,7 +306,7 @@ mod test { // FakeSource and task let source_messages_clone = source_messages.clone(); let source = FakeSource::builder(message_period) - .on_message(move |message: Arc| { + .on_message_input(move |message: Arc| { let source_messages = source_messages_clone.clone(); async move { diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index 85b2d841..f05495cc 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -240,13 +240,15 @@ mod tests { // Example struct implementing Driver pub struct ExampleDriver { - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks<(u64, Arc)>, } impl ExampleDriver { pub fn new() -> ExampleDriverBuilder { ExampleDriverBuilder(Self { - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), }) } } @@ -258,11 +260,11 @@ mod tests { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); self } } @@ -273,9 +275,11 @@ mod tests { let mut hub_receiver = hub_sender.subscribe(); while let Ok(message) = hub_receiver.recv().await { - for future in self.on_message.call_all(Arc::clone(&message)) { + for future in self.on_message_input.call_all(Arc::clone(&message)) { if let Err(error) = future.await { - debug!("Dropping message: on_message callback returned error: {error:?}"); + debug!( + "Dropping message: on_message_input callback returned error: {error:?}" + ); continue; } } @@ -315,13 +319,13 @@ mod tests { } #[tokio::test] - async fn on_message_callback_test() -> Result<()> { + async fn on_message_input_callback_test() -> Result<()> { let (sender, _receiver) = tokio::sync::broadcast::channel(1); let called = Arc::new(RwLock::new(false)); let called_cloned = called.clone(); let driver = ExampleDriver::new() - .on_message(move |_msg| { + .on_message_input(move |_msg| { let called = called_cloned.clone(); async move { diff --git a/src/drivers/serial/mod.rs b/src/drivers/serial/mod.rs index ffcf711c..7bf4f712 100644 --- a/src/drivers/serial/mod.rs +++ b/src/drivers/serial/mod.rs @@ -17,7 +17,8 @@ use crate::{ pub struct Serial { pub port_name: String, pub baud_rate: u32, - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, } pub struct SerialBuilder(Serial); @@ -27,11 +28,19 @@ impl SerialBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); + self + } + + pub fn on_message_output(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -42,15 +51,18 @@ impl Serial { SerialBuilder(Self { port_name: port_name.to_string(), baud_rate, - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), }) } - #[instrument(level = "debug", skip(port))] + #[instrument(level = "debug", skip(port, on_message_input))] async fn serial_receive_task( port_name: &str, port: Arc>>, hub_sender: broadcast::Sender>, + + on_message_input: &Callbacks>, ) -> Result<()> { let mut buf = vec![0; 1024]; @@ -59,7 +71,16 @@ impl Serial { // We got something Ok(bytes_received) if bytes_received > 0 => { read_all_messages("serial", &mut buf, |message| async { - if let Err(error) = hub_sender.send(Arc::new(message)) { + let message = Arc::new(message); + + for future in on_message_input.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message_input callback returned error: {error:?}"); + continue; + } + } + + if let Err(error) = hub_sender.send(message) { error!("Failed to send message to hub: {error:?}, from {port_name:?}"); } }) @@ -80,15 +101,23 @@ impl Serial { Ok(()) } - #[instrument(level = "debug", skip(port))] + #[instrument(level = "debug", skip(port, on_message_output))] async fn serial_send_task( port_name: &str, port: Arc>>, mut hub_receiver: broadcast::Receiver>, + on_message_output: &Callbacks>, ) -> Result<()> { loop { match hub_receiver.recv().await { Ok(message) => { + for future in on_message_output.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message_output callback returned error: {error:?}"); + continue; + } + } + if let Err(error) = port.lock().await.write_all(&message.raw_bytes()).await { error!("Failed to send serial message: {error:?}"); break; @@ -126,12 +155,12 @@ impl Driver for Serial { let hub_receiver = hub_sender.subscribe(); tokio::select! { - result = Serial::serial_send_task(&port_name, write.clone(), hub_receiver) => { + result = Serial::serial_send_task(&port_name, write.clone(), hub_receiver, &self.on_message_output) => { if let Err(e) = result { error!("Error in serial receive task for {port_name}: {e:?}"); } } - result = Serial::serial_receive_task(&port_name, read.clone(), hub_sender) => { + result = Serial::serial_receive_task(&port_name, read.clone(), hub_sender, &self.on_message_input) => { if let Err(e) = result { error!("Error in serial send task for {port_name}: {e:?}"); } diff --git a/src/drivers/tcp/client.rs b/src/drivers/tcp/client.rs index f4fa776c..808ffc3e 100644 --- a/src/drivers/tcp/client.rs +++ b/src/drivers/tcp/client.rs @@ -15,7 +15,8 @@ use crate::{ pub struct TcpClient { pub remote_addr: String, - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, } pub struct TcpClientBuilder(TcpClient); @@ -25,11 +26,19 @@ impl TcpClientBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); + self + } + + pub fn on_message_output(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -39,7 +48,8 @@ impl TcpClient { pub fn builder(remote_addr: &str) -> TcpClientBuilder { TcpClientBuilder(Self { remote_addr: remote_addr.to_string(), - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), }) } } @@ -67,12 +77,12 @@ impl Driver for TcpClient { let hub_sender_cloned = Arc::clone(&hub_sender); tokio::select! { - result = tcp_receive_task(read, server_addr, hub_sender_cloned, &self.on_message) => { + result = tcp_receive_task(read, server_addr, hub_sender_cloned, &self.on_message_input) => { if let Err(e) = result { error!("Error in TCP receive task: {e:?}"); } } - result = tcp_send_task(write, server_addr, hub_receiver, &self.on_message) => { + result = tcp_send_task(write, server_addr, hub_receiver, &self.on_message_output) => { if let Err(e) = result { error!("Error in TCP send task: {e:?}"); } diff --git a/src/drivers/tcp/mod.rs b/src/drivers/tcp/mod.rs index d82f1cbc..7eba4406 100644 --- a/src/drivers/tcp/mod.rs +++ b/src/drivers/tcp/mod.rs @@ -15,12 +15,12 @@ pub mod client; pub mod server; /// Receives messages from the TCP Socket and sends them to the HUB Channel -#[instrument(level = "debug", skip(socket, hub_sender, on_message))] +#[instrument(level = "debug", skip(socket, hub_sender, on_message_input))] async fn tcp_receive_task( mut socket: OwnedReadHalf, remote_addr: &str, hub_sender: Arc>>, - on_message: &Callbacks>, + on_message_input: &Callbacks>, ) -> Result<()> { let mut buf = Vec::with_capacity(1024); @@ -36,9 +36,9 @@ async fn tcp_receive_task( read_all_messages(remote_addr, &mut buf, |message| async { let message = Arc::new(message); - for future in on_message.call_all(Arc::clone(&message)) { + for future in on_message_input.call_all(Arc::clone(&message)) { if let Err(error) = future.await { - debug!("Dropping message: on_message callback returned error: {error:?}"); + debug!("Dropping message: on_message_input callback returned error: {error:?}"); continue; } } @@ -55,12 +55,12 @@ async fn tcp_receive_task( } /// Receives messages from the HUB Channel and sends them to the TCP Socket -#[instrument(level = "debug", skip(socket, hub_receiver, on_message))] +#[instrument(level = "debug", skip(socket, hub_receiver, on_message_output))] async fn tcp_send_task( mut socket: OwnedWriteHalf, remote_addr: &str, mut hub_receiver: broadcast::Receiver>, - on_message: &Callbacks>, + on_message_output: &Callbacks>, ) -> Result<()> { loop { let message = match hub_receiver.recv().await { @@ -79,9 +79,9 @@ async fn tcp_send_task( continue; // Don't do loopback } - for future in on_message.call_all(Arc::clone(&message)) { + for future in on_message_output.call_all(Arc::clone(&message)) { if let Err(error) = future.await { - debug!("Dropping message: on_message callback returned error: {error:?}"); + debug!("Dropping message: on_message_output callback returned error: {error:?}"); continue; } } diff --git a/src/drivers/tcp/server.rs b/src/drivers/tcp/server.rs index e617fc67..82fb58e5 100644 --- a/src/drivers/tcp/server.rs +++ b/src/drivers/tcp/server.rs @@ -19,7 +19,8 @@ use crate::{ #[derive(Clone)] pub struct TcpServer { pub local_addr: String, - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, } pub struct TcpServerBuilder(TcpServer); @@ -29,11 +30,19 @@ impl TcpServerBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); + self + } + + pub fn on_message_output(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -43,29 +52,34 @@ impl TcpServer { pub fn builder(local_addr: &str) -> TcpServerBuilder { TcpServerBuilder(Self { local_addr: local_addr.to_string(), - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), }) } /// Handles communication with a single client - #[instrument(level = "debug", skip(socket, hub_sender, on_message))] + #[instrument( + level = "debug", + skip(socket, hub_sender, on_message_input, on_message_output) + )] async fn handle_client( socket: TcpStream, remote_addr: String, hub_sender: Arc>>, - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, ) -> Result<()> { let hub_receiver = hub_sender.subscribe(); let (read, write) = socket.into_split(); tokio::select! { - result = tcp_receive_task(read, &remote_addr, hub_sender, &on_message) => { + result = tcp_receive_task(read, &remote_addr, hub_sender, &on_message_input) => { if let Err(e) = result { error!("Error in TCP receive task for {remote_addr}: {e:?}"); } } - result = tcp_send_task(write, &remote_addr, hub_receiver, &on_message) => { + result = tcp_send_task(write, &remote_addr, hub_receiver, &on_message_input) => { if let Err(e) = result { error!("Error in TCP send task for {remote_addr}: {e:?}"); } @@ -94,7 +108,8 @@ impl Driver for TcpServer { socket, remote_addr, hub_sender, - self.on_message.clone(), + self.on_message_input.clone(), + self.on_message_output.clone(), )); } Err(error) => { diff --git a/src/drivers/tlog/reader.rs b/src/drivers/tlog/reader.rs index 244d3464..310a4bcc 100644 --- a/src/drivers/tlog/reader.rs +++ b/src/drivers/tlog/reader.rs @@ -14,7 +14,7 @@ use crate::{ pub struct TlogReader { pub path: PathBuf, - on_message: Callbacks<(u64, Arc)>, + on_message_input: Callbacks<(u64, Arc)>, } pub struct TlogReaderBuilder(TlogReader); @@ -24,11 +24,11 @@ impl TlogReaderBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback<(u64, Arc)>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); self } } @@ -38,7 +38,7 @@ impl TlogReader { pub fn builder(path: PathBuf) -> TlogReaderBuilder { TlogReaderBuilder(Self { path, - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), }) } @@ -101,11 +101,11 @@ impl TlogReader { let message = Arc::new(message); for future in self - .on_message + .on_message_input .call_all((us_since_epoch, (Arc::clone(&message)))) { if let Err(error) = future.await { - debug!("Dropping message: on_message callback returned error: {error:?}"); + debug!("Dropping message: on_message_input callback returned error: {error:?}"); continue; } } @@ -219,7 +219,7 @@ mod tests { let tlog_file = PathBuf::from_str("tests/files/00025-2024-04-22_18-49-07.tlog").unwrap(); let driver = TlogReader::builder(tlog_file.clone()) - .on_message(move |args: (u64, Arc)| { + .on_message_input(move |args: (u64, Arc)| { let messages_received = messages_received_cloned.clone(); async move { diff --git a/src/drivers/tlog/writer.rs b/src/drivers/tlog/writer.rs index 0432d496..ef054892 100644 --- a/src/drivers/tlog/writer.rs +++ b/src/drivers/tlog/writer.rs @@ -15,7 +15,7 @@ use crate::{ pub struct TlogWriter { pub path: PathBuf, - on_message: Callbacks<(u64, Arc)>, + on_message_output: Callbacks<(u64, Arc)>, } pub struct TlogWriterBuilder(TlogWriter); @@ -25,11 +25,11 @@ impl TlogWriterBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_output(self, callback: C) -> Self where C: MessageCallback<(u64, Arc)>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -39,7 +39,7 @@ impl TlogWriter { pub fn builder(path: PathBuf) -> TlogWriterBuilder { TlogWriterBuilder(Self { path, - on_message: Callbacks::new(), + on_message_output: Callbacks::new(), }) } @@ -55,15 +55,14 @@ impl TlogWriter { match hub_receiver.recv().await { Ok(message) => { let timestamp = chrono::Utc::now().timestamp_micros() as u64; - let message = Arc::new(message); for future in self - .on_message + .on_message_output .call_all((timestamp, (Arc::clone(&message)))) { if let Err(error) = future.await { debug!( - "Dropping message: on_message callback returned error: {error:?}" + "Dropping message: on_message_input callback returned error: {error:?}" ); continue; } diff --git a/src/drivers/udp/client.rs b/src/drivers/udp/client.rs index 8164f824..2c181a8d 100644 --- a/src/drivers/udp/client.rs +++ b/src/drivers/udp/client.rs @@ -12,7 +12,8 @@ use crate::{ pub struct UdpClient { pub remote_addr: String, - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, } pub struct UdpClientBuilder(UdpClient); @@ -22,11 +23,19 @@ impl UdpClientBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); + self + } + + pub fn on_message_output(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -36,7 +45,8 @@ impl UdpClient { pub fn builder(remote_addr: &str) -> UdpClientBuilder { UdpClientBuilder(Self { remote_addr: remote_addr.to_string(), - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), }) } @@ -56,9 +66,9 @@ impl UdpClient { read_all_messages(client_addr, &mut buf, |message| async { let message = Arc::new(message); - for future in self.on_message.call_all(Arc::clone(&message)) { + for future in self.on_message_input.call_all(Arc::clone(&message)) { if let Err(error) = future.await { - debug!("Dropping message: on_message callback returned error: {error:?}"); + debug!("Dropping message: on_message_input callback returned error: {error:?}"); continue; } } @@ -97,10 +107,10 @@ impl UdpClient { continue; // Don't do loopback } - for future in self.on_message.call_all(Arc::clone(&message)) { + for future in self.on_message_output.call_all(Arc::clone(&message)) { if let Err(error) = future.await { debug!( - "Dropping message: on_message callback returned error: {error:?}" + "Dropping message: on_message_output callback returned error: {error:?}" ); continue; } diff --git a/src/drivers/udp/server.rs b/src/drivers/udp/server.rs index d316b146..a08642f2 100644 --- a/src/drivers/udp/server.rs +++ b/src/drivers/udp/server.rs @@ -16,7 +16,8 @@ use crate::{ pub struct UdpServer { pub local_addr: String, clients: Arc>>, - on_message: Callbacks>, + on_message_input: Callbacks>, + on_message_output: Callbacks>, } pub struct UdpServerBuilder(UdpServer); @@ -26,11 +27,19 @@ impl UdpServerBuilder { self.0 } - pub fn on_message(self, callback: C) -> Self + pub fn on_message_input(self, callback: C) -> Self where C: MessageCallback>, { - self.0.on_message.add_callback(callback.into_boxed()); + self.0.on_message_input.add_callback(callback.into_boxed()); + self + } + + pub fn on_message_output(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message_output.add_callback(callback.into_boxed()); self } } @@ -41,7 +50,8 @@ impl UdpServer { UdpServerBuilder(Self { local_addr, clients: Arc::new(RwLock::new(HashMap::new())), - on_message: Callbacks::new(), + on_message_input: Callbacks::new(), + on_message_output: Callbacks::new(), }) } @@ -62,9 +72,9 @@ impl UdpServer { read_all_messages(client_addr, &mut buf, |message| async { let message = Arc::new(message); - for future in self.on_message.call_all(Arc::clone(&message)) { + for future in self.on_message_input.call_all(Arc::clone(&message)) { if let Err(error) = future.await { - debug!("Dropping message: on_message callback returned error: {error:?}"); + debug!("Dropping message: on_message_input callback returned error: {error:?}"); continue; } } @@ -119,6 +129,13 @@ impl UdpServer { continue; // Don't do loopback } + for future in self.on_message_output.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message_output callback returned error: {error:?}"); + continue; + } + } + match socket.send_to(message.raw_bytes(), client_addr).await { Ok(_) => { // Message sent successfully