Skip to content

Commit

Permalink
wireguard: try to have a flow where we parse the public key up front
Browse files Browse the repository at this point in the history
  • Loading branch information
octol committed Oct 11, 2023
1 parent 3913e3e commit f3098b4
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 28 deletions.
8 changes: 7 additions & 1 deletion common/wireguard/src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use std::fmt::{Display, Formatter};
use bytes::Bytes;

#[allow(unused)]
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum Event {
/// IP packet received from the WireGuard tunnel that should be passed through to the corresponding virtual device/internet.
/// Original implementation also has protocol here since it understands it, but we'll have to infer it downstream
WgPacket(Bytes),
/// IP packet received from the UDP listener that was verified as part of the handshake
WgVerifiedPacket(Bytes),
/// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device.
IpPacket(Bytes),
}
Expand All @@ -19,6 +21,10 @@ impl Display for Event {
let size = data.len();
write!(f, "WgPacket{{ size={size} }}")
}
Event::WgVerifiedPacket(data) => {
let size = data.len();
write!(f, "WgVerifiedPacket{{ size={size} }}")
}
Event::IpPacket(data) => {
let size = data.len();
write!(f, "IpPacket{{ size={size} }}")
Expand Down
112 changes: 91 additions & 21 deletions common/wireguard/src/udp_listener.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::{net::SocketAddr, sync::Arc};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};

