From e276a80f953b5fd3288c8f30e0dea14264c3c669 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Wed, 16 Oct 2024 14:00:22 -0300 Subject: [PATCH] WIP --- Cargo.lock | 38 ++- Cargo.toml | 8 +- cross_build_and_install.sh | 18 ++ src/build.rs | 92 -------- src/lib/callbacks.rs | 6 + src/lib/cli.rs | 12 +- src/lib/drivers/fake.rs | 119 +++++----- src/lib/drivers/generic_tasks.rs | 149 ++++++++++++ src/lib/drivers/mod.rs | 23 +- src/lib/drivers/rest/mod.rs | 46 ++-- src/lib/drivers/serial/mod.rs | 145 ++++-------- src/lib/drivers/tcp/client.rs | 52 +++-- src/lib/drivers/tcp/mod.rs | 106 --------- src/lib/drivers/tcp/server.rs | 78 ++++--- src/lib/drivers/tlog/reader.rs | 37 +-- src/lib/drivers/tlog/writer.rs | 3 +- src/lib/drivers/udp/client.rs | 276 ++++++++++++---------- src/lib/drivers/udp/server.rs | 318 +++++++++++++++----------- src/lib/hub/actor.rs | 8 +- src/lib/logger.rs | 5 +- src/lib/protocol.rs | 59 +---- src/lib/stats/accumulated/messages.rs | 4 +- src/lib/stats/accumulated/mod.rs | 2 +- src/lib/web/mod.rs | 2 +- 24 files changed, 829 insertions(+), 777 deletions(-) create mode 100755 cross_build_and_install.sh create mode 100644 src/lib/drivers/generic_tasks.rs diff --git a/Cargo.lock b/Cargo.lock index 2ede5053..4f6d3fb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -271,6 +271,9 @@ name = "bytes" version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +dependencies = [ + "serde", +] [[package]] name = "camino" @@ -1689,7 +1692,7 @@ checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" [[package]] name = "mavlink" version = "0.13.2" -source = "git+https://github.com/joaoantoniocardoso/rust-mavlink?branch=add-tokio#eb988ac7edf2782be25c4e835933359b9e4c2b07" +source = "git+https://github.com/mavlink/rust-mavlink#5f2ecbe856a02e7bd17f43de63c8136db7947715" dependencies = [ "bitflags 1.3.2", "mavlink-bindgen", @@ -1703,7 +1706,7 @@ dependencies = [ [[package]] name = "mavlink-bindgen" version = "0.13.2" -source = "git+https://github.com/joaoantoniocardoso/rust-mavlink?branch=add-tokio#eb988ac7edf2782be25c4e835933359b9e4c2b07" +source = "git+https://github.com/mavlink/rust-mavlink#5f2ecbe856a02e7bd17f43de63c8136db7947715" dependencies = [ "crc-any", "lazy_static", @@ -1713,11 +1716,24 @@ dependencies = [ "thiserror", ] +[[package]] +name = "mavlink-codec" +version = "0.1.0" +source = "git+https://github.com/bluerobotics/rust-mavlink-codec?branch=master#04486bb0ed7ca5be98e9c2186f194452edde9214" +dependencies = [ + "bytes", + "log", + "mavlink", + "thiserror", + "tokio-util", +] + [[package]] name = "mavlink-core" version = "0.13.2" -source = "git+https://github.com/joaoantoniocardoso/rust-mavlink?branch=add-tokio#eb988ac7edf2782be25c4e835933359b9e4c2b07" +source = "git+https://github.com/mavlink/rust-mavlink#5f2ecbe856a02e7bd17f43de63c8136db7947715" dependencies = [ + "async-trait", "byteorder", "crc-any", "serde", @@ -1734,6 +1750,7 @@ dependencies = [ "async-trait", "axum", "byteorder", + "bytes", "chrono", "clap", "criterion", @@ -1743,6 +1760,7 @@ dependencies = [ "json5", "lazy_static", "mavlink", + "mavlink-codec", "mime_guess", "regex", "serde", @@ -1751,6 +1769,7 @@ dependencies = [ "shellexpand", "tokio", "tokio-serial", + "tokio-util", "tracing", "tracing-appender", "tracing-log", @@ -2648,6 +2667,19 @@ dependencies = [ "tungstenite", ] +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tower" version = "0.4.13" diff --git a/Cargo.toml b/Cargo.toml index e2f90346..bc39368f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ anyhow = "1" arc-swap = "1.7" async-trait = "0.1.81" axum = { version = "0.7.5", features = ["ws"] } +bytes = { version = "1.7", features = ["serde"] } byteorder = "1.5.0" chrono = { version = "0.4", features = ["serde"] } clap = { version = "4.5", features = ["derive"] } @@ -32,9 +33,9 @@ include_dir = "0.7.4" indexmap = { version = "2.5.0", features = ["serde"] } json5 = "0.4.1" lazy_static = "1.5.0" -# mavlink = { default-features = false, features = ["ardupilotmega", "std", "tokio-1"], path = "../rust-mavlink/mavlink" } -# mavlink = { version = "0.13.1", default-features = false, features = ["ardupilotmega", "std"] } -mavlink = { default-features = false, features = ["ardupilotmega", "serde", "std", "tokio-1"], git = "https://github.com/joaoantoniocardoso/rust-mavlink", branch = "add-tokio" } +mavlink = { default-features = false, features = ["std", "ardupilotmega", "serde", "tokio-1"], git = "https://github.com/mavlink/rust-mavlink", hash = "5f2ecbe8" } +# mavlink-codec = { path = "../mavlink-codec" } +mavlink-codec = { git = "https://github.com/bluerobotics/rust-mavlink-codec", branch = "master" } regex = "1.10.6" serde = { version = "1", features = ["rc"] } serde_derive = "1.0.210" @@ -42,6 +43,7 @@ serde_json = "1.0.128" shellexpand = "3.1" tokio = { version = "1", features = ["full"] } tokio-serial = "5.4.4" +tokio-util = { version = "0.7", features = [ "codec", "net" ] } url = { version = "2.5.2", features = ["serde"] } uuid = { version = "1", features = ["v5", "v4", "serde"] } mime_guess = "2.0.5" diff --git a/cross_build_and_install.sh b/cross_build_and_install.sh new file mode 100755 index 00000000..5de29333 --- /dev/null +++ b/cross_build_and_install.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +set -e + +TARGET=armv7-unknown-linux-gnueabihf +BUILDTYPE=release + +cross build --$BUILDTYPE --target=$TARGET + +/home/joaoantoniocardoso/BlueRobotics/cross_build_dev/old/upload_to_blueos.sh \ + target/$TARGET/$BUILDTYPE/mavlink-server \ + /home/pi/mavlink-server + +echo "" +echo "" +echo 'clear; sshpass -p raspberry scp -o StrictHostKeyChecking=no pi@localhost:/home/pi/mavlink-server "$(which mavlink-server)" ; /home/pi/services/ardupilot_manager/main.py' +echo "" +echo "" diff --git a/src/build.rs b/src/build.rs index bb8786c3..4cd6a684 100644 --- a/src/build.rs +++ b/src/build.rs @@ -1,62 +1,5 @@ -use std::process::{exit, Command}; - use vergen_gix::{BuildBuilder, CargoBuilder, DependencyKind, GixBuilder}; -macro_rules! info { - ($($tokens: tt)*) => { - println!("cargo:warning={}", format!($($tokens)*)) - } -} - -fn is_wasm_target_installed() -> bool { - let output = Command::new("rustup") - .args(["target", "list", "--installed"]) - .output() - .expect("Failed to execute rustup"); - - let installed_targets = String::from_utf8_lossy(&output.stdout); - installed_targets.contains("wasm32-unknown-unknown") -} - -fn install_wasm_target() { - info!("Adding wasm32-unknown-unknown target..."); - let output = Command::new("rustup") - .args(["target", "add", "wasm32-unknown-unknown"]) - .output() - .expect("Failed to execute rustup"); - - if !output.status.success() { - eprintln!("T{}", String::from_utf8_lossy(&output.stderr)); - exit(1); - } -} - -fn get_trunk_version() -> Option { - Command::new("trunk") - .arg("--version") - .output() - .ok() - .and_then(|output| String::from_utf8(output.stdout).ok()) - .and_then(|version_string| version_string.split_whitespace().last().map(String::from)) -} - -fn install_trunk() -> Result<(), Box> { - info!("Installing trunk..."); - - let output = Command::new("cargo") - .arg("install") - .arg("trunk") - .arg("--force") - .output()?; - - if !output.status.success() { - eprintln!("TT{}", String::from_utf8_lossy(&output.stderr)); - exit(1); - } - - Ok(()) -} - fn main() -> Result<(), Box> { println!("cargo:rerun-if-changed=./src/webpage/"); @@ -68,40 +11,5 @@ fn main() -> Result<(), Box> { )? .emit()?; - if std::env::var("SKIP_FRONTEND").is_ok() { - return Ok(()); - } - - if !is_wasm_target_installed() { - install_wasm_target(); - } - - if get_trunk_version().is_none() { - info!("trunk not found"); - install_trunk().unwrap_or_else(|e| { - eprintln!("Error: {}", e); - exit(1); - }); - } - - let mut trunk_command = Command::new("trunk"); - trunk_command.args(["build", "./src/webpage/index.html"]); - - // Add --release argument if not in debug mode - if cfg!(not(debug_assertions)) { - trunk_command.args(["--release", "--locked"]); - } - - let trunk_output = trunk_command.output().expect("Failed to execute trunk"); - - if !trunk_output.status.success() { - eprintln!( - "Trunk build failed: {}", - String::from_utf8_lossy(&trunk_output.stderr) - ); - exit(1); - } - info!("{}", String::from_utf8_lossy(&trunk_output.stdout)); - Ok(()) } diff --git a/src/lib/callbacks.rs b/src/lib/callbacks.rs index 463b77dc..4e4084a7 100644 --- a/src/lib/callbacks.rs +++ b/src/lib/callbacks.rs @@ -12,6 +12,12 @@ pub struct Callbacks { callbacks: Arc>>>, } +impl std::fmt::Debug for Callbacks { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Callbacks").finish() + } +} + impl Callbacks { pub fn new() -> Self { Self { diff --git a/src/lib/cli.rs b/src/lib/cli.rs index 290ccb0f..cfb7f128 100644 --- a/src/lib/cli.rs +++ b/src/lib/cli.rs @@ -6,7 +6,7 @@ use tracing::*; use crate::drivers; -#[derive(Parser)] +#[derive(Debug, Parser)] #[command( version = env!("CARGO_PKG_VERSION"), author = env!("CARGO_PKG_AUTHORS"), @@ -83,7 +83,7 @@ fn build_endpoints_help() -> String { .collect::>() .join("\n\t\t "); - vec![ + [ format!("{name}\t {help_schemas}").to_string(), format!("\t legacy: {help_legacy}").to_string(), format!("\t url: {help_url}\n").to_string(), @@ -154,7 +154,7 @@ pub fn endpoints() -> Vec> { let default_endpoints = Arc::new(crate::drivers::rest::Rest::builder("Default").build()); let mut endpoints = MANAGER.clap_matches.endpoints.clone(); endpoints.push(default_endpoints); - return endpoints; + endpoints } #[instrument(level = "debug")] @@ -162,6 +162,12 @@ pub fn command_line_string() -> String { std::env::args().collect::>().join(" ") } +// Return a clone of current Args struct +#[instrument(level = "debug")] +pub fn command_line() -> String { + format!("{:#?}", MANAGER.clap_matches) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib/drivers/fake.rs b/src/lib/drivers/fake.rs index 4ac0937a..f9228de5 100644 --- a/src/lib/drivers/fake.rs +++ b/src/lib/drivers/fake.rs @@ -1,19 +1,22 @@ +use bytes::{BufMut, BytesMut}; use std::sync::Arc; use anyhow::Result; +use mavlink_codec::{v2::V2Packet, Packet}; use tokio::sync::{broadcast, RwLock}; use tracing::*; use crate::{ callbacks::{Callbacks, MessageCallback}, drivers::{Driver, DriverInfo}, - protocol::{read_all_messages, Protocol}, + protocol::Protocol, stats::{ accumulated::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider}, driver::DriverUuid, }, }; +#[derive(Debug)] pub struct FakeSink { name: arc_swap::ArcSwap, uuid: DriverUuid, @@ -75,9 +78,28 @@ impl Driver for FakeSink { } } - let mut bytes = mavlink::async_peek_reader::AsyncPeekReader::new(message.raw_bytes()); - let (header, message): (mavlink::MavHeader, mavlink::ardupilotmega::MavMessage) = - mavlink::read_v2_msg_async(&mut bytes).await.unwrap(); + let version = match &**message { + Packet::V1(_) => mavlink::MavlinkVersion::V1, + Packet::V2(_) => mavlink::MavlinkVersion::V2, + }; + + let frame = match mavlink::MavFrame::::deser( + version, + message.as_slice(), + ) { + Ok(frame) => frame, + Err(error) => { + warn!("Failed to deserialize Mavlink Message: {error:?}"); + continue; + } + }; + + let mavlink::MavFrame { + header, + msg: message, + protocol_version: _, + } = frame; + if self.print { println!("Message received: {header:?} {message:?}"); } else { @@ -148,6 +170,7 @@ impl DriverInfo for FakeSinkInfo { } } +#[derive(Debug)] pub struct FakeSource { name: arc_swap::ArcSwap, uuid: DriverUuid, @@ -193,64 +216,54 @@ impl Driver for FakeSource { async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { let mut sequence = 0; - let mut buf: Vec = Vec::with_capacity(280); - use mavlink::ardupilotmega::{ MavAutopilot, MavMessage, MavModeFlag, MavState, MavType, HEARTBEAT_DATA, }; + let mut header = mavlink::MavHeader { + sequence, + system_id: 1, + component_id: 2, + }; + let data = MavMessage::HEARTBEAT(HEARTBEAT_DATA { + custom_mode: 5, + mavtype: MavType::MAV_TYPE_QUADROTOR, + autopilot: MavAutopilot::MAV_AUTOPILOT_ARDUPILOTMEGA, + base_mode: MavModeFlag::MAV_MODE_FLAG_MANUAL_INPUT_ENABLED + | MavModeFlag::MAV_MODE_FLAG_STABILIZE_ENABLED + | MavModeFlag::MAV_MODE_FLAG_GUIDED_ENABLED + | MavModeFlag::MAV_MODE_FLAG_CUSTOM_MODE_ENABLED, + system_status: MavState::MAV_STATE_STANDBY, + mavlink_version: 3, + }); + loop { - let header = mavlink::MavHeader { - sequence, - system_id: 1, - component_id: 2, - }; - let data = MavMessage::HEARTBEAT(HEARTBEAT_DATA { - custom_mode: 5, - mavtype: MavType::MAV_TYPE_QUADROTOR, - autopilot: MavAutopilot::MAV_AUTOPILOT_ARDUPILOTMEGA, - base_mode: MavModeFlag::MAV_MODE_FLAG_MANUAL_INPUT_ENABLED - | MavModeFlag::MAV_MODE_FLAG_STABILIZE_ENABLED - | MavModeFlag::MAV_MODE_FLAG_GUIDED_ENABLED - | MavModeFlag::MAV_MODE_FLAG_CUSTOM_MODE_ENABLED, - system_status: MavState::MAV_STATE_STANDBY, - mavlink_version: 3, - }); + header.sequence = sequence; sequence = sequence.overflowing_add(1).0; - buf.clear(); - mavlink::write_v2_msg(&mut buf, header, &data).expect("Failed to write message"); + let buf = BytesMut::with_capacity(V2Packet::MAX_PACKET_SIZE); + let mut writer = buf.writer(); + if let Err(error) = mavlink::write_v2_msg(&mut writer, header, &data) { + warn!("Failed to serialize message: {error:?}"); + continue; + } - read_all_messages("FakeSource", &mut buf, { - let hub_sender = hub_sender.clone(); - move |message| { - let message = Arc::new(message); - let hub_sender = hub_sender.clone(); + let packet = Packet::V2(V2Packet::new(writer.into_inner().freeze())); - async move { - trace!("Fake message created: {message:?}"); - - self.stats - .write() - .await - .stats - .update_output(&message); - - for future in self.on_message_output.call_all(message.clone()) { - if let Err(error) = future.await { - debug!( - "Dropping message: on_message_input callback returned error: {error:?}" - ); - continue; - } - } - - if let Err(error) = hub_sender.send(message) { - error!("Failed to send message to hub: {error:?}"); - } - } - }}) - .await; + let message = Arc::new(Protocol::new("fake_source", packet)); + + self.stats.write().await.stats.update_output(&message); + + for future in self.on_message_output.call_all(message.clone()) { + if let Err(error) = future.await { + debug!("Dropping message: on_message_input callback returned error: {error:?}"); + continue; + } + } + + if let Err(error) = hub_sender.send(message) { + error!("Failed to send message to hub: {error:?}"); + } tokio::time::sleep(self.period).await; } diff --git a/src/lib/drivers/generic_tasks.rs b/src/lib/drivers/generic_tasks.rs new file mode 100644 index 00000000..92b16cdc --- /dev/null +++ b/src/lib/drivers/generic_tasks.rs @@ -0,0 +1,149 @@ +use std::sync::Arc; + +use anyhow::Result; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use mavlink_codec::{error::DecoderError, Packet}; +use tokio::sync::{broadcast, RwLock}; +use tracing::*; + +use crate::{ + callbacks::Callbacks, protocol::Protocol, stats::accumulated::driver::AccumulatedDriverStats, +}; + +#[derive(Clone)] +pub struct SendReceiveContext { + pub hub_sender: broadcast::Sender>, + pub on_message_output: Callbacks>, + pub on_message_input: Callbacks>, + pub stats: Arc>, +} + +#[instrument(level = "debug", skip(writer, reader, context))] +pub async fn default_send_receive_run( + mut writer: S, + mut reader: T, + identifier: &str, + context: &SendReceiveContext, +) -> Result<()> +where + S: Sink + std::marker::Unpin, + T: Stream>> + + std::marker::Unpin, +{ + tokio::select! { + result = default_send_task(&mut writer, identifier, context) => { + if let Err(error) = result { + error!("Error in send task for {identifier}: {error:?}"); + } + } + result = default_receive_task(&mut reader, identifier, context) => { + if let Err(error) = result { + error!("Error in receive task for {identifier}: {error:?}"); + } + } + } + + Ok(()) +} + +/// Receives messages from a Stream and sends them to the HUB Channel +#[instrument(level = "debug", skip(reader, context))] +pub async fn default_receive_task( + reader: &mut T, + identifier: &str, + context: &SendReceiveContext, +) -> Result<()> +where + T: Stream>> + + std::marker::Unpin, +{ + loop { + let packet = match reader.next().await { + Some(Ok(Ok(packet))) => packet, + Some(Ok(Err(decode_error))) => { + error!("Failed to decode packet: {decode_error:?}"); + continue; + } + Some(Err(io_error)) => { + error!("Critical error trying to decode data from: {io_error:?}"); + break; + } + None => break, + }; + + let message = Arc::new(Protocol::new(identifier, packet)); + + trace!("Received message: {message:?}"); + + context.stats.write().await.stats.update_input(&message); + + for future in context.on_message_input.call_all(message.clone()) { + if let Err(error) = future.await { + debug!("Dropping message: on_message_input callback returned error: {error:?}"); + continue; + } + } + + if let Err(send_error) = context.hub_sender.send(message) { + error!("Failed to send message to hub: {send_error:?}"); + continue; + } + + trace!("Message sent to hub"); + } + + debug!("Driver receiver task stopped!"); + + Ok(()) +} + +/// Receives messages from the HUB Channel and sends them to a Sink +#[instrument(level = "debug", skip(writer, context))] +pub async fn default_send_task( + writer: &mut S, + identifier: &str, + context: &SendReceiveContext, +) -> Result<()> +where + S: Sink + std::marker::Unpin, +{ + let mut hub_receiver = context.hub_sender.subscribe(); + + loop { + let message = match hub_receiver.recv().await { + Ok(message) => message, + Err(broadcast::error::RecvError::Closed) => { + error!("Hub channel closed!"); + break; + } + Err(broadcast::error::RecvError::Lagged(count)) => { + warn!("Channel lagged by {count} messages."); + continue; + } + }; + + if message.origin.eq(&identifier) { + continue; // Don't do loopback + } + + context.stats.write().await.stats.update_output(&message); + + for future in context.on_message_output.call_all(message.clone()) { + if let Err(error) = future.await { + debug!("Dropping message: on_message_output callback returned error: {error:?}"); + continue; + } + } + + if let Err(error) = writer.send((**message).clone()).await { + error!("Failed to send message: {error:?}"); + break; + } + + trace!("Message sent to {identifier}: {:?}", message.as_slice()); + } + + debug!("Driver sender task stopped!"); + + Ok(()) +} diff --git a/src/lib/drivers/mod.rs b/src/lib/drivers/mod.rs index 90b2681a..d9f59de0 100644 --- a/src/lib/drivers/mod.rs +++ b/src/lib/drivers/mod.rs @@ -1,4 +1,5 @@ pub mod fake; +pub mod generic_tasks; pub mod rest; pub mod serial; pub mod tcp; @@ -17,7 +18,7 @@ use crate::{protocol::Protocol, stats::accumulated::driver::AccumulatedDriverSta #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum Type { - FakeClient, + FakeSink, FakeSource, Serial, TlogWriter, @@ -37,7 +38,7 @@ pub struct DriverDescriptionLegacy { } #[async_trait::async_trait] -pub trait Driver: Send + Sync + AccumulatedDriverStatsProvider { +pub trait Driver: Send + Sync + AccumulatedDriverStatsProvider + std::fmt::Debug { async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()>; fn info(&self) -> Box; @@ -215,11 +216,11 @@ pub fn endpoints() -> Vec { typ: Type::UdpServer, }, ExtInfo { - driver_ext: Box::new(fake::FakeSourceInfo), - typ: Type::FakeClient, + driver_ext: Box::new(fake::FakeSinkInfo), + typ: Type::FakeSink, }, ExtInfo { - driver_ext: Box::new(fake::FakeSinkInfo), + driver_ext: Box::new(fake::FakeSourceInfo), typ: Type::FakeSource, }, ] @@ -230,7 +231,7 @@ mod tests { use std::{collections::HashSet, sync::Arc}; use anyhow::{anyhow, Result}; - use mavlink::MAVLinkV2MessageRaw; + use mavlink_codec::{v2::V2Packet, Packet}; use tokio::sync::RwLock; use tracing::*; @@ -254,6 +255,7 @@ mod tests { } // Example struct implementing Driver + #[derive(Debug)] pub struct ExampleDriver { name: arc_swap::ArcSwap, uuid: DriverUuid, @@ -299,7 +301,7 @@ mod tests { let mut hub_receiver = hub_sender.subscribe(); while let Ok(message) = hub_receiver.recv().await { - self.stats.write().await.stats.update_input(&message); + self.stats.write().await.stats.update_output(&message); for future in self.on_message_input.call_all(message.clone()) { if let Err(error) = future.await { @@ -310,7 +312,7 @@ mod tests { } } - trace!("Message received: {message:?}"); + trace!("Message sent: {message:?}"); } Ok(()) @@ -395,7 +397,10 @@ mod tests { async move { sender - .send(Arc::new(Protocol::new("test", MAVLinkV2MessageRaw::new()))) + .send(Arc::new(Protocol::new( + "test", + Packet::V2(V2Packet::default()), + ))) .unwrap(); } }); diff --git a/src/lib/drivers/rest/mod.rs b/src/lib/drivers/rest/mod.rs index 69b94583..eb1534ad 100644 --- a/src/lib/drivers/rest/mod.rs +++ b/src/lib/drivers/rest/mod.rs @@ -3,6 +3,8 @@ pub mod data; use std::sync::Arc; use anyhow::Result; +use mavlink::ardupilotmega::MavMessage; +use mavlink_codec::Packet; use serde::{Deserialize, Serialize}; use tokio::sync::{broadcast, RwLock}; use tracing::*; @@ -23,6 +25,7 @@ pub struct MAVLinkMessage { pub message: T, } +#[derive(Debug)] pub struct Rest { name: arc_swap::ArcSwap, uuid: DriverUuid, @@ -83,26 +86,7 @@ impl Rest { { let mut message_raw = mavlink::MAVLinkV2MessageRaw::new(); message_raw.serialize_message(content.header, &content.message); - let bus_message = Arc::new(Protocol::new("Ws", message_raw)); - stats.write().await.stats.update_input(&bus_message); - - for future in on_message_input.call_all(bus_message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_input callback returned error: {error:?}"); - continue; - } - } - - if let Err(error) = hub_sender.send(bus_message) { - error!("Failed to send message to hub: {error:?}"); - } - return Ok(()); - } else if let Ok(content) = - json5::from_str::>(&message) - { - let mut message_raw = mavlink::MAVLinkV2MessageRaw::new(); - message_raw.serialize_message(content.header, &content.message); - let bus_message = Arc::new(Protocol::new("Ws", message_raw)); + let bus_message = Arc::new(Protocol::new("Ws", Packet::from(message_raw))); stats.write().await.stats.update_input(&bus_message); for future in on_message_input.call_all(bus_message.clone()) { @@ -148,16 +132,14 @@ impl Rest { } let mut bytes = - mavlink::async_peek_reader::AsyncPeekReader::new(message.raw_bytes()); - let (header, message): ( - mavlink::MavHeader, - mavlink::ardupilotmega::MavMessage, - ) = mavlink::read_v2_msg_async(&mut bytes).await.unwrap(); - - let mavlink_message = MAVLinkMessage { - header: header, - message: message, + mavlink::async_peek_reader::AsyncPeekReader::new(message.as_slice()); + let Ok((header, message)) = + mavlink::read_v2_msg_async::(&mut bytes).await + else { + continue; }; + + let mavlink_message = MAVLinkMessage { header, message }; let json_string = parse_query(&mavlink_message); data::update((header, mavlink_message.message)); crate::web::send_message(json_string).await; @@ -180,7 +162,13 @@ pub fn parse_query(message: &T) -> String { impl Driver for Rest { #[instrument(level = "debug", skip(self, hub_sender))] async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { + let mut first = true; loop { + if !first { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + first = false; + } + let hub_sender = hub_sender.clone(); let hub_receiver = hub_sender.subscribe(); let mut ws_receiver = crate::web::create_message_receiver(); diff --git a/src/lib/drivers/serial/mod.rs b/src/lib/drivers/serial/mod.rs index 95293598..5694f93f 100644 --- a/src/lib/drivers/serial/mod.rs +++ b/src/lib/drivers/serial/mod.rs @@ -1,23 +1,26 @@ use std::sync::Arc; use anyhow::Result; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - sync::{broadcast, Mutex, RwLock}, -}; +use futures::StreamExt; +use mavlink_codec::codec::MavlinkCodec; +use tokio::sync::{broadcast, RwLock}; use tokio_serial::{self, SerialPortBuilderExt}; +use tokio_util::codec::Framed; use tracing::*; use crate::{ callbacks::{Callbacks, MessageCallback}, drivers::{Driver, DriverInfo}, - protocol::{read_all_messages, Protocol}, + protocol::Protocol, stats::{ accumulated::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider}, driver::DriverUuid, }, }; +use super::generic_tasks::{default_send_receive_run, SendReceiveContext}; + +#[derive(Debug)] pub struct Serial { name: arc_swap::ArcSwap, uuid: DriverUuid, @@ -67,117 +70,53 @@ impl Serial { stats: Arc::new(RwLock::new(AccumulatedDriverStats::new(name, &SerialInfo))), }) } +} - #[instrument(level = "debug", skip(port, on_message_input))] - async fn serial_receive_task( - port_name: &str, - port: Arc>>, - hub_sender: broadcast::Sender>, +#[async_trait::async_trait] +impl Driver for Serial { + #[instrument(level = "debug", skip(self, hub_sender))] + async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { + let port_name = self.port_name.clone(); - on_message_input: &Callbacks>, - ) -> Result<()> { - let mut buf = vec![0; 1024]; + let context = SendReceiveContext { + hub_sender, + on_message_output: self.on_message_output.clone(), + on_message_input: self.on_message_input.clone(), + stats: self.stats.clone(), + }; + let mut first = true; loop { - match port.lock().await.read(&mut buf).await { - // We got something - Ok(bytes_received) if bytes_received > 0 => { - read_all_messages("serial", &mut buf, |message| async { - let message = Arc::new(message); + if !first { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + first = false; + } - for future in on_message_input.call_all(message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_input callback returned error: {error:?}"); - continue; - } - } + debug!("Trying to connect..."); - if let Err(error) = hub_sender.send(message) { - error!("Failed to send message to hub: {error:?}, from {port_name:?}"); - } - }) - .await; - } - // We got nothing - Ok(_) => { - break; - } - // We got problems + let stream = match tokio_serial::new(&port_name, self.baud_rate) + .timeout(tokio::time::Duration::from_secs(1)) + .open_native_async() + { + Ok(stream) => stream, Err(error) => { - error!("Failed to receive serial message: {error:?}, from {port_name:?}"); - break; + error!("Failed to open serial port {port_name:?}: {error:?}"); + continue; } - } - } + }; - Ok(()) - } + debug!("Successfully connected"); - #[instrument(level = "debug", skip(port, on_message_output))] - async fn serial_send_task( - port_name: &str, - port: Arc>>, - mut hub_receiver: broadcast::Receiver>, - on_message_output: &Callbacks>, - ) -> Result<()> { - loop { - match hub_receiver.recv().await { - Ok(message) => { - for future in on_message_output.call_all(message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_output callback returned error: {error:?}"); - continue; - } - } + let codec = MavlinkCodec::::default(); + let (writer, reader) = Framed::new(stream, codec).split(); - if let Err(error) = port.lock().await.write_all(&message.raw_bytes()).await { - error!("Failed to send serial message: {error:?}"); - break; - } - } - Err(error) => { - error!("Failed to receive message from hub: {error:?}"); - } + if let Err(reason) = + default_send_receive_run(writer, reader, &port_name, &context).await + { + warn!("Driver send/receive tasks closed: {reason}"); } - } - Ok(()) - } -} -#[async_trait::async_trait] -impl Driver for Serial { - #[instrument(level = "debug", skip(self, hub_sender))] - async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { - let port_name = self.port_name.clone(); - let (read, write) = match tokio_serial::new(&port_name, self.baud_rate) - .timeout(tokio::time::Duration::from_secs(1)) - .open_native_async() - { - Ok(port) => { - let (read, write) = tokio::io::split(port); - (Arc::new(Mutex::new(read)), Arc::new(Mutex::new(write))) - } - Err(error) => { - error!("Failed to open serial port {port_name:?}: {error:?}"); - return Err(error.into()); - } - }; - loop { - let hub_sender = hub_sender.clone(); - let hub_receiver = hub_sender.subscribe(); - - tokio::select! { - result = Serial::serial_send_task(&port_name, write.clone(), hub_receiver, &self.on_message_output) => { - if let Err(e) = result { - error!("Error in serial receive task for {port_name}: {e:?}"); - } - } - result = Serial::serial_receive_task(&port_name, read.clone(), hub_sender, &self.on_message_input) => { - if let Err(e) = result { - error!("Error in serial send task for {port_name}: {e:?}"); - } - } - } + debug!("Restarting connection loop..."); } } diff --git a/src/lib/drivers/tcp/client.rs b/src/lib/drivers/tcp/client.rs index 706aaaef..f320f3b7 100644 --- a/src/lib/drivers/tcp/client.rs +++ b/src/lib/drivers/tcp/client.rs @@ -1,16 +1,19 @@ use std::sync::Arc; use anyhow::Result; +use futures::StreamExt; +use mavlink_codec::codec::MavlinkCodec; use tokio::{ net::TcpStream, sync::{broadcast, RwLock}, }; +use tokio_util::codec::Framed; use tracing::*; use crate::{ callbacks::{Callbacks, MessageCallback}, drivers::{ - tcp::{tcp_receive_task, tcp_send_task}, + generic_tasks::{default_send_receive_run, SendReceiveContext}, Driver, DriverInfo, }, protocol::Protocol, @@ -20,6 +23,7 @@ use crate::{ }, }; +#[derive(Debug)] pub struct TcpClient { pub remote_addr: String, name: arc_swap::ArcSwap, @@ -77,36 +81,44 @@ impl Driver for TcpClient { #[instrument(level = "debug", skip(self, hub_sender))] async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { let server_addr = &self.remote_addr; - let hub_sender = Arc::new(hub_sender); + let context = SendReceiveContext { + hub_sender, + on_message_output: self.on_message_output.clone(), + on_message_input: self.on_message_input.clone(), + stats: self.stats.clone(), + }; + + let mut first = true; loop { - debug!("Trying to connect to {server_addr:?}..."); - let (read, write) = match TcpStream::connect(server_addr).await { - Ok(socket) => socket.into_split(), + if !first { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + first = false; + } + + debug!("Trying to connect..."); + + let stream = match TcpStream::connect(server_addr).await { + Ok(stream) => stream, Err(error) => { - error!("Failed connecting to {server_addr:?}: {error:?}"); + error!("Failed connecting: {error:?}"); tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; continue; } }; - debug!("TcpClient successfully connected to {server_addr:?}"); - let hub_receiver = hub_sender.subscribe(); + debug!("Successfully connected"); - tokio::select! { - result = tcp_receive_task(read, server_addr, hub_sender.clone(), &self.on_message_input, &self.stats) => { - if let Err(e) = result { - error!("Error in TCP receive task: {e:?}"); - } - } - result = tcp_send_task(write, server_addr, hub_receiver, &self.on_message_output, &self.stats) => { - if let Err(e) = result { - error!("Error in TCP send task: {e:?}"); - } - } + let codec = MavlinkCodec::::default(); + let (writer, reader) = Framed::new(stream, codec).split(); + + if let Err(reason) = + default_send_receive_run(writer, reader, server_addr, &context).await + { + warn!("Driver send/receive tasks closed: {reason:?}"); } - debug!("Restarting TCP Client connection loop..."); + debug!("Restarting connection loop..."); } } diff --git a/src/lib/drivers/tcp/mod.rs b/src/lib/drivers/tcp/mod.rs index 8f57c4bd..c07f47e0 100644 --- a/src/lib/drivers/tcp/mod.rs +++ b/src/lib/drivers/tcp/mod.rs @@ -1,108 +1,2 @@ -use std::sync::Arc; - -use anyhow::Result; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::tcp::{OwnedReadHalf, OwnedWriteHalf}, - sync::{broadcast, RwLock}, -}; -use tracing::*; - -use crate::{ - callbacks::Callbacks, - protocol::{read_all_messages, Protocol}, - stats::accumulated::driver::AccumulatedDriverStats, -}; - pub mod client; pub mod server; - -/// Receives messages from the TCP Socket and sends them to the HUB Channel -#[instrument(level = "debug", skip(socket, hub_sender, on_message_input))] -async fn tcp_receive_task( - mut socket: OwnedReadHalf, - remote_addr: &str, - hub_sender: Arc>>, - on_message_input: &Callbacks>, - stats: &Arc>, -) -> Result<()> { - let mut buf = Vec::with_capacity(1024); - - loop { - let bytes_received = socket.read_buf(&mut buf).await?; - if bytes_received == 0 { - warn!("TCP connection closed by {remote_addr}."); - break; - } - - trace!("Received TCP packet: {buf:?}"); - - read_all_messages(remote_addr, &mut buf, |message| async { - let message = Arc::new(message); - - stats.write().await.stats.update_input(&message); - - for future in on_message_input.call_all(message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_input callback returned error: {error:?}"); - continue; - } - } - - if let Err(error) = hub_sender.send(message) { - error!("Failed to send message to hub: {error:?}"); - } - }) - .await; - } - - debug!("TCP Receive task for {remote_addr} finished"); - Ok(()) -} - -/// Receives messages from the HUB Channel and sends them to the TCP Socket -#[instrument(level = "debug", skip(socket, hub_receiver, on_message_output))] -async fn tcp_send_task( - mut socket: OwnedWriteHalf, - remote_addr: &str, - mut hub_receiver: broadcast::Receiver>, - on_message_output: &Callbacks>, - stats: &Arc>, -) -> Result<()> { - loop { - let message = match hub_receiver.recv().await { - Ok(message) => message, - Err(broadcast::error::RecvError::Closed) => { - error!("Hub channel closed!"); - break; - } - Err(broadcast::error::RecvError::Lagged(count)) => { - warn!("Channel lagged by {count} messages."); - continue; - } - }; - - if message.origin.eq(&remote_addr) { - continue; // Don't do loopback - } - - stats.write().await.stats.update_output(&message); - - for future in on_message_output.call_all(message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_output callback returned error: {error:?}"); - continue; - } - } - - socket.write_all(message.raw_bytes()).await?; - - trace!( - "Message sent to {remote_addr} from TCP server: {:?}", - message.raw_bytes() - ); - } - - debug!("TCP Send task for {remote_addr} finished"); - Ok(()) -} diff --git a/src/lib/drivers/tcp/server.rs b/src/lib/drivers/tcp/server.rs index f7791f2b..9d4e3b22 100644 --- a/src/lib/drivers/tcp/server.rs +++ b/src/lib/drivers/tcp/server.rs @@ -1,16 +1,19 @@ use std::sync::Arc; -use anyhow::Result; +use anyhow::{anyhow, Result}; +use futures::StreamExt; +use mavlink_codec::codec::MavlinkCodec; use tokio::{ net::{TcpListener, TcpStream}, sync::{broadcast, RwLock}, }; +use tokio_util::codec::Framed; use tracing::*; use crate::{ callbacks::{Callbacks, MessageCallback}, drivers::{ - tcp::{tcp_receive_task, tcp_send_task}, + generic_tasks::{default_send_receive_run, SendReceiveContext}, Driver, DriverInfo, }, protocol::Protocol, @@ -20,6 +23,7 @@ use crate::{ }, }; +#[derive(Debug)] pub struct TcpServer { pub local_addr: String, name: arc_swap::ArcSwap, @@ -72,36 +76,24 @@ impl TcpServer { } /// Handles communication with a single client - #[instrument( - level = "debug", - skip(socket, hub_sender, on_message_input, on_message_output) - )] + #[instrument(level = "debug", skip(stream, context))] async fn handle_client( - socket: TcpStream, + stream: TcpStream, remote_addr: String, - hub_sender: Arc>>, - on_message_input: Callbacks>, - on_message_output: Callbacks>, - stats: Arc>, + context: SendReceiveContext, ) -> Result<()> { - let hub_receiver = hub_sender.subscribe(); + debug!("New TCP client"); - let (read, write) = socket.into_split(); + let codec = MavlinkCodec::::default(); + let (writer, reader) = Framed::new(stream, codec).split(); - tokio::select! { - result = tcp_receive_task(read, &remote_addr, hub_sender, &on_message_input, &stats) => { - if let Err(e) = result { - error!("Error in TCP receive task for {remote_addr}: {e:?}"); - } - } - result = tcp_send_task(write, &remote_addr, hub_receiver, &on_message_output, &stats) => { - if let Err(e) = result { - error!("Error in TCP send task for {remote_addr}: {e:?}"); - } - } + if let Err(reason) = default_send_receive_run(writer, reader, &remote_addr, &context).await + { + warn!("Driver send/receive tasks closed: {reason:?}"); } - debug!("Finished handling connection with {remote_addr}"); + debug!("TCP Client connection terminated"); + Ok(()) } } @@ -110,22 +102,44 @@ impl TcpServer { impl Driver for TcpServer { #[instrument(level = "debug", skip(self, hub_sender))] async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { - let listener = TcpListener::bind(&self.local_addr).await?; - let hub_sender = Arc::new(hub_sender); + debug!("Trying to bind to local address {:?}...", self.local_addr); + + let context = SendReceiveContext { + hub_sender, + on_message_output: self.on_message_output.clone(), + on_message_input: self.on_message_input.clone(), + stats: self.stats.clone(), + }; + + let listener = match TcpListener::bind(&self.local_addr).await { + Ok(listener) => listener, + Err(error) => { + error!( + "Failed to bind TCP Server to {:?}: {error:?}", + self.local_addr + ); + + return Err(anyhow!("Failed to bind TCP Server")); + } + }; + let mut first = true; loop { + if !first { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + first = false; + } + + debug!("Waiting for clients..."); + match listener.accept().await { Ok((socket, remote_addr)) => { let remote_addr = remote_addr.to_string(); - let hub_sender = hub_sender.clone(); tokio::spawn(TcpServer::handle_client( socket, remote_addr, - hub_sender, - self.on_message_input.clone(), - self.on_message_output.clone(), - self.stats.clone(), + context.clone(), )); } Err(error) => { diff --git a/src/lib/drivers/tlog/reader.rs b/src/lib/drivers/tlog/reader.rs index e3c7a84c..7549b228 100644 --- a/src/lib/drivers/tlog/reader.rs +++ b/src/lib/drivers/tlog/reader.rs @@ -3,6 +3,7 @@ use std::{path::PathBuf, sync::Arc}; use anyhow::{Context, Result}; use chrono::DateTime; use mavlink::ardupilotmega::MavMessage; +use mavlink_codec::Packet; use tokio::sync::{broadcast, RwLock}; use tracing::*; @@ -16,6 +17,7 @@ use crate::{ }, }; +#[derive(Debug)] pub struct TlogReader { pub path: PathBuf, name: arc_swap::ArcSwap, @@ -96,24 +98,27 @@ impl TlogReader { reader.consume(8); assert_eq!(reader.peek_exact(1).await?[0], mavlink::MAV_STX_V2); - let message = match mavlink::read_v2_raw_message_async::(&mut reader) - .await - { - Ok(message) => Protocol::new_with_timestamp(us_since_epoch, &source_name, message), - Err(error) => { - match error { - mavlink::error::MessageReadError::Io(_) => (), - mavlink::error::MessageReadError::Parse(_) => { - error!("Failed to parse MAVLink message: {error:?}") + let message = + match mavlink::read_v2_raw_message_async::(&mut reader).await { + Ok(message) => Protocol::new_with_timestamp( + us_since_epoch, + &source_name, + Packet::from(message), + ), + Err(error) => { + match error { + mavlink::error::MessageReadError::Io(_) => (), + mavlink::error::MessageReadError::Parse(_) => { + error!("Failed to parse MAVLink message: {error:?}") + } } - } - continue; - } - }; - reader.consume(message.raw_bytes().len() - 1); + continue; + } + }; + reader.consume(message.bytes().len() - 1); - trace!("Parsed message: {:?}", message.raw_bytes()); + trace!("Parsed message: {:?}", message.bytes()); let message = Arc::new(message); @@ -316,7 +321,7 @@ mod tests { .map(|message| { let parsed_message = mavlink::MavFrame::::deser( mavlink::MavlinkVersion::V2, - &message.raw_bytes()[4..], + &message.bytes()[4..], ); (message.timestamp, parsed_message) diff --git a/src/lib/drivers/tlog/writer.rs b/src/lib/drivers/tlog/writer.rs index 3a61d30a..33877a45 100644 --- a/src/lib/drivers/tlog/writer.rs +++ b/src/lib/drivers/tlog/writer.rs @@ -17,6 +17,7 @@ use crate::{ }, }; +#[derive(Debug)] pub struct TlogWriter { pub path: PathBuf, name: arc_swap::ArcSwap, @@ -83,7 +84,7 @@ impl TlogWriter { } } - let raw_bytes = message.raw_bytes(); + let raw_bytes = message.bytes(); writer.write_all(×tamp.to_be_bytes()).await?; writer.write_all(raw_bytes).await?; writer.flush().await?; diff --git a/src/lib/drivers/udp/client.rs b/src/lib/drivers/udp/client.rs index f3ddd008..07ca0bfd 100644 --- a/src/lib/drivers/udp/client.rs +++ b/src/lib/drivers/udp/client.rs @@ -1,22 +1,26 @@ -use std::sync::Arc; +use std::{net::SocketAddr, sync::Arc}; use anyhow::Result; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use mavlink_codec::{codec::MavlinkCodec, error::DecoderError, Packet}; use tokio::{ net::UdpSocket, sync::{broadcast, RwLock}, }; +use tokio_util::udp::UdpFramed; use tracing::*; use crate::{ callbacks::{Callbacks, MessageCallback}, - drivers::{Driver, DriverInfo}, - protocol::{read_all_messages, Protocol}, + drivers::{generic_tasks::SendReceiveContext, Driver, DriverInfo}, + protocol::Protocol, stats::{ accumulated::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider}, driver::DriverUuid, }, }; +#[derive(Debug)] pub struct UdpClient { pub remote_addr: String, name: arc_swap::ArcSwap, @@ -67,114 +71,33 @@ impl UdpClient { ))), }) } - - #[instrument(level = "debug", skip(self, socket))] - async fn udp_receive_task( - &self, - socket: Arc, - hub_sender: Arc>>, - ) -> Result<()> { - let mut buf = Vec::with_capacity(1024); - - loop { - match socket.recv_buf_from(&mut buf).await { - Ok((bytes_received, client_addr)) if bytes_received > 0 => { - let client_addr = &client_addr.to_string(); - - read_all_messages(client_addr, &mut buf, |message| async { - let message = Arc::new(message); - - self.stats - .write() - .await - .stats - .update_input(&message); - - for future in self.on_message_input.call_all(message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_input callback returned error: {error:?}"); - continue; - } - } - - if let Err(error) = hub_sender.send(message) { - error!("Failed to send message to hub: {error:?}"); - } - }) - .await; - } - Ok((_, client_addr)) => { - warn!("UDP connection closed by {client_addr}."); - break; - } - Err(error) => { - error!("Failed to receive UDP message: {error:?}"); - break; - } - } - } - - debug!("UdpClient Receiver task finished"); - Ok(()) - } - - #[instrument(level = "debug", skip(self, socket, hub_receiver))] - async fn udp_send_task( - &self, - socket: Arc, - mut hub_receiver: broadcast::Receiver>, - ) -> Result<()> { - loop { - match hub_receiver.recv().await { - Ok(message) => { - if message.origin.eq(&socket.peer_addr()?.to_string()) { - continue; // Don't do loopback - } - - self.stats.write().await.stats.update_output(&message); - - for future in self.on_message_output.call_all(message.clone()) { - if let Err(error) = future.await { - debug!( - "Dropping message: on_message_output callback returned error: {error:?}" - ); - continue; - } - } - - match socket.send(message.raw_bytes()).await { - Ok(_) => { - // Message sent successfully - } - Err(ref error) if error.kind() == std::io::ErrorKind::ConnectionRefused => { - // error!("UDP connection refused: {error:?}"); - continue; - } - Err(error) => { - error!("Failed to send UDP message: {error:?}"); - break; - } - } - } - Err(error) => { - error!("Failed to receive message from hub: {error:?}"); - } - } - } - Ok(()) - } } #[async_trait::async_trait] impl Driver for UdpClient { #[instrument(level = "debug", skip(self, hub_sender))] async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { - let local_addr = "0.0.0.0:0"; - let remote_addr = self.remote_addr.clone(); + let local_addr = "0.0.0.0:0".parse::().unwrap(); + let remote_addr = self.remote_addr.parse::()?; + + let context = SendReceiveContext { + hub_sender, + on_message_output: self.on_message_output.clone(), + on_message_input: self.on_message_input.clone(), + stats: self.stats.clone(), + }; + let mut first = true; loop { + if !first { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + first = false; + } + + debug!("Trying to bind to address {local_addr:?}..."); + let socket = match UdpSocket::bind(local_addr).await { - Ok(socket) => Arc::new(socket), + Ok(socket) => socket, Err(error) => { error!("Failed binding UdpClient to address {local_addr:?}: {error:?}"); tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; @@ -192,20 +115,12 @@ impl Driver for UdpClient { debug!("UdpClient successfully connected to {remote_addr:?}"); - let hub_sender = Arc::new(hub_sender.clone()); - let hub_receiver = hub_sender.subscribe(); + let codec = MavlinkCodec::::default(); + let (writer, reader) = UdpFramed::new(socket, codec).split(); - tokio::select! { - result = self.udp_receive_task(socket.clone(), hub_sender) => { - if let Err(error) = result { - error!("Error in receiving UDP messages: {error:?}"); - } - } - result = self.udp_send_task(socket, hub_receiver) => { - if let Err(error) = result { - error!("Error in sending UDP messages: {error:?}"); - } - } + if let Err(reason) = udp_send_receive_run(writer, reader, &remote_addr, &context).await + { + warn!("Driver send/receive tasks closed: {reason:?}"); } } } @@ -224,6 +139,137 @@ impl Driver for UdpClient { } } +#[instrument(level = "debug", skip(writer, reader, context,))] +async fn udp_send_receive_run( + mut writer: S, + mut reader: T, + remote_addr: &SocketAddr, + context: &SendReceiveContext, +) -> Result<()> +where + S: Sink<(Packet, SocketAddr), Error = std::io::Error> + std::marker::Unpin, + T: Stream, SocketAddr)>> + + std::marker::Unpin, +{ + tokio::select! { + result = udp_send_task(&mut writer, remote_addr, context) => { + if let Err(error) = result { + error!("Error in send task for {remote_addr}: {error:?}"); + } + } + result = udp_receive_task(&mut reader, remote_addr, context) => { + if let Err(error) = result { + error!("Error in receive task for {remote_addr}: {error:?}"); + } + } + } + + Ok(()) +} + +/// Receives messages from a Stream and sends them to the HUB Channel +#[instrument(level = "debug", skip(reader, context))] +async fn udp_receive_task( + reader: &mut T, + remote_addr: &SocketAddr, + context: &SendReceiveContext, +) -> Result<()> +where + T: Stream, SocketAddr)>> + + std::marker::Unpin, +{ + loop { + let (packet, remote_addr) = match reader.next().await { + Some(Ok((Ok(packet), remote_addr))) => (packet, remote_addr), + Some(Ok((Err(decode_error), remote_addr))) => { + error!(origin = ?remote_addr, "Failed to decode packet: {decode_error:?}"); + continue; + } + Some(Err(io_error)) => { + error!("Critical error trying to decode data from: {io_error:?}"); + break; + } + None => break, + }; + + let message = Arc::new(Protocol::new(&remote_addr.to_string(), packet)); + + trace!(origin = ?remote_addr, "Received message: {message:?}"); + + context.stats.write().await.stats.update_input(&message); + + for future in context.on_message_input.call_all(message.clone()) { + if let Err(error) = future.await { + debug!(origin = ?remote_addr, "Dropping message: on_message_input callback returned error: {error:?}"); + continue; + } + } + + if let Err(send_error) = context.hub_sender.send(message) { + error!(origin = ?remote_addr, "Failed to send message to hub: {send_error:?}"); + continue; + } + + trace!(origin = ?remote_addr, "Message sent to hub"); + } + + debug!("Driver receiver task stopped!"); + + Ok(()) +} + +/// Receives messages from the HUB Channel and sends them to a Sink +#[instrument(level = "debug", skip(writer, context))] +async fn udp_send_task( + writer: &mut S, + remote_addr: &SocketAddr, + context: &SendReceiveContext, +) -> Result<()> +where + S: Sink<(Packet, SocketAddr), Error = std::io::Error> + std::marker::Unpin, +{ + let mut hub_receiver = context.hub_sender.subscribe(); + + loop { + let message = match hub_receiver.recv().await { + Ok(message) => message, + Err(broadcast::error::RecvError::Closed) => { + error!("Hub channel closed!"); + break; + } + Err(broadcast::error::RecvError::Lagged(count)) => { + warn!("Channel lagged by {count} messages."); + continue; + } + }; + + if message.origin.eq(&remote_addr.to_string()) { + continue; // Don't do loopback + } + + context.stats.write().await.stats.update_output(&message); + + for future in context.on_message_output.call_all(message.clone()) { + if let Err(error) = future.await { + debug!(remote = ?remote_addr, "Dropping message: on_message_output callback returned error: {error:?}"); + continue; + } + } + + if let Err(io_error) = writer.send(((**message).clone(), *remote_addr)).await { + if io_error.kind() == std::io::ErrorKind::ConnectionRefused { + continue; + } + + trace!(remote = ?remote_addr, "Failed to send message: {io_error:?}"); + break; + } + + trace!("Message sent to {remote_addr}: {:?}", message.as_slice()); + } + Ok(()) +} + #[async_trait::async_trait] impl AccumulatedDriverStatsProvider for UdpClient { async fn stats(&self) -> AccumulatedDriverStats { diff --git a/src/lib/drivers/udp/server.rs b/src/lib/drivers/udp/server.rs index 649d3a1e..0dff3e81 100644 --- a/src/lib/drivers/udp/server.rs +++ b/src/lib/drivers/udp/server.rs @@ -1,27 +1,31 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use anyhow::Result; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use mavlink_codec::{codec::MavlinkCodec, error::DecoderError, Packet}; use tokio::{ net::UdpSocket, sync::{broadcast, RwLock}, }; +use tokio_util::udp::UdpFramed; use tracing::*; use crate::{ callbacks::{Callbacks, MessageCallback}, - drivers::{Driver, DriverInfo}, - protocol::{read_all_messages, Protocol}, + drivers::{generic_tasks::SendReceiveContext, Driver, DriverInfo}, + protocol::Protocol, stats::{ accumulated::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider}, driver::DriverUuid, }, }; +#[derive(Debug)] pub struct UdpServer { pub local_addr: String, name: arc_swap::ArcSwap, uuid: DriverUuid, - clients: Arc>>, + clients: Arc>>, on_message_input: Callbacks>, on_message_output: Callbacks>, stats: Arc>, @@ -69,130 +73,33 @@ impl UdpServer { ))), }) } - - #[instrument(level = "debug", skip(self, socket, hub_sender, clients))] - async fn udp_receive_task( - &self, - socket: Arc, - hub_sender: Arc>>, - clients: Arc>>, - ) -> Result<()> { - let mut buf = Vec::with_capacity(1024); - - loop { - match socket.recv_buf_from(&mut buf).await { - Ok((bytes_received, client_addr)) if bytes_received > 0 => { - let client_addr = &client_addr.to_string(); - - read_all_messages(client_addr, &mut buf, |message| async { - let message = Arc::new(message); - - self.stats - .write() - .await - .stats - .update_input(&message); - - for future in self.on_message_input.call_all(message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_input callback returned error: {error:?}"); - continue; - } - } - - // Update clients - let sysid = message.system_id(); - let compid = message.component_id(); - if let Some(old_client_addr) = clients - .write() - .await - .insert((sysid, compid), client_addr.clone()) - { - debug!("Client ({sysid},{compid}) updated from {old_client_addr:?} (OLD) to {client_addr:?} (NEW)"); - } else { - debug!("Client added: ({sysid},{compid}) -> {client_addr:?}"); - } - - - if let Err(error) = hub_sender.send(message) { - error!("Failed to send message to hub: {error:?}"); - } - }) - .await; - } - Ok((_, client_addr)) => { - warn!("UDP connection closed by {client_addr}."); - break; - } - Err(error) => { - error!("Failed to receive UDP message: {error:?}"); - break; - } - } - } - - debug!("UdpServer Receiver task finished"); - Ok(()) - } - - #[instrument(level = "debug", skip(self, socket, hub_receiver, clients))] - async fn udp_send_task( - &self, - socket: Arc, - mut hub_receiver: broadcast::Receiver>, - clients: Arc>>, - ) -> Result<()> { - loop { - match hub_receiver.recv().await { - Ok(message) => { - for ((_, _), client_addr) in clients.read().await.iter() { - if message.origin.eq(client_addr) { - continue; // Don't do loopback - } - - self.stats.write().await.stats.update_output(&message); - - for future in self.on_message_output.call_all(message.clone()) { - if let Err(error) = future.await { - debug!("Dropping message: on_message_output callback returned error: {error:?}"); - continue; - } - } - - match socket.send_to(message.raw_bytes(), client_addr).await { - Ok(_) => { - // Message sent successfully - } - Err(ref error) - if error.kind() == std::io::ErrorKind::ConnectionRefused => - { - // error!("UDP connection refused: {error:?}"); - continue; - } - Err(error) => { - error!("Failed to send UDP message: {error:?}"); - break; - } - } - } - } - Err(error) => { - error!("Failed to receive message from hub: {error:?}"); - } - } - } - } } #[async_trait::async_trait] impl Driver for UdpServer { #[instrument(level = "debug", skip(self, hub_sender))] async fn run(&self, hub_sender: broadcast::Sender>) -> Result<()> { - let local_addr = &self.local_addr; + let local_addr = self.local_addr.parse::()?; let clients = self.clients.clone(); + + let context = SendReceiveContext { + hub_sender, + on_message_output: self.on_message_output.clone(), + on_message_input: self.on_message_input.clone(), + stats: self.stats.clone(), + }; + + let mut first = true; loop { + if !first { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + first = false; + } + + debug!("Trying to bind to address {local_addr:?}..."); + let socket = match UdpSocket::bind(&local_addr).await { - Ok(socket) => Arc::new(socket), + Ok(socket) => socket, Err(error) => { error!("Failed binding UdpServer to address {local_addr:?}: {error:?}"); tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; @@ -200,20 +107,15 @@ impl Driver for UdpServer { } }; - let hub_sender = Arc::new(hub_sender.clone()); - let hub_receiver = hub_sender.subscribe(); + debug!("UdpServer successfully bound to {local_addr}"); - tokio::select! { - result = self.udp_receive_task(socket.clone(), hub_sender, clients.clone()) => { - if let Err(error) = result { - error!("Error in receiving UDP messages: {error:?}"); - } - } - result = self.udp_send_task(socket, hub_receiver, clients.clone()) => { - if let Err(error) = result { - error!("Error in sending UDP messages: {error:?}"); - } - } + let codec = MavlinkCodec::::default(); + let (writer, reader) = UdpFramed::new(socket, codec).split(); + + if let Err(reason) = + udp_send_receive_run(writer, reader, clients.clone(), local_addr, &context).await + { + warn!("Driver send/receive tasks closed: {reason:?}"); } } } @@ -232,6 +134,158 @@ impl Driver for UdpServer { } } +async fn udp_send_receive_run( + mut writer: S, + mut reader: T, + clients: Arc>>, + local_addr: SocketAddr, + context: &SendReceiveContext, +) -> Result<()> +where + S: Sink<(Packet, SocketAddr), Error = std::io::Error> + std::marker::Unpin, + T: Stream, SocketAddr)>> + + std::marker::Unpin, +{ + tokio::select! { + result = udp_send_task(&mut writer, clients.clone(), local_addr, context) => { + if let Err(error) = result { + error!("Error in send task for {clients:?}: {error:?}"); + } + } + result = udp_receive_task(&mut reader, clients.clone(), local_addr, context) => { + if let Err(error) = result { + error!("Error in receive task for {clients:?}: {error:?}"); + } + } + } + + Ok(()) +} + +/// Receives messages from a Stream and sends them to the HUB Channel +#[instrument(level = "debug", skip(reader, context))] +async fn udp_receive_task( + reader: &mut T, + clients: Arc>>, + local_addr: SocketAddr, + context: &SendReceiveContext, +) -> Result<()> +where + T: Stream, SocketAddr)>> + + std::marker::Unpin, +{ + loop { + let (packet, client_addr) = match reader.next().await { + Some(Ok((Ok(packet), client_addr))) => (packet, client_addr), + Some(Ok((Err(decode_error), client_addr))) => { + error!(origin = ?client_addr, "Failed to decode packet: {decode_error:?}"); + continue; + } + Some(Err(io_error)) => { + error!("Critical error trying to decode data from: {io_error:?}"); + break; + } + None => break, + }; + + let message = Arc::new(Protocol::new(&client_addr.to_string(), packet)); + + trace!(origin = ?client_addr, "Received message: {message:?}"); + + context.stats.write().await.stats.update_input(&message); + + for future in context.on_message_input.call_all(message.clone()) { + if let Err(error) = future.await { + debug!(origin = ?client_addr, "Dropping message: on_message_input callback returned error: {error:?}"); + continue; + } + } + + // Update clients + let sysid = *message.system_id(); + let compid = *message.component_id(); + + { + let mut clients = clients.write().await; + + if let Some(old_client_addr) = clients.get_mut(&(sysid, compid)) { + if old_client_addr != &client_addr { + *old_client_addr = client_addr; + + debug!("Client ({sysid},{compid}) updated from {old_client_addr:?} (OLD) to {client_addr:?} (NEW)"); + } + } else { + clients.insert((sysid, compid), client_addr); + debug!("Client added: ({sysid},{compid}) -> {client_addr:?}"); + } + } + + if let Err(send_error) = context.hub_sender.send(message) { + error!(origin = ?client_addr, "Failed to send message to hub: {send_error:?}"); + continue; + } + + trace!(origin = ?client_addr, "Message sent to hub"); + } + + debug!("Driver receiver task stopped!"); + + Ok(()) +} + +/// Receives messages from the HUB Channel and sends them to a Sink +#[instrument(level = "debug", skip(writer, context))] +async fn udp_send_task( + writer: &mut S, + clients: Arc>>, + local_addr: SocketAddr, + context: &SendReceiveContext, +) -> Result<()> +where + S: Sink<(Packet, SocketAddr), Error = std::io::Error> + std::marker::Unpin, +{ + let mut hub_receiver = context.hub_sender.subscribe(); + + loop { + let message = match hub_receiver.recv().await { + Ok(message) => message, + Err(broadcast::error::RecvError::Closed) => { + error!("Hub channel closed!"); + break; + } + Err(broadcast::error::RecvError::Lagged(count)) => { + warn!("Channel lagged by {count} messages."); + continue; + } + }; + + for ((_, _), client_addr) in clients.read().await.iter() { + if message.origin.eq(&client_addr.to_string()) { + continue; // Don't do loopback + } + + context.stats.write().await.stats.update_output(&message); + + for future in context.on_message_output.call_all(message.clone()) { + if let Err(error) = future.await { + debug!( + client = ?client_addr, "Dropping message: on_message_output callback returned error: {error:?}" + ); + continue; + } + } + + if let Err(error) = writer.send(((**message).clone(), *client_addr)).await { + error!(client = ?client_addr, "Failed to send message: {error:?}"); + break; + } + + trace!("Message sent to {client_addr}: {:?}", message.as_slice()); + } + } + Ok(()) +} + #[async_trait::async_trait] impl AccumulatedDriverStatsProvider for UdpServer { async fn stats(&self) -> AccumulatedDriverStats { diff --git a/src/lib/hub/actor.rs b/src/lib/hub/actor.rs index cd7bfc35..36223342 100644 --- a/src/lib/hub/actor.rs +++ b/src/lib/hub/actor.rs @@ -2,7 +2,7 @@ use std::{ops::Div, sync::Arc}; use anyhow::{anyhow, Context, Result}; use indexmap::IndexMap; -use mavlink::MAVLinkV2MessageRaw; +use mavlink_codec::Packet; use tokio::sync::{broadcast, mpsc, RwLock}; use tracing::*; @@ -172,10 +172,12 @@ impl HubActor { ..Default::default() }; - let mut message_raw = Protocol::new("", MAVLinkV2MessageRaw::new()); + let mut message_raw = mavlink::MAVLinkV2MessageRaw::new(); message_raw.serialize_message(header, &message); - if let Err(error) = bcst_sender.send(Arc::new(message_raw)) { + let message = Protocol::new("", Packet::from(message_raw)); + + if let Err(error) = bcst_sender.send(Arc::new(message)) { error!("Failed to send HEARTBEAT message: {error}"); } } diff --git a/src/lib/logger.rs b/src/lib/logger.rs index 3f4b252f..7479fc6c 100644 --- a/src/lib/logger.rs +++ b/src/lib/logger.rs @@ -69,8 +69,5 @@ pub fn init() { chrono::Local::now().format("%Y-%m-%dT%H:%M:%S"), ); debug!("Command line call: {}", cli::command_line_string()); - debug!( - "Command line input struct call: {}", - cli::command_line_string() - ); + debug!("Command line input struct call: {}", cli::command_line()); } diff --git a/src/lib/protocol.rs b/src/lib/protocol.rs index 30d161f1..f93853df 100644 --- a/src/lib/protocol.rs +++ b/src/lib/protocol.rs @@ -1,83 +1,44 @@ -use std::{ - future::Future, - io::Cursor, - ops::{Deref, DerefMut}, -}; +use std::ops::{Deref, DerefMut}; -use mavlink::{ardupilotmega::MavMessage, MAVLinkV2MessageRaw}; +use mavlink_codec::Packet; use serde::Serialize; -use tracing::*; - #[derive(Debug, PartialEq, Serialize)] pub struct Protocol { pub origin: String, pub timestamp: u64, #[serde(skip)] - message: MAVLinkV2MessageRaw, + packet: Packet, } impl Protocol { - pub fn new(origin: &str, message: MAVLinkV2MessageRaw) -> Self { + pub fn new(origin: &str, packet: Packet) -> Self { Self { origin: origin.to_string(), timestamp: chrono::Utc::now().timestamp_micros() as u64, - message, + packet, } } - pub fn new_with_timestamp(timestamp: u64, origin: &str, message: MAVLinkV2MessageRaw) -> Self { + pub fn new_with_timestamp(timestamp: u64, origin: &str, packet: Packet) -> Self { Self { origin: origin.to_string(), timestamp, - message, + packet, } } } -pub async fn read_all_messages(origin: &str, buf: &mut Vec, process_message: F) -where - F: Fn(Protocol) -> Fut, - Fut: Future, -{ - let reader = Cursor::new(buf.as_slice()); - let mut reader: mavlink::async_peek_reader::AsyncPeekReader, 280> = - mavlink::async_peek_reader::AsyncPeekReader::new(reader); - - loop { - let message = match mavlink::read_v2_raw_message_async::(&mut reader).await { - Ok(message) => Protocol::new(origin, message), - Err(error) => { - match error { - mavlink::error::MessageReadError::Io(_) => (), - mavlink::error::MessageReadError::Parse(_) => { - error!("Failed to parse MAVLink message: {error:?}") - } - } - - break; - } - }; - - trace!("Parsed message: {:?}", message.raw_bytes()); - - process_message(message).await; - } - - let bytes_read = reader.reader_ref().position() as usize; - buf.drain(..bytes_read); -} - impl Deref for Protocol { - type Target = MAVLinkV2MessageRaw; + type Target = Packet; fn deref(&self) -> &Self::Target { - &self.message + &self.packet } } impl DerefMut for Protocol { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.message + &mut self.packet } } diff --git a/src/lib/stats/accumulated/messages.rs b/src/lib/stats/accumulated/messages.rs index f2323c31..49890707 100644 --- a/src/lib/stats/accumulated/messages.rs +++ b/src/lib/stats/accumulated/messages.rs @@ -28,10 +28,10 @@ pub struct AccumulatedComponentMessageStats { impl AccumulatedHubMessagesStats { pub fn update(&mut self, message: &Arc) { self.systems_messages_stats - .entry(message.system_id()) + .entry(*message.system_id()) .or_default() .components_messages_stats - .entry(message.component_id()) + .entry(*message.component_id()) .or_default() .messages_stats .entry(message.message_id()) diff --git a/src/lib/stats/accumulated/mod.rs b/src/lib/stats/accumulated/mod.rs index 8df21e59..e82ed2cf 100644 --- a/src/lib/stats/accumulated/mod.rs +++ b/src/lib/stats/accumulated/mod.rs @@ -32,7 +32,7 @@ impl AccumulatedStatsInner { pub fn update(&mut self, message: &Arc) { self.last_message = Some(message.clone()); self.last_update_us = chrono::Utc::now().timestamp_micros() as u64; - self.bytes = self.bytes.wrapping_add(message.raw_bytes().len() as u64); + self.bytes = self.bytes.wrapping_add(message.packet_size() as u64); self.messages = self.messages.wrapping_add(1); self.delay = self .delay diff --git a/src/lib/web/mod.rs b/src/lib/web/mod.rs index 34314e59..2fe83121 100644 --- a/src/lib/web/mod.rs +++ b/src/lib/web/mod.rs @@ -77,7 +77,7 @@ async fn websocket_connection(socket: WebSocket, state: AppState) { broadcast_message_websockets(&state, identifier, Message::Text(text)).await; } Message::Close(frame) => { - debug!("WS client {identifier} disconnected: {frame:#?}"); + debug!("WS client {identifier} disconnected: {frame:?}"); break; } _ => {}