diff --git a/common/wireguard/src/udp_listener.rs b/common/wireguard/src/udp_listener.rs index ec15fbf3d1e..8435955670e 100644 --- a/common/wireguard/src/udp_listener.rs +++ b/common/wireguard/src/udp_listener.rs @@ -19,6 +19,7 @@ use crate::{ const MAX_PACKET: usize = 65535; +pub(crate) type PeerIdx = u32; pub(crate) type ActivePeers = DashMap>; pub(crate) type PeersByIp = NetworkTable>; @@ -32,10 +33,13 @@ pub(crate) async fn start_udp_listener( log::info!("Starting wireguard UDP listener on {wg_address}"); let udp_socket = Arc::new(UdpSocket::bind(wg_address).await?); - // Setup some static keys for development + // Setup static key for development let static_private = setup::server_static_private_key(); + + // A single hardcoded peer let peer_static_public = setup::peer_static_public_key(); let peer_allowed_ips = setup::peer_allowed_ips(); + let peer_index = 0; tokio::spawn(async move { // Each tunnel is run in its own task, and the task handle is stored here so we can remove @@ -84,6 +88,7 @@ pub(crate) async fn start_udp_listener( static_private.clone(), peer_static_public, peer_allowed_ips, + peer_index, tun_task_tx.clone(), ); diff --git a/common/wireguard/src/wg_tunnel.rs b/common/wireguard/src/wg_tunnel.rs index e164d9f7f85..2d44a03fb18 100644 --- a/common/wireguard/src/wg_tunnel.rs +++ b/common/wireguard/src/wg_tunnel.rs @@ -2,7 +2,7 @@ use std::{net::SocketAddr, sync::Arc, time::Duration}; use async_recursion::async_recursion; use boringtun::{ - noise::{errors::WireGuardError, Tunn, TunnResult}, + noise::{errors::WireGuardError, rate_limiter::RateLimiter, Tunn, TunnResult}, x25519, }; use bytes::Bytes; @@ -14,7 +14,11 @@ use tokio::{ time::timeout, }; -use crate::{error::WgError, event::Event, network_table::NetworkTable, TunTaskTx}; +use crate::{ + error::WgError, event::Event, network_table::NetworkTable, udp_listener::PeerIdx, TunTaskTx, +}; + +const HANDSHAKE_MAX_RATE: u64 = 10; const MAX_PACKET: usize = 65535; @@ -56,6 +60,7 @@ impl WireGuardTunnel { static_private: x25519::StaticSecret, peer_static_public: x25519::PublicKey, peer_allowed_ips: ip_network::IpNetwork, + index: PeerIdx, tunnel_tx: TunTaskTx, ) -> (Self, mpsc::UnboundedSender) { let local_addr = udp.local_addr().unwrap(); @@ -64,8 +69,12 @@ impl WireGuardTunnel { let preshared_key = None; let persistent_keepalive = None; - let index = 0; - let rate_limiter = None; + + let static_public = x25519::PublicKey::from(&static_private); + let rate_limiter = Some(Arc::new(RateLimiter::new( + &static_public, + HANDSHAKE_MAX_RATE, + ))); let wg_tunnel = Arc::new(tokio::sync::Mutex::new( Tunn::new( @@ -288,6 +297,7 @@ pub(crate) fn start_wg_tunnel( static_private: x25519::StaticSecret, peer_static_public: x25519::PublicKey, peer_allowed_ips: ip_network::IpNetwork, + peer_index: PeerIdx, tunnel_tx: TunTaskTx, ) -> ( tokio::task::JoinHandle, @@ -299,6 +309,7 @@ pub(crate) fn start_wg_tunnel( static_private, peer_static_public, peer_allowed_ips, + peer_index, tunnel_tx, ); let join_handle = tokio::spawn(async move {