use boringtun::{
noise::{handshake::parse_handshake_anon, rate_limiter::RateLimiter, TunnResult},
x25519,
};
use dashmap::DashMap;
use futures::StreamExt;
use ip_network::IpNetwork;
use log::error;
use nym_task::TaskClient;
use tap::TapFallible;
Expand All @@ -19,11 +24,19 @@ use crate::{

const MAX_PACKET: usize = 65535;

pub(crate) type ActivePeers = DashMap<SocketAddr, mpsc::UnboundedSender<Event>>;
pub(crate) type ActivePeers = DashMap<x25519::PublicKey, mpsc::UnboundedSender<Event>>;
pub(crate) type PeersByIp = NetworkTable<mpsc::UnboundedSender<Event>>;

struct RegisteredPeer {
// peer_tx: mpsc::UnboundedSender<Event>,
public_key: x25519::PublicKey,
allowed_ips: IpNetwork,
// endpoint: SocketAddr,
}

pub(crate) async fn start_udp_listener(
tun_task_tx: TunTaskTx,
// active_peers: Arc<ActivePeers>,
active_peers: Arc<ActivePeers>,
peers_by_ip: Arc<std::sync::Mutex<PeersByIp>>,
mut task_client: TaskClient,
Expand All @@ -32,16 +45,38 @@ pub(crate) async fn start_udp_listener(
log::info!("Starting wireguard UDP listener on {wg_address}");
let udp_socket = Arc::new(UdpSocket::bind(wg_address).await?);

// Setup some static keys for development
// Setup our own keys
let static_private = setup::server_static_private_key();
let static_public = x25519::PublicKey::from(&static_private);
let handshake_max_rate = 100u64;
let rate_limiter = RateLimiter::new(&static_public, handshake_max_rate);

// Test peer
let peer_static_public = setup::peer_static_public_key();
let peer_allowed_ips = setup::peer_allowed_ips();
let test_peer = Arc::new(tokio::sync::Mutex::new(RegisteredPeer {
public_key: peer_static_public,
allowed_ips: peer_allowed_ips,
}));

type PeerIdx = u32;
let mut registered_peers: HashMap<x25519::PublicKey, Arc<tokio::sync::Mutex<RegisteredPeer>>> =
HashMap::new();
let mut registered_peers_by_idx: HashMap<PeerIdx, Arc<tokio::sync::Mutex<RegisteredPeer>>> =
HashMap::new();

registered_peers.insert(peer_static_public, Arc::clone(&test_peer));
registered_peers_by_idx.insert(0, test_peer);

tokio::spawn(async move {
// Each tunnel is run in its own task, and the task handle is stored here so we can remove
// it from `active_peers` when the tunnel is closed
let mut active_peers_task_handles = futures::stream::FuturesUnordered::new();

let mut buf = [0u8; MAX_PACKET];
let mut dst_buf = [0u8; MAX_PACKET];

// let mut buf2 = [0u8; MAX_PACKET];

while !task_client.is_shutdown() {
tokio::select! {
Expand All @@ -50,11 +85,11 @@ pub(crate) async fn start_udp_listener(
break;
}
// Handle tunnel closing
Some(addr) = active_peers_task_handles.next() => {
match addr {
Ok(addr) => {
log::info!("Removing peer: {addr:?}");
active_peers.remove(&addr);
Some(public_key) = active_peers_task_handles.next() => {
match public_key {
Ok(public_key) => {
log::info!("Removing peer: {public_key:?}");
active_peers.remove(&public_key);
// TODO: remove from peers_by_ip
}
Err(err) => {
Expand All @@ -65,38 +100,73 @@ pub(crate) async fn start_udp_listener(
// Handle incoming packets
Ok((len, addr)) = udp_socket.recv_from(&mut buf) => {
log::trace!("udp: received {} bytes from {}", len, addr);
let verified_packet = match rate_limiter.verify_packet(Some(addr.ip()), &buf[..len], &mut dst_buf) {
Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => {
log::info!("WireGuard UDP listener: send back cookie");
udp_socket.send_to(cookie, addr).await.unwrap();
return;
}
Err(err) => {
log::warn!("{err:?}");
return;
}
};

if let Some(peer_tx) = active_peers.get_mut(&addr) {
// Check if this is a registered peer, if not just drop
let registered_peer = match verified_packet {
boringtun::noise::Packet::HandshakeInit(ref packet) => {
let Ok(handshake) = parse_handshake_anon(&static_private, &static_public, &packet) else {
log::warn!("Handshake failed");
return;
};
registered_peers.get(&x25519::PublicKey::from(handshake.peer_static_public))
},
boringtun::noise::Packet::HandshakeResponse(packet) => {
let peer_idx = packet.receiver_idx >> 8;
registered_peers_by_idx.get(&peer_idx)
},
boringtun::noise::Packet::PacketCookieReply(packet) => {
let peer_idx = packet.receiver_idx >> 8;
registered_peers_by_idx.get(&peer_idx)
},
boringtun::noise::Packet::PacketData(packet) => {
let peer_idx = packet.receiver_idx >> 8;
registered_peers_by_idx.get(&peer_idx)
},
};

let Some(registered_peer) = registered_peer else {
log::warn!("Peer not registered");
return;
};
let registered_peer = registered_peer.lock().await;

// Look up if the peer is already connected
if let Some(peer_tx) = active_peers.get_mut(&registered_peer.public_key) {
log::info!("udp: received {len} bytes from {addr} from known peer");
peer_tx.send(Event::WgPacket(buf[..len].to_vec().into()))
peer_tx.send(Event::WgVerifiedPacket(buf[..len].to_vec().into()))
.tap_err(|err| log::error!("{err}"))
.unwrap();
} else {
log::info!("udp: received {len} bytes from {addr} from unknown peer, starting tunnel");
// TODO: this is a temporary solution for development since this
// assumes we know the peer_static_public this corresponds to.
// TODO: rework this before production! This is likely not secure!
log::warn!("Assuming peer_static_public is known");
log::warn!("SECURITY: Rework me to do proper handshake before creating the tunnel!");
let (join_handle, peer_tx) = crate::wg_tunnel::start_wg_tunnel(
addr,
udp_socket.clone(),
static_private.clone(),
peer_static_public,
peer_allowed_ips,
registered_peer.public_key,
registered_peer.allowed_ips,
tun_task_tx.clone(),
);

peers_by_ip.lock().unwrap().insert(peer_allowed_ips, peer_tx.clone());
peers_by_ip.lock().unwrap().insert(registered_peer.allowed_ips, peer_tx.clone());

peer_tx.send(Event::WgPacket(buf[..len].to_vec().into()))
.tap_err(|err| log::error!("{err}"))
.unwrap();

// WIP(JON): active peers should probably be keyed by peer_static_public
// instead. Does this current setup lead to any issues?
log::info!("Adding peer: {addr}");
active_peers.insert(addr, peer_tx);
active_peers.insert(registered_peer.public_key, peer_tx);
active_peers_task_handles.push(join_handle);
}
},
Expand Down
22 changes: 16 additions & 6 deletions common/wireguard/src/wg_tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ impl WireGuardTunnel {
static_private: x25519::StaticSecret,
peer_static_public: x25519::PublicKey,
peer_allowed_ips: ip_network::IpNetwork,
index: PeerIdx,
rate_limiter: Option<RateLimiter>,
tunnel_tx: TunTaskTx,
) -> (Self, mpsc::UnboundedSender<Event>) {
let local_addr = udp.local_addr().unwrap();
Expand Down Expand Up @@ -122,6 +124,11 @@ impl WireGuardTunnel {
.await
.tap_err(|err| error!("WireGuard tunnel: consume_wg error: {err}"));
},
Event::WgVerifiedPacket(data) => {
let _ = self.consume_verified_wg(&data)
.await
.tap_err(|err| error!("WireGuard tunnel: consume_verified_wg error: {err}"));
}
Event::IpPacket(data) => self.consume_eth(&data).await,
}
},
Expand Down Expand Up @@ -182,17 +189,13 @@ impl WireGuardTunnel {
}
}
TunnResult::WriteToTunnelV4(packet, addr) => {
// TODO: once the flow is redone, we should add updating the endpoint dynamically
// self.set_endpoint(addr);
if self.allowed_ips.longest_match(addr).is_some() {
self.tun_task_tx.send(packet.to_vec()).unwrap();
} else {
warn!("Packet from {addr} not in allowed_ips");
}
}
TunnResult::WriteToTunnelV6(packet, addr) => {
// TODO: once the flow is redone, we should add updating the endpoint dynamically
// self.set_endpoint(addr);
if self.allowed_ips.longest_match(addr).is_some() {
self.tun_task_tx.send(packet.to_vec()).unwrap();
} else {
Expand All @@ -209,6 +212,13 @@ impl WireGuardTunnel {
Ok(())
}

async fn consume_verified_wg(&mut self, data: &[u8]) -> Result<(), WgError> {
// Potentially we could take some shortcuts here in the name of performance, but currently
// I don't see that the needed functions in boringtun is exposed in the public API.
// TODO: make sure we don't put double pressure on the rate limiter!
self.consume_wg(data).await
}

async fn consume_eth(&self, data: &Bytes) {
info!("consume_eth: raw packet size: {}", data.len());
let encapsulated_packet = self.encapsulate_packet(data).await;
Expand Down Expand Up @@ -290,7 +300,7 @@ pub(crate) fn start_wg_tunnel(
peer_allowed_ips: ip_network::IpNetwork,
tunnel_tx: TunTaskTx,
) -> (
tokio::task::JoinHandle<SocketAddr>,
tokio::task::JoinHandle<x25519::PublicKey>,
mpsc::UnboundedSender<Event>,
) {
let (mut tunnel, peer_tx) = WireGuardTunnel::new(
Expand All @@ -303,7 +313,7 @@ pub(crate) fn start_wg_tunnel(
);
let join_handle = tokio::spawn(async move {
tunnel.spin_off().await;
endpoint
peer_static_public
});
(join_handle, peer_tx)
}

0 comments on commit f3098b4

Please sign in to comment.