From 35490baa021d9ac98ba88abefd131dff4b6131c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juozapas=20Bo=C4=8Dkus?= Date: Thu, 26 Sep 2024 15:04:48 +0300 Subject: [PATCH] Add allow_multicast flags from mesh config to starcast integration --- .unreleased/LLT-5108 | 1 + crates/telio-starcast/src/nat.rs | 8 +- crates/telio-starcast/src/transport.rs | 135 ++++++++++++++++----- nat-lab/tests/mesh_api.py | 4 +- nat-lab/tests/test_mesh_api.py | 4 +- nat-lab/tests/test_multicast_connection.py | 57 +++++++++ src/device.rs | 14 ++- 7 files changed, 185 insertions(+), 38 deletions(-) create mode 100644 .unreleased/LLT-5108 diff --git a/.unreleased/LLT-5108 b/.unreleased/LLT-5108 new file mode 100644 index 000000000..203c73611 --- /dev/null +++ b/.unreleased/LLT-5108 @@ -0,0 +1 @@ +Add allow_multicast flags from meshnet config to Starcast implementation. \ No newline at end of file diff --git a/crates/telio-starcast/src/nat.rs b/crates/telio-starcast/src/nat.rs index 994e7fb0c..205a461ce 100644 --- a/crates/telio-starcast/src/nat.rs +++ b/crates/telio-starcast/src/nat.rs @@ -32,7 +32,7 @@ pub enum Error { pub trait Nat { /// Translate incoming packet (from the transport socket to the multicast peer) /// Change the source to the multicast peer's ip and natted port - fn translate_incoming(&mut self, packet: &mut [u8]) -> Result<(), Error>; + fn translate_incoming(&mut self, packet: &mut [u8]) -> Result; /// Translate outgoing packet (from the multicast peer to the transport socket) /// Change the destination to the peer's original ip and port fn translate_outgoing(&mut self, packet: &mut [u8]) -> Result; @@ -92,7 +92,7 @@ impl StarcastNat { fn translate_incoming_internal<'a, P: MutableIpPacket<'a>>( &mut self, packet: &'a mut [u8], - ) -> Result<(), Error> { + ) -> Result { let mut ip_packet = P::new(packet).ok_or(Error::PacketTooShort)?; if ip_packet.get_next_level_protocol() != IpNextHeaderProtocols::Udp { return Err(Error::UnexpectedTransportProtocol); @@ -139,7 +139,7 @@ impl StarcastNat { )) } - Ok(()) + Ok(old_src_ip.into()) } fn translate_outgoing_internal<'a, P: MutableIpPacket<'a>>( @@ -181,7 +181,7 @@ impl StarcastNat { } impl Nat for StarcastNat { - fn translate_incoming(&mut self, packet: &mut [u8]) -> Result<(), Error> { + fn translate_incoming(&mut self, packet: &mut [u8]) -> Result { match packet.first().ok_or(Error::PacketTooShort)? >> 4 { 4 => self.translate_incoming_internal::(packet), 6 => self.translate_incoming_internal::(packet), diff --git a/crates/telio-starcast/src/transport.rs b/crates/telio-starcast/src/transport.rs index fc66f260c..9cdd95b50 100644 --- a/crates/telio-starcast/src/transport.rs +++ b/crates/telio-starcast/src/transport.rs @@ -83,15 +83,19 @@ pub struct Peer { pub addr: SocketAddr, /// The public key of the peer pub public_key: PublicKey, + /// Whether our node accepts multicast messages from the peer or not. + pub allow_multicast: bool, + /// Whether the peer node accepts multicast messages from our node or not. + pub peer_allows_multicast: bool, } /// Config for transport component /// Contains fields that can change at runtime pub enum Config { /// Simple transport config, has IP of peer but not port - Simple(Vec<(PublicKey, IpAddr)>), + Simple(Vec<(PublicKey, IpAddr, bool, bool)>), /// Full transport config, has full socket address of peer - Full(Vec<(PublicKey, SocketAddr)>), + Full(Vec<(PublicKey, SocketAddr, bool, bool)>), } /// The starcast transport component @@ -200,14 +204,25 @@ impl State { self.peers = match config { Config::Simple(peers) => peers .into_iter() - .map(|(public_key, addr)| Peer { - public_key, - addr: SocketAddr::new(addr, MULTICAST_TRANSPORT_PORT), - }) + .map( + |(public_key, addr, allow_multicast, peer_allows_multicast)| Peer { + public_key, + addr: SocketAddr::new(addr, MULTICAST_TRANSPORT_PORT), + allow_multicast, + peer_allows_multicast, + }, + ) .collect(), Config::Full(peers) => peers .into_iter() - .map(|(public_key, addr)| Peer { public_key, addr }) + .map( + |(public_key, addr, allow_multicast, peer_allows_multicast)| Peer { + public_key, + addr, + allow_multicast, + peer_allows_multicast, + }, + ) .collect(), }; } @@ -216,11 +231,15 @@ impl State { let Some(transport_socket) = self.transport_socket.as_ref() else { return Err(Error::TransportSocketNotOpen); }; - let failed_peers = join_all(self.peers.iter().map(|peer| { - transport_socket - .send_to(&packet, peer.addr) - .map_err(|_| peer.public_key) - })) + // If peer_allows_multicast is false for a peer, we cannot send multicast packets to that peer, + // but we can still receive multicast packets from that peer. + let failed_peers = join_all(self.peers.iter().filter(|p| p.peer_allows_multicast).map( + |peer| { + transport_socket + .send_to(&packet, peer.addr) + .map_err(|_| peer.public_key) + }, + )) .await .into_iter() .filter_map(|res| match res { @@ -256,6 +275,34 @@ impl State { .map_err(|_| Error::SocketSendError) } + /// Separate method for handling starcast packets received on the transport socket from other + /// meshnet nodes and dropping those packets if multicast isn't allowed for those nodes. + async fn handle_incoming_packet( + &mut self, + mut packet: Vec, + send_permit: tokio::sync::mpsc::OwnedPermit>, + ) -> Result<(), Error> { + let peer_ip = self + .nat + .translate_incoming(&mut packet) + .map_err(Error::NatError)?; + if self + .peers + .iter() + .find(|peer| peer.addr.ip() == peer_ip) + // If allow_multicast is false for a peer, we drop any multicast packets that were + // received from that peer, but we can still send multicast packets to that peer. + .filter(|peer| peer.allow_multicast) + .is_some() + { + // If a starcast packet is received from a peer which is not present in the peer list, + // we assume that multicast is disallowed for it. + send_permit.send(packet); + }; + + Ok(()) + } + fn has_multicast_dst(&self, packet: &mut [u8]) -> Result { let dst = match packet.first().ok_or(Error::InvalidIpPacket)? >> 4 { 4 => Self::get_packet_dst::(packet), @@ -315,12 +362,8 @@ impl Runtime for State { } Some((permit, Ok(bytes_read))) = wait_for_tx(&self.packet_chan.tx, transport_socket.recv(&mut self.recv_buffer)) => { #[allow(clippy::expect_used)] - let mut packet = self.recv_buffer.get(..bytes_read).expect("We know bytes_read bytes should be in the buffer at this point").to_vec(); - self.nat.translate_incoming(&mut packet) - .map_err(Error::NatError) - .map(|_| { - let _ = permit.send(packet); - }) + let packet = self.recv_buffer.get(..bytes_read).expect("We know bytes_read bytes should be in the buffer at this point").to_vec(); + self.handle_incoming_packet(packet, permit).await } else => { telio_log_warn!("MutlicastListener: no events to wait on"); @@ -351,23 +394,49 @@ mod tests { task: Task, transport_socket: Arc, channel: Chan>, - peers: Vec<(PublicKey, UdpSocket)>, + peers: Vec<(PublicKey, UdpSocket, bool, bool)>, } impl Scaffold { async fn start() -> Self { let transport_socket = Arc::new(Self::bind_local_socket().await); - let mut peers = Vec::with_capacity(3); - for _ in 0..3 { - peers.push((SecretKey::gen().public(), Self::bind_local_socket().await)); - } + // Peers with all the different possible meshnet configurations: + let peers = vec![ + ( + SecretKey::gen().public(), + Self::bind_local_socket().await, + true, + true, + ), + ( + SecretKey::gen().public(), + Self::bind_local_socket().await, + false, + true, + ), + ( + SecretKey::gen().public(), + Self::bind_local_socket().await, + true, + false, + ), + ( + SecretKey::gen().public(), + Self::bind_local_socket().await, + false, + false, + ), + ]; + let (packet_chan, channel) = Chan::pipe(); let task_peers = peers .iter() - .map(|(pk, s)| Peer { + .map(|(pk, s, allow_multicast, peer_allows_multicast)| Peer { public_key: *pk, addr: s.local_addr().unwrap(), + allow_multicast: *allow_multicast, + peer_allows_multicast: *peer_allows_multicast, }) .collect(); let multicast_ips = vec![IpNet::new("224.0.0.0".parse().unwrap(), 4).unwrap()]; @@ -459,11 +528,19 @@ mod tests { scaffold.channel.tx.send(packet.clone()).await.unwrap(); - for (_, socket) in &scaffold.peers { + for (_, socket, _, peer_allows_multicast) in &scaffold.peers { let mut buffer = vec![0; TEST_MAX_PACKET_SIZE]; - let bytes_read = socket.recv(&mut buffer).await.unwrap(); - buffer.truncate(bytes_read); - assert_eq!(buffer, packet); + if *peer_allows_multicast { + let bytes_read = socket.recv(&mut buffer).await.unwrap(); + buffer.truncate(bytes_read); + assert_eq!(buffer, packet); + } else { + // Using timeout here, because otherwise the socket will just wait forever. + let result = + tokio::time::timeout(Duration::from_millis(100), socket.recv_from(&mut buffer)) + .await; + assert!(result.is_err()); + } } scaffold.stop().await; @@ -494,7 +571,7 @@ mod tests { tokio::task::yield_now().await; - for (_, socket) in scaffold.peers.iter().skip(1) { + for (_, socket, _, _) in scaffold.peers.iter().skip(1) { let mut buffer = vec![0; TEST_MAX_PACKET_SIZE]; assert!(socket.try_recv(&mut buffer).is_err()); } diff --git a/nat-lab/tests/mesh_api.py b/nat-lab/tests/mesh_api.py index b9dcdbdbe..34ed0433c 100644 --- a/nat-lab/tests/mesh_api.py +++ b/nat-lab/tests/mesh_api.py @@ -167,8 +167,8 @@ def to_peer_config_for_node(self, node) -> Peer: is_local=node.is_local and self.is_local, allow_incoming_connections=firewall_config.allow_incoming_connections, allow_peer_send_files=firewall_config.allow_peer_send_files, - allow_multicast=False, - peer_allows_multicast=False, + allow_multicast=True, + peer_allows_multicast=True, ) def set_peer_firewall_settings( diff --git a/nat-lab/tests/test_mesh_api.py b/nat-lab/tests/test_mesh_api.py index 11fd38b1d..f733b76d3 100644 --- a/nat-lab/tests/test_mesh_api.py +++ b/nat-lab/tests/test_mesh_api.py @@ -43,8 +43,8 @@ def test_to_peer_config(self) -> None: is_local=False, allow_incoming_connections=False, allow_peer_send_files=False, - allow_multicast=False, - peer_allows_multicast=False, + allow_multicast=True, + peer_allows_multicast=True, ) assert expected == node.to_peer_config_for_node(node) diff --git a/nat-lab/tests/test_multicast_connection.py b/nat-lab/tests/test_multicast_connection.py index b07eae0c2..c0af87195 100644 --- a/nat-lab/tests/test_multicast_connection.py +++ b/nat-lab/tests/test_multicast_connection.py @@ -5,6 +5,7 @@ from utils.bindings import default_features, TelioAdapterType from utils.connection_util import ConnectionTag, Connection, TargetOS from utils.multicast import MulticastClient, MulticastServer +from utils.process import ProcessExecError def generate_setup_parameter_pair( @@ -99,3 +100,59 @@ async def test_multicast(setup_params: List[SetupParameters], protocol: str) -> async with MulticastServer(beta_connection, protocol).run() as server: await server.wait_till_ready() await MulticastClient(alpha_connection, protocol).execute() + + +MUILTICAST_DISALLOWED_TEST_PARAMS = [ + pytest.param( + generate_setup_parameter_pair([ + (ConnectionTag.DOCKER_FULLCONE_CLIENT_1, TelioAdapterType.BORING_TUN), + (ConnectionTag.DOCKER_FULLCONE_CLIENT_2, TelioAdapterType.BORING_TUN), + ]), + "ssdp", + ), + pytest.param( + generate_setup_parameter_pair([ + (ConnectionTag.DOCKER_SYMMETRIC_CLIENT_1, TelioAdapterType.BORING_TUN), + (ConnectionTag.DOCKER_SYMMETRIC_CLIENT_2, TelioAdapterType.BORING_TUN), + ]), + "mdns", + ), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("setup_params, protocol", MUILTICAST_DISALLOWED_TEST_PARAMS) +async def test_multicast_disallowed( + setup_params: List[SetupParameters], protocol: str +) -> None: + async with AsyncExitStack() as exit_stack: + env = await setup_mesh_nodes(exit_stack, setup_params) + + alpha_connection, beta_connection = [ + conn.connection for conn in env.connections + ] + + client_alpha, client_beta = env.clients + + alpha, beta = env.nodes + mesh_config_alpha = env.api.get_meshnet_config(alpha.id) + if mesh_config_alpha.peers is not None: + for peer in mesh_config_alpha.peers: + if peer.base.hostname == beta.hostname: + peer.allow_multicast = False + await client_alpha.set_meshnet_config(mesh_config_alpha) + + mesh_config_beta = env.api.get_meshnet_config(beta.id) + if mesh_config_beta.peers is not None: + for peer in mesh_config_beta.peers: + if peer.base.hostname == alpha.hostname: + peer.peer_allows_multicast = False + await client_beta.set_meshnet_config(mesh_config_beta) + + await add_multicast_route(alpha_connection) + await add_multicast_route(beta_connection) + + async with MulticastServer(beta_connection, protocol).run() as server: + with pytest.raises(ProcessExecError): + await server.wait_till_ready() + await MulticastClient(alpha_connection, protocol).execute() diff --git a/src/device.rs b/src/device.rs index c2ec82aed..e3c8a536e 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1790,6 +1790,11 @@ impl Runtime { .peers .unwrap_or_default() .iter() + .filter(|p| { + // If neither our node nor peer node allow multicast, there's no point in keeping + // that peer in the config. + p.allow_multicast || p.peer_allows_multicast + }) .filter_map(|p| { p.ip_addresses .to_owned() @@ -1797,7 +1802,14 @@ impl Runtime { .iter() // While IPV6 support is not added yet for multicast, only using IPV4 IPs .find(|ip| ip.is_ipv4()) - .map(|ip| (p.base.public_key, ip.to_owned())) + .map(|ip| { + ( + p.base.public_key, + ip.to_owned(), + p.allow_multicast, + p.peer_allows_multicast, + ) + }) }) .collect(); let starcast_transport_config = StarcastTransportConfig::Simple(multicast_peers);