diff --git a/src/client.rs b/src/client.rs index 40231d08..2ee55642 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,14 +6,14 @@ use crate::protocol::{ self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd, DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES, }; -use crate::transport::{TcpTransport, Transport}; +use crate::transport::{TcpTransport, Transport, TransportStream}; use anyhow::{anyhow, bail, Context, Result}; use backoff::ExponentialBackoff; use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite}; use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time::{self, Duration}; @@ -148,6 +148,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> { } } + struct RunDataChannelArgs { session_key: Nonce, remote_addr: String, @@ -157,7 +158,7 @@ struct RunDataChannelArgs { async fn do_data_channel_handshake( args: Arc>, -) -> Result { +) -> Result<(TransportStream)> { // Retry at least every 100ms, at most for 10 seconds let backoff = ExponentialBackoff { max_interval: Duration::from_millis(100), @@ -167,7 +168,7 @@ async fn do_data_channel_handshake( // FIXME: Respect control channel shutdown here // Connect to remote_addr - let mut conn: T::Stream = backoff::future::retry_notify( + let mut conn = backoff::future::retry_notify( backoff, || async { Ok(args @@ -182,11 +183,12 @@ async fn do_data_channel_handshake( ) .await?; - // Send nonce + // Send nonce using reliable stream let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap(); let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned()); - conn.write_all(&bincode::serialize(&hello).unwrap()).await?; + conn.write_all_reliably(&bincode::serialize(&hello).unwrap()).await?; + // return the unreliable connection if one has been provided. Ok(conn) } @@ -195,12 +197,28 @@ async fn run_data_channel(args: Arc>) -> Res let mut conn = do_data_channel_handshake(args.clone()).await?; // Forward - match read_data_cmd(&mut conn).await? { + match read_data_cmd(&mut conn.get_reliable_stream()).await? { DataChannelCmd::StartForwardTcp => { - run_data_channel_for_tcp::(conn, &args.local_addr).await?; + match conn { + TransportStream::StrictlyReliable(reliable) => { + run_data_channel_for_tcp::(reliable, &args.local_addr).await?; + } + TransportStream::PartiallyReliable(_, unreliable) => { + run_data_channel_for_tcp::(unreliable, &args.local_addr).await?; + } + } + } DataChannelCmd::StartForwardUdp => { - run_data_channel_for_udp::(conn, &args.local_addr).await?; + match conn { + TransportStream::StrictlyReliable(reliable) => { + run_data_channel_for_udp::(reliable, &args.local_addr).await?; + } + TransportStream::PartiallyReliable(_, unreliable) => { + run_data_channel_for_udp::(unreliable, &args.local_addr).await?; + } + } + } } Ok(()) @@ -208,8 +226,8 @@ async fn run_data_channel(args: Arc>) -> Res // Simply copying back and forth for TCP #[instrument(skip(conn))] -async fn run_data_channel_for_tcp( - mut conn: T::Stream, +async fn run_data_channel_for_tcp( + mut conn: S, local_addr: &str, ) -> Result<()> { debug!("New data channel starts forwarding"); @@ -228,7 +246,8 @@ async fn run_data_channel_for_tcp( type UdpPortMap = Arc>>>; #[instrument(skip(conn))] -async fn run_data_channel_for_udp(conn: T::Stream, local_addr: &str) -> Result<()> { +async fn run_data_channel_for_udp( mut conn: S, + local_addr: &str) -> Result<()> { debug!("New data channel starts forwarding"); let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new())); @@ -375,12 +394,14 @@ struct ControlChannelHandle { impl ControlChannel { #[instrument(skip_all)] async fn run(&mut self) -> Result<()> { - let mut conn = self + // ignore unreliable stream, it is not needed for running the control channel + let mut conn_both = self .transport .connect(&self.remote_addr) .await .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?; + let mut conn = conn_both.get_reliable_stream(); // Send hello debug!("Sending hello"); let hello_send = diff --git a/src/server.rs b/src/server.rs index 26e0b966..c5ae13de 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,7 +7,7 @@ use crate::protocol::{ self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic, HASH_WIDTH_IN_BYTES, }; -use crate::transport::{TcpTransport, Transport}; +use crate::transport::{TcpTransport, Transport, TransportStream}; use anyhow::{anyhow, bail, Context, Result}; use backoff::backoff::Backoff; use backoff::ExponentialBackoff; @@ -229,15 +229,15 @@ impl<'a, T: 'static + Transport> Server<'a, T> { // Handle connections to `server.bind_addr` async fn handle_connection( - mut conn: T::Stream, + mut conn: TransportStream, services: Arc>>, control_channels: Arc>>, ) -> Result<()> { // Read hello - let hello = read_hello(&mut conn).await?; + let hello = read_hello(&mut conn.get_reliable_stream()).await?; match hello { ControlChannelHello(_, service_digest) => { - do_control_channel_handshake(conn, services, control_channels, service_digest).await?; + do_control_channel_handshake(conn.into_reliable_stream(), services, control_channels, service_digest).await?; } DataChannelHello(_, nonce) => { do_data_channel_handshake(conn, control_channels, nonce).await?; @@ -247,7 +247,7 @@ async fn handle_connection( } async fn do_control_channel_handshake( - mut conn: T::Stream, + mut conn: T::ReliableStream, services: Arc>>, control_channels: Arc>>, service_digest: ServiceDigest, @@ -326,7 +326,7 @@ async fn do_control_channel_handshake( } async fn do_data_channel_handshake( - conn: T::Stream, + conn: TransportStream, control_channels: Arc>>, nonce: Nonce, ) -> Result<()> { @@ -353,7 +353,7 @@ async fn do_data_channel_handshake( pub struct ControlChannelHandle { // Shutdown the control channel by dropping it _shutdown_tx: broadcast::Sender, - data_ch_tx: mpsc::Sender, + data_ch_tx: mpsc::Sender>, } impl ControlChannelHandle @@ -363,7 +363,7 @@ where // Create a control channel handle, where the control channel handling task // and the connection pool task are created. #[instrument(skip_all, fields(service = %service.name))] - fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle { + fn new(conn: T::ReliableStream, service: ServerServiceConfig) -> ControlChannelHandle { // Create a shutdown channel let (shutdown_tx, shutdown_rx) = broadcast::channel::(1); @@ -449,7 +449,7 @@ where // Control channel, using T as the transport layer. P is TcpStream or UdpTraffic struct ControlChannel { - conn: T::Stream, // The connection of control channel + conn: T::ReliableStream, // The connection of control channel service: ServerServiceConfig, // A copy of the corresponding service config shutdown_rx: broadcast::Receiver, // Receives the shutdown signal data_ch_req_rx: mpsc::UnboundedReceiver, // Receives visitor connections @@ -571,19 +571,26 @@ fn tcp_listen_and_send( } #[instrument(skip_all)] -async fn run_tcp_connection_pool( +async fn run_tcp_connection_pool( bind_addr: String, - mut data_ch_rx: mpsc::Receiver, + mut data_ch_rx: mpsc::Receiver>, data_ch_req_tx: mpsc::UnboundedSender, shutdown_rx: broadcast::Receiver, ) -> Result<()> { let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx); while let Some(mut visitor) = visitor_rx.recv().await { - if let Some(mut ch) = data_ch_rx.recv().await { + if let Some(mut conn) = data_ch_rx.recv().await { tokio::spawn(async move { let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap(); - if ch.write_all(&cmd).await.is_ok() { - let _ = copy_bidirectional(&mut ch, &mut visitor).await; + if conn.write_all_reliably(&cmd).await.is_ok() { + match conn { + TransportStream::StrictlyReliable(mut reliable) => { + let _ = copy_bidirectional(&mut reliable, &mut visitor).await; + } + TransportStream::PartiallyReliable(_, mut unreliable) => { + let _ = copy_bidirectional(&mut unreliable, &mut visitor).await; + } + } } }); } else { @@ -598,7 +605,7 @@ async fn run_tcp_connection_pool( #[instrument(skip_all)] async fn run_udp_connection_pool( bind_addr: String, - mut data_ch_rx: mpsc::Receiver, + mut data_ch_rx: mpsc::Receiver>, _data_ch_req_tx: mpsc::UnboundedSender, mut shutdown_rx: broadcast::Receiver, ) -> Result<()> { @@ -616,8 +623,8 @@ async fn run_udp_connection_pool( warn!("{:?}", e); }, ) - .await - .with_context(|| "Failed to listen for the service")?; + .await + .with_context(|| "Failed to listen for the service")?; info!("Listening at {}", &bind_addr); @@ -628,13 +635,28 @@ async fn run_udp_connection_pool( .recv() .await .ok_or(anyhow!("No available data channels"))?; - conn.write_all(&cmd).await?; + conn.write_all_reliably(&cmd).await?; + + match conn { + TransportStream::StrictlyReliable(reliable) => + udp_copy_bidirectional::(reliable, shutdown_rx, l).await, + TransportStream::PartiallyReliable(_, unreliable) => + udp_copy_bidirectional::(unreliable, shutdown_rx, l).await, + } +} +#[instrument(skip_all)] +async fn udp_copy_bidirectional ( + mut conn: S, + mut shutdown_rx: broadcast::Receiver, + l: UdpSocket, + ) -> Result<()> { + // after sending forward CMD, use unreliable stream if available for actual forwarding let mut buf = [0u8; UDP_BUFFER_SIZE]; loop { tokio::select! { // Forward inbound traffic to the client - val = l.recv_from(&mut buf) => { + val = l.recv_from(&mut buf) => { let (n, from) = val?; UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?; }, diff --git a/src/transport/mod.rs b/src/transport/mod.rs index fd2064a2..3275dc9b 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -2,8 +2,12 @@ use crate::config::TransportConfig; use anyhow::Result; use async_trait::async_trait; use std::fmt::Debug; +use std::io; +use std::io::Error; use std::net::SocketAddr; -use tokio::io::{AsyncRead, AsyncWrite}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::ToSocketAddrs; // Specify a transport layer, like TCP, TLS @@ -11,15 +15,38 @@ use tokio::net::ToSocketAddrs; pub trait Transport: Debug + Send + Sync { type Acceptor: Send + Sync; type RawStream: Send + Sync; - type Stream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug; + type ReliableStream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug; + type UnreliableStream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug; + async fn new(config: &TransportConfig) -> Result where Self: Sized; async fn bind(&self, addr: T) -> Result; async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)>; - async fn handshake(&self, conn: Self::RawStream) -> Result; - async fn connect(&self, addr: &str) -> Result; + + /// Perform handshake using a newly initiated raw stream (tcp/udp) + /// return a properly configured connection for protocol that transport uses. + /// + /// The returned connection may either be Reliable or Partially Reliable + /// (wholly unreliable transport are not currently supported). + /// + /// Both Partially reliable and strictly reliable transport must provide a reliable stream + /// If partially reliable, then an unreliable stream must additionally be provided + async fn handshake(&self, conn: Self::RawStream) -> Result>; + + /// Connection to Server + /// return + /// - A reliable ordered stream used for control channel communication + /// - Optionally an unordered and unreliable stream used for data channel forwarding + /// If no such stream is provided, then data will be sent using the reliable stream. + async fn connect(&self, addr: &str) -> Result>; +} + +#[derive(Debug)] +pub enum TransportStream { + StrictlyReliable(T::ReliableStream), + PartiallyReliable(T::ReliableStream, T::UnreliableStream), } mod tcp; @@ -33,3 +60,54 @@ pub use tls::TlsTransport; mod noise; #[cfg(feature = "noise")] pub use noise::NoiseTransport; + +impl TransportStream + where T: Transport +{ + pub async fn write_all_reliably<'a>(&'a mut self, src: &'a [u8]) -> std::io::Result<()> { + let r = match self { + TransportStream::StrictlyReliable(s) => s.write_all(src).await, + TransportStream::PartiallyReliable(s, _) => s.write_all(src).await, + }; + r + } + + pub(crate) fn get_reliable_stream(&mut self) -> &mut T::ReliableStream + { + match self { + TransportStream::StrictlyReliable(s) => s, + TransportStream::PartiallyReliable(s, _) => s, + } + } + + pub fn into_reliable_stream(self) -> T::ReliableStream { + match self { + TransportStream::StrictlyReliable(s) => s, + TransportStream::PartiallyReliable(s, _) => s, + } + } +} + +/// A dummy struct for use with transports that are strictly reliable +#[derive(Debug)] +pub struct UnimplementedUnreliableStream; + +impl AsyncRead for UnimplementedUnreliableStream { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + todo!() + } +} + +impl AsyncWrite for UnimplementedUnreliableStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + todo!() + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!() + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!() + } +} \ No newline at end of file diff --git a/src/transport/noise.rs b/src/transport/noise.rs index b83609a4..fa299f61 100644 --- a/src/transport/noise.rs +++ b/src/transport/noise.rs @@ -9,6 +9,8 @@ use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use snowstorm::{Builder, NoiseParams, NoiseStream}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use crate::transport::{TransportStream, UnimplementedUnreliableStream}; +use crate::transport::TransportStream::StrictlyReliable; pub struct NoiseTransport { config: NoiseConfig, @@ -36,8 +38,9 @@ impl NoiseTransport { #[async_trait] impl Transport for NoiseTransport { type Acceptor = TcpListener; + type ReliableStream = snowstorm::stream::NoiseStream; + type UnreliableStream = UnimplementedUnreliableStream; type RawStream = TcpStream; - type Stream = snowstorm::stream::NoiseStream; async fn new(config: &TransportConfig) -> Result { let config = match &config.noise { @@ -81,14 +84,14 @@ impl Transport for NoiseTransport { Ok((conn, addr)) } - async fn handshake(&self, conn: Self::RawStream) -> Result { + async fn handshake(&self, conn: Self::RawStream) -> Result> { let conn = NoiseStream::handshake(conn, self.builder().build_responder()?) .await .with_context(|| "Failed to do noise handshake")?; - Ok(conn) + Ok(StrictlyReliable(conn)) } - async fn connect(&self, addr: &str) -> Result { + async fn connect(&self, addr: &str) -> Result> { let conn = TcpStream::connect(addr) .await .with_context(|| "Failed to connect TCP socket")?; @@ -97,6 +100,6 @@ impl Transport for NoiseTransport { let conn = NoiseStream::handshake(conn, self.builder().build_initiator()?) .await .with_context(|| "Failed to do noise handshake")?; - return Ok(conn); + return Ok(StrictlyReliable(conn)); } } diff --git a/src/transport/quic.rs b/src/transport/quic.rs new file mode 100644 index 00000000..40d817e2 --- /dev/null +++ b/src/transport/quic.rs @@ -0,0 +1,319 @@ +use futures::lock::Mutex; +use std::borrow::{Borrow, BorrowMut}; +use std::fmt::{Debug, Formatter}; +use std::io; +use std::io::{Error, IoSlice}; +use std::net::SocketAddr; +use std::ops::Deref; +use std::ops::DerefMut; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; + +use super::Transport; +use crate::config::{TlsConfig, TransportConfig}; +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::{AsyncWriteExt, ready, StreamExt, Stream}; +use openssl::pkcs12::Pkcs12; +use quinn::{Connection, ConnectionError, Datagrams, Endpoint, EndpointConfig, Incoming, NewConnection, SendDatagramError}; +use rustls::internal::msgs::codec::Codec; +use rustls::ClientConfig; +use tokio::fs; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::{ToSocketAddrs, UdpSocket}; +use tokio_native_tls::native_tls::Certificate; +use crate::transport::TransportStream; + +pub const ALPN_QUIC_TUNNEL: &[&[u8]] = &[b"qt"]; +pub const NONE: &str = "None"; + +pub struct QuicTransport { + config: TlsConfig, + client_crypto: Option, +} + +impl Debug for QuicTransport { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let client_crypto = &self.client_crypto.as_ref().map(|_| "ClientConfig{}"); + + f.debug_struct("QuicTransport") + .field("config", &self.config) + .field("client_crypto", client_crypto) + .finish() + } +} + +#[derive(Debug)] +pub struct QuicBiStream { + send: quinn::SendStream, + recv: quinn::RecvStream, + conn: Connection, +} + +#[derive(Debug)] +pub struct QuicDatagramStream { + conn: Connection, + datagrams: Datagrams, +} + +impl QuicBiStream { + fn new((mut send, recv): (quinn::SendStream, quinn::RecvStream), conn: Connection) -> Self { + Self { send, recv, conn} + } +} + + +impl QuicDatagramStream { + fn new(mut datagrams: Datagrams, conn: Connection) -> Self { + QuicDatagramStream{ + datagrams, + conn, + } + } +} + +impl tokio::io::AsyncRead for QuicBiStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // ready!(AsyncRead::poll_read(self.recv., cx, buf))?; + // Poll::Ready(Ok(())) + Pin::new(self.get_mut().recv.borrow_mut()).poll_read(cx, buf) + } +} + +impl tokio::io::AsyncRead for QuicDatagramStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match Pin::new(self.get_mut().datagrams.borrow_mut()).poll_next(cx) { + Poll::Ready(Some(Ok(b))) => { + buf.put_slice(b.as_ref()); + return Poll::Ready(std::io::Result::Ok(())) + } + Poll::Ready(Some(Err(err))) => Poll::Ready(std::io::Result::Err(Error::from(std::io::ErrorKind::BrokenPipe))), + Poll::Ready(None) => Poll::Ready(std::io::Result::Err(Error::from(std::io::ErrorKind::BrokenPipe))), + Poll::Pending => {Poll::Pending} + } + } +} + +impl AsyncWrite for QuicBiStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(self.get_mut().send.borrow_mut()).poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(self.get_mut().send.borrow_mut()).poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(self.get_mut().send.borrow_mut()).poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(self.get_mut().send.borrow_mut()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.send.is_write_vectored() + } +} + +impl AsyncWrite for QuicDatagramStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { + match self.conn.send_datagram(Bytes::from(buf)) { + Ok(_) => Poll::Ready(Ok(buf.len())), + Err(e) => Poll::Ready(std::io::Result::Err(Error::from(std::io::ErrorKind::BrokenPipe))), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + todo!() + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + todo!() + } +} + + +pub struct QuicAcceptor(Arc>); + +impl Drop for QuicAcceptor { + fn drop(&mut self) { + if let Some(guard) = self.0.try_lock() { + guard.0.close(0u8.into(), &[]) + } + } +} + +impl Drop for QuicBiStream { + fn drop(&mut self) { + self.conn.close(0u8.into(), &[]); + } +} + +#[async_trait] +impl Transport for QuicTransport { + type Acceptor = QuicAcceptor; + type ReliableStream = QuicBiStream; + type UnreliableStream = QuicDatagramStream; + + async fn new(config: &TransportConfig) -> Result { + let config = match &config.tls { + Some(v) => v, + None => { + return Err(anyhow!("Missing tls config")); + } + }; + + let client_crypto = match config.trusted_root.as_ref() { + Some(path) => { + let s = fs::read_to_string(path) + .await + .with_context(|| "Failed to read the `tls.trusted_root`")?; + let cert = Certificate::from_pem(s.as_bytes()) + .with_context(|| "Failed to read certificate from `tls.trusted_root`")?; + + let mut roots = rustls::RootCertStore::empty(); + + roots.add(&rustls::Certificate( + cert.to_der() + .with_context(|| "could not encode trust root as DER")?, + )); + + let mut client_crypto = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(roots) + .with_no_client_auth(); + client_crypto.alpn_protocols = ALPN_QUIC_TUNNEL.iter().map(|&x| x.into()).collect(); + Some(client_crypto) + } + None => None, + }; + + Ok(QuicTransport { + config: config.clone(), + client_crypto, + }) + } + + async fn bind(&self, addr: A) -> Result { + let buf = fs::read(self.config.pkcs12.as_ref().unwrap()) + .await + .with_context(|| "Failed to read the `tls.pkcs12`")?; + + let pkcs12 = + Pkcs12::from_der(buf.as_slice()).with_context(|| "Failed to open `tls.pkcs12`")?; + + let parsed = pkcs12 + .parse(self.config.pkcs12_password.as_ref().unwrap()) + .with_context(|| "Could not decrypt `tls.pkcs12` using `tls.pkcs12_password`")?; + + let mut chain: Vec = parsed + .chain + .unwrap() + .into_iter() + .map(|cert| rustls::Certificate(cert.to_der().unwrap())) + .rev() + .collect(); + chain.insert( + 0, + rustls::Certificate( + parsed + .cert + .to_der() + .with_context(|| "Could not encode server cert as PEM")?, + ), + ); + + let key = rustls::PrivateKey(parsed.pkey.private_key_to_der().unwrap()); + + let mut server_crypto = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(chain, key) + .with_context(|| "Server keys invalid")?; + + server_crypto.alpn_protocols = ALPN_QUIC_TUNNEL.iter().map(|&x| x.into()).collect(); + + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + Arc::get_mut(&mut server_config.transport) + .unwrap() + .datagram_receive_buffer_size(Some(65536)) + .datagram_send_buffer_size(65536) + .max_idle_timeout(Some(Duration::from_secs(10).try_into()?)); + + server_config.use_retry(true); + let socket = UdpSocket::bind(addr).await?.into_std()?; + quinn::Endpoint::new(EndpointConfig::default(), Some(server_config), socket) + .with_context(|| "Failed to start server") + .map(|e_i| QuicAcceptor(Arc::new(Mutex::new(e_i)))) + } + + async fn accept(&self, a: &Self::Acceptor) -> Result<(TransportStream, SocketAddr)> { + // let a_guard = a.lock().unwrap(); + while let Some(connecting) = a.0.lock().await.1.next().await { + let addr = connecting.remote_address(); + let mut conn = connecting.await?; + if let Some(stream) = conn.bi_streams.next().await { + + return Ok((TransportStream::PartiallyReliable( + QuicBiStream::new(stream.unwrap(), conn.connection.clone()), + QuicDatagramStream::new(conn.datagrams, conn.connection) + ), addr)); + } + } + Err(anyhow!("endpoint closed")) + } + + async fn connect(&self, addr: &str) -> Result> { + let mut endpoint = quinn::Endpoint::client("[::]:0".parse().unwrap()) + .with_context(|| "could not open client socket")?; + let mut config = + quinn::ClientConfig::new(Arc::new(self.client_crypto.as_ref().unwrap().clone())); + // server times out afte 10 sec, so send keepalive every 5 sec + Arc::get_mut(&mut config.transport) + .unwrap() + .keep_alive_interval(Some(Duration::from_secs(5).try_into()?)) + .datagram_receive_buffer_size(Some(65536)) + .datagram_send_buffer_size(65536); + endpoint.set_default_client_config(config); + let connecting = endpoint.connect( + addr.parse().with_context(|| "server address not valid")?, + self.config + .hostname + .as_ref() + .unwrap_or(&String::from(addr.split(':').next().unwrap())), + )?; + let new_conn = connecting.await?; + + Ok(TransportStream::PartiallyReliable( + QuicBiStream::new(new_conn.connection.open_bi().await?, new_conn.connection.clone()), + QuicDatagramStream::new(new_conn.datagrams, new_conn.connection), + )) + } +} + diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 7f50c578..02d50272 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -6,6 +6,8 @@ use anyhow::Result; use async_trait::async_trait; use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use crate::transport::{TransportStream, UnimplementedUnreliableStream}; +use crate::transport::TransportStream::StrictlyReliable; #[derive(Debug)] pub struct TcpTransport {} @@ -13,7 +15,8 @@ pub struct TcpTransport {} #[async_trait] impl Transport for TcpTransport { type Acceptor = TcpListener; - type Stream = TcpStream; + type ReliableStream = TcpStream; + type UnreliableStream = UnimplementedUnreliableStream; type RawStream = TcpStream; async fn new(_config: &TransportConfig) -> Result { @@ -30,13 +33,13 @@ impl Transport for TcpTransport { Ok((s, addr)) } - async fn handshake(&self, conn: Self::RawStream) -> Result { - Ok(conn) + async fn handshake(&self, conn: Self::RawStream) -> Result> { + Ok(StrictlyReliable(conn)) } - async fn connect(&self, addr: &str) -> Result { + async fn connect(&self, addr: &str) -> Result> { let s = TcpStream::connect(addr).await?; set_tcp_keepalive(&s); - Ok(s) + Ok(StrictlyReliable(s)) // TCP cannot provide unreliable stream } } diff --git a/src/transport/tls.rs b/src/transport/tls.rs index eda2216f..37655cb5 100644 --- a/src/transport/tls.rs +++ b/src/transport/tls.rs @@ -9,6 +9,8 @@ use tokio::fs; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_native_tls::native_tls::{self, Certificate, Identity}; use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream}; +use crate::transport::{TransportStream, UnimplementedUnreliableStream}; +use crate::transport::TransportStream::StrictlyReliable; #[derive(Debug)] pub struct TlsTransport { @@ -20,8 +22,9 @@ pub struct TlsTransport { #[async_trait] impl Transport for TlsTransport { type Acceptor = TcpListener; + type ReliableStream = TlsStream; + type UnreliableStream = UnimplementedUnreliableStream; type RawStream = TcpStream; - type Stream = TlsStream; async fn new(config: &TransportConfig) -> Result { let config = match &config.tls { @@ -81,17 +84,17 @@ impl Transport for TlsTransport { Ok((conn, addr)) } - async fn handshake(&self, conn: Self::RawStream) -> Result { + async fn handshake(&self, conn: Self::RawStream) -> Result> { let conn = self.tls_acceptor.as_ref().unwrap().accept(conn).await?; - Ok(conn) + Ok(StrictlyReliable(conn)) } - async fn connect(&self, addr: &str) -> Result { + async fn connect(&self, addr: &str) -> Result> { let conn = TcpStream::connect(&addr).await?; set_tcp_keepalive(&conn); let connector = self.connector.as_ref().unwrap(); - Ok(connector + Ok(StrictlyReliable(connector .connect( self.config .hostname @@ -99,6 +102,6 @@ impl Transport for TlsTransport { .unwrap_or(&String::from(addr.split(':').next().unwrap())), conn, ) - .await?) + .await?)) } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 18115753..4624cfdb 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -127,6 +127,8 @@ async fn test(config_path: &'static str, t: Type) -> Result<()> { client_shutdown_tx.send(true)?; let _ = tokio::join!(client); + // Wait for the server connection to be closed (quic) + time::sleep(Duration::from_millis(2500)).await; info!("restart the client"); let client_shutdown_rx = client_shutdown_tx.subscribe(); let client = tokio::spawn(async move {