Skip to content

Commit

Permalink
src: drivers: Split callbacks into two: input and output
Browse files Browse the repository at this point in the history
  • Loading branch information
joaoantoniocardoso committed Sep 7, 2024
1 parent 3f02501 commit a07581a
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 82 deletions.
48 changes: 34 additions & 14 deletions src/drivers/fake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ use crate::{
};

pub struct FakeSink {
on_message: Callbacks<Arc<Protocol>>,
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<Arc<Protocol>>,
print: bool,
}

impl FakeSink {
pub fn builder() -> FakeSinkBuilder {
FakeSinkBuilder(Self {
on_message: Callbacks::new(),
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
print: false,
})
}
Expand All @@ -36,11 +38,19 @@ impl FakeSinkBuilder {
self
}

pub fn on_message<C>(self, callback: C) -> Self
pub fn on_message_input<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message.add_callback(callback.into_boxed());
self.0.on_message_input.add_callback(callback.into_boxed());
self
}

pub fn on_message_output<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message_output.add_callback(callback.into_boxed());
self
}
}
Expand All @@ -51,9 +61,9 @@ impl Driver for FakeSink {
let mut hub_receiver = hub_sender.subscribe();

while let Ok(message) = hub_receiver.recv().await {
for future in self.on_message.call_all(Arc::clone(&message)) {
for future in self.on_message_input.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!("Dropping message: on_message callback returned error: {error:?}");
debug!("Dropping message: on_message_input callback returned error: {error:?}");
continue;
}
}
Expand Down Expand Up @@ -116,14 +126,16 @@ impl DriverInfo for FakeSinkInfo {

pub struct FakeSource {
period: std::time::Duration,
on_message: Callbacks<Arc<Protocol>>,
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<Arc<Protocol>>,
}

impl FakeSource {
pub fn builder(period: std::time::Duration) -> FakeSourceBuilder {
FakeSourceBuilder(Self {
period,
on_message: Callbacks::new(),
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
})
}
}
Expand All @@ -135,11 +147,19 @@ impl FakeSourceBuilder {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
pub fn on_message_input<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message_input.add_callback(callback.into_boxed());
self
}

pub fn on_message_output<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message.add_callback(callback.into_boxed());
self.0.on_message_output.add_callback(callback.into_boxed());
self
}
}
Expand Down Expand Up @@ -185,10 +205,10 @@ impl Driver for FakeSource {
async move {
trace!("Fake message created: {message:?}");

for future in self.on_message.call_all(Arc::clone(&message)) {
for future in self.on_message_input.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!(
"Dropping message: on_message callback returned error: {error:?}"
"Dropping message: on_message_input callback returned error: {error:?}"
);
continue;
}
Expand Down Expand Up @@ -271,7 +291,7 @@ mod test {
// FakeSink and task
let sink_messages_clone = sink_messages.clone();
let sink = FakeSink::builder()
.on_message(move |message: Arc<Protocol>| {
.on_message_input(move |message: Arc<Protocol>| {
let sink_messages = sink_messages_clone.clone();

async move {
Expand All @@ -289,7 +309,7 @@ mod test {
// FakeSource and task
let source_messages_clone = source_messages.clone();
let source = FakeSource::builder(message_period)
.on_message(move |message: Arc<Protocol>| {
.on_message_input(move |message: Arc<Protocol>| {
let source_messages = source_messages_clone.clone();

async move {
Expand Down
20 changes: 12 additions & 8 deletions src/drivers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,15 @@ mod tests {

// Example struct implementing Driver
pub struct ExampleDriver {
on_message: Callbacks<Arc<Protocol>>,
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<(u64, Arc<Protocol>)>,
}

impl ExampleDriver {
pub fn new() -> ExampleDriverBuilder {
ExampleDriverBuilder(Self {
on_message: Callbacks::new(),
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
})
}
}
Expand All @@ -258,11 +260,11 @@ mod tests {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
pub fn on_message_input<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message.add_callback(callback.into_boxed());
self.0.on_message_input.add_callback(callback.into_boxed());
self
}
}
Expand All @@ -273,9 +275,11 @@ mod tests {
let mut hub_receiver = hub_sender.subscribe();

while let Ok(message) = hub_receiver.recv().await {
for future in self.on_message.call_all(Arc::clone(&message)) {
for future in self.on_message_input.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!("Dropping message: on_message callback returned error: {error:?}");
debug!(
"Dropping message: on_message_input callback returned error: {error:?}"
);
continue;
}
}
Expand Down Expand Up @@ -315,13 +319,13 @@ mod tests {
}

#[tokio::test]
async fn on_message_callback_test() -> Result<()> {
async fn on_message_input_callback_test() -> Result<()> {
let (sender, _receiver) = tokio::sync::broadcast::channel(1);

let called = Arc::new(RwLock::new(false));
let called_cloned = called.clone();
let driver = ExampleDriver::new()
.on_message(move |_msg| {
.on_message_input(move |_msg| {
let called = called_cloned.clone();

async move {
Expand Down
47 changes: 38 additions & 9 deletions src/drivers/serial/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use crate::{
pub struct Serial {
pub port_name: String,
pub baud_rate: u32,
on_message: Callbacks<Arc<Protocol>>,
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<Arc<Protocol>>,
}

pub struct SerialBuilder(Serial);
Expand All @@ -27,11 +28,19 @@ impl SerialBuilder {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
pub fn on_message_input<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message.add_callback(callback.into_boxed());
self.0.on_message_input.add_callback(callback.into_boxed());
self
}

pub fn on_message_output<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message_output.add_callback(callback.into_boxed());
self
}
}
Expand All @@ -42,15 +51,18 @@ impl Serial {
SerialBuilder(Self {
port_name: port_name.to_string(),
baud_rate,
on_message: Callbacks::new(),
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
})
}

#[instrument(level = "debug", skip(port))]
#[instrument(level = "debug", skip(port, on_message_input))]
async fn serial_receive_task(
port_name: &str,
port: Arc<Mutex<tokio::io::ReadHalf<tokio_serial::SerialStream>>>,
hub_sender: broadcast::Sender<Arc<Protocol>>,

on_message_input: &Callbacks<Arc<Protocol>>,
) -> Result<()> {
let mut buf = vec![0; 1024];

Expand All @@ -59,7 +71,16 @@ impl Serial {
// We got something
Ok(bytes_received) if bytes_received > 0 => {
read_all_messages("serial", &mut buf, |message| async {
if let Err(error) = hub_sender.send(Arc::new(message)) {
let message = Arc::new(message);

for future in on_message_input.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!("Dropping message: on_message_input callback returned error: {error:?}");
continue;
}
}

if let Err(error) = hub_sender.send(message) {
error!("Failed to send message to hub: {error:?}, from {port_name:?}");
}
})
Expand All @@ -80,15 +101,23 @@ impl Serial {
Ok(())
}

#[instrument(level = "debug", skip(port))]
#[instrument(level = "debug", skip(port, on_message_output))]
async fn serial_send_task(
port_name: &str,
port: Arc<Mutex<tokio::io::WriteHalf<tokio_serial::SerialStream>>>,
mut hub_receiver: broadcast::Receiver<Arc<Protocol>>,
on_message_output: &Callbacks<Arc<Protocol>>,
) -> Result<()> {
loop {
match hub_receiver.recv().await {
Ok(message) => {
for future in on_message_output.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!("Dropping message: on_message_output callback returned error: {error:?}");
continue;
}
}

if let Err(error) = port.lock().await.write_all(&message.raw_bytes()).await {
error!("Failed to send serial message: {error:?}");
break;
Expand Down Expand Up @@ -126,12 +155,12 @@ impl Driver for Serial {
let hub_receiver = hub_sender.subscribe();

tokio::select! {
result = Serial::serial_send_task(&port_name, write.clone(), hub_receiver) => {
result = Serial::serial_send_task(&port_name, write.clone(), hub_receiver, &self.on_message_output) => {
if let Err(e) = result {
error!("Error in serial receive task for {port_name}: {e:?}");
}
}
result = Serial::serial_receive_task(&port_name, read.clone(), hub_sender) => {
result = Serial::serial_receive_task(&port_name, read.clone(), hub_sender, &self.on_message_input) => {
if let Err(e) = result {
error!("Error in serial send task for {port_name}: {e:?}");
}
Expand Down
22 changes: 16 additions & 6 deletions src/drivers/tcp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use crate::{

pub struct TcpClient {
pub remote_addr: String,
on_message: Callbacks<Arc<Protocol>>,
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<Arc<Protocol>>,
}

pub struct TcpClientBuilder(TcpClient);
Expand All @@ -25,11 +26,19 @@ impl TcpClientBuilder {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
pub fn on_message_input<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message.add_callback(callback.into_boxed());
self.0.on_message_input.add_callback(callback.into_boxed());
self
}

pub fn on_message_output<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message_output.add_callback(callback.into_boxed());
self
}
}
Expand All @@ -39,7 +48,8 @@ impl TcpClient {
pub fn builder(remote_addr: &str) -> TcpClientBuilder {
TcpClientBuilder(Self {
remote_addr: remote_addr.to_string(),
on_message: Callbacks::new(),
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
})
}
}
Expand Down Expand Up @@ -67,12 +77,12 @@ impl Driver for TcpClient {
let hub_sender_cloned = Arc::clone(&hub_sender);

tokio::select! {
result = tcp_receive_task(read, server_addr, hub_sender_cloned, &self.on_message) => {
result = tcp_receive_task(read, server_addr, hub_sender_cloned, &self.on_message_input) => {
if let Err(e) = result {
error!("Error in TCP receive task: {e:?}");
}
}
result = tcp_send_task(write, server_addr, hub_receiver, &self.on_message) => {
result = tcp_send_task(write, server_addr, hub_receiver, &self.on_message_output) => {
if let Err(e) = result {
error!("Error in TCP send task: {e:?}");
}
Expand Down
Loading

0 comments on commit a07581a

Please sign in to comment.