Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Change Transport to allow for implementing partially reliable transport like QUIC #95

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use crate::protocol::{
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES,
};
use crate::transport::{TcpTransport, Transport};
use crate::transport::{TcpTransport, Transport, TransportStream};
use anyhow::{anyhow, bail, Context, Result};
use backoff::ExponentialBackoff;
use bytes::{Bytes, BytesMut};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration};
Expand Down Expand Up @@ -148,6 +148,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
}
}


struct RunDataChannelArgs<T: Transport> {
session_key: Nonce,
remote_addr: String,
Expand All @@ -157,7 +158,7 @@ struct RunDataChannelArgs<T: Transport> {

async fn do_data_channel_handshake<T: Transport>(
args: Arc<RunDataChannelArgs<T>>,
) -> Result<T::Stream> {
) -> Result<(TransportStream<T>)> {
// Retry at least every 100ms, at most for 10 seconds
let backoff = ExponentialBackoff {
max_interval: Duration::from_millis(100),
Expand All @@ -167,7 +168,7 @@ async fn do_data_channel_handshake<T: Transport>(

// FIXME: Respect control channel shutdown here
// Connect to remote_addr
let mut conn: T::Stream = backoff::future::retry_notify(
let mut conn = backoff::future::retry_notify(
backoff,
|| async {
Ok(args
Expand All @@ -182,11 +183,12 @@ async fn do_data_channel_handshake<T: Transport>(
)
.await?;

// Send nonce
// Send nonce using reliable stream
let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned());
conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
conn.write_all_reliably(&bincode::serialize(&hello).unwrap()).await?;

// return the unreliable connection if one has been provided.
Ok(conn)
}

Expand All @@ -195,21 +197,37 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
let mut conn = do_data_channel_handshake(args.clone()).await?;

// Forward
match read_data_cmd(&mut conn).await? {
match read_data_cmd(&mut conn.get_reliable_stream()).await? {
DataChannelCmd::StartForwardTcp => {
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
match conn {
TransportStream::StrictlyReliable(reliable) => {
run_data_channel_for_tcp::<T, T::ReliableStream>(reliable, &args.local_addr).await?;
}
TransportStream::PartiallyReliable(_, unreliable) => {
run_data_channel_for_tcp::<T, T::UnreliableStream>(unreliable, &args.local_addr).await?;
Copy link
Owner

@rapiz1 rapiz1 Jan 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right. After reading quinn::Datagrams, I see it's really just datagram, just like what UDP provides. And your AsyncWrite wrapper simply wrap poll_write to send a datagram. But when you do the forwarding for a TCP service, copy_bidirectional is called and byte stream is forwarded, in which every byte is from the application layer, not a TCP packet, and should be guaranteed to be sent to the destination, while with T::UnreliableStream, it can be lost.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I must have misunderstood what the rathole client<->server data channel was carrying.

If we are ACK'ing data from the end user, then we absolutely have to ensure that the packet does indeed arrive (reliability).

But what we can still try to take advantage of is the ability to send QUIC packet unordered.

}
}

}
DataChannelCmd::StartForwardUdp => {
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
match conn {
TransportStream::StrictlyReliable(reliable) => {
run_data_channel_for_udp::<T, T::ReliableStream>(reliable, &args.local_addr).await?;
}
TransportStream::PartiallyReliable(_, unreliable) => {
run_data_channel_for_udp::<T, T::UnreliableStream>(unreliable, &args.local_addr).await?;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may also not work as expect. run_data_channel_for_udp call write_all to write packets. And write_all can call poll_write multiple times and there's no guarantee that a packet passed to write_all will not be fragmented and passed to poll_write at a whole, which means it's possible multiple quic datagrams are used to carry one UdpTraffic, and some of them are lost, causing the deserialization of UdpTraffic to fail.

}
}

}
}
Ok(())
}

// Simply copying back and forth for TCP
#[instrument(skip(conn))]
async fn run_data_channel_for_tcp<T: Transport>(
mut conn: T::Stream,
async fn run_data_channel_for_tcp<T: Transport, S: AsyncRead + AsyncWrite + Unpin>(
mut conn: S,
local_addr: &str,
) -> Result<()> {
debug!("New data channel starts forwarding");
Expand All @@ -228,7 +246,8 @@ async fn run_data_channel_for_tcp<T: Transport>(
type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;

#[instrument(skip(conn))]
async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
async fn run_data_channel_for_udp<T: Transport, S: 'static + AsyncRead + AsyncWrite + Send>( mut conn: S,
local_addr: &str) -> Result<()> {
debug!("New data channel starts forwarding");

let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
Expand Down Expand Up @@ -375,12 +394,14 @@ struct ControlChannelHandle {
impl<T: 'static + Transport> ControlChannel<T> {
#[instrument(skip_all)]
async fn run(&mut self) -> Result<()> {
let mut conn = self
// ignore unreliable stream, it is not needed for running the control channel
let mut conn_both = self
.transport
.connect(&self.remote_addr)
.await
.with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;

let mut conn = conn_both.get_reliable_stream();
// Send hello
debug!("Sending hello");
let hello_send =
Expand Down
60 changes: 41 additions & 19 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::protocol::{
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
HASH_WIDTH_IN_BYTES,
};
use crate::transport::{TcpTransport, Transport};
use crate::transport::{TcpTransport, Transport, TransportStream};
use anyhow::{anyhow, bail, Context, Result};
use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;
Expand Down Expand Up @@ -229,15 +229,15 @@ impl<'a, T: 'static + Transport> Server<'a, T> {

// Handle connections to `server.bind_addr`
async fn handle_connection<T: 'static + Transport>(
mut conn: T::Stream,
mut conn: TransportStream<T>,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
) -> Result<()> {
// Read hello
let hello = read_hello(&mut conn).await?;
let hello = read_hello(&mut conn.get_reliable_stream()).await?;
match hello {
ControlChannelHello(_, service_digest) => {
do_control_channel_handshake(conn, services, control_channels, service_digest).await?;
do_control_channel_handshake(conn.into_reliable_stream(), services, control_channels, service_digest).await?;
}
DataChannelHello(_, nonce) => {
do_data_channel_handshake(conn, control_channels, nonce).await?;
Expand All @@ -247,7 +247,7 @@ async fn handle_connection<T: 'static + Transport>(
}

async fn do_control_channel_handshake<T: 'static + Transport>(
mut conn: T::Stream,
mut conn: T::ReliableStream,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
service_digest: ServiceDigest,
Expand Down Expand Up @@ -326,7 +326,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
}

async fn do_data_channel_handshake<T: 'static + Transport>(
conn: T::Stream,
conn: TransportStream<T>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
nonce: Nonce,
) -> Result<()> {
Expand All @@ -353,7 +353,7 @@ async fn do_data_channel_handshake<T: 'static + Transport>(
pub struct ControlChannelHandle<T: Transport> {
// Shutdown the control channel by dropping it
_shutdown_tx: broadcast::Sender<bool>,
data_ch_tx: mpsc::Sender<T::Stream>,
data_ch_tx: mpsc::Sender<TransportStream<T>>,
}

impl<T> ControlChannelHandle<T>
Expand All @@ -363,7 +363,7 @@ where
// Create a control channel handle, where the control channel handling task
// and the connection pool task are created.
#[instrument(skip_all, fields(service = %service.name))]
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
fn new(conn: T::ReliableStream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
// Create a shutdown channel
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);

Expand Down Expand Up @@ -449,7 +449,7 @@ where

// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
struct ControlChannel<T: Transport> {
conn: T::Stream, // The connection of control channel
conn: T::ReliableStream, // The connection of control channel
service: ServerServiceConfig, // A copy of the corresponding service config
shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
Expand Down Expand Up @@ -571,19 +571,26 @@ fn tcp_listen_and_send(
}

#[instrument(skip_all)]
async fn run_tcp_connection_pool<T: Transport>(
async fn run_tcp_connection_pool<T: 'static + Transport>(
bind_addr: String,
mut data_ch_rx: mpsc::Receiver<T::Stream>,
mut data_ch_rx: mpsc::Receiver<TransportStream<T>>,
data_ch_req_tx: mpsc::UnboundedSender<bool>,
shutdown_rx: broadcast::Receiver<bool>,
) -> Result<()> {
let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
while let Some(mut visitor) = visitor_rx.recv().await {
if let Some(mut ch) = data_ch_rx.recv().await {
if let Some(mut conn) = data_ch_rx.recv().await {
tokio::spawn(async move {
let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
if ch.write_all(&cmd).await.is_ok() {
let _ = copy_bidirectional(&mut ch, &mut visitor).await;
if conn.write_all_reliably(&cmd).await.is_ok() {
match conn {
TransportStream::StrictlyReliable(mut reliable) => {
let _ = copy_bidirectional(&mut reliable, &mut visitor).await;
}
TransportStream::PartiallyReliable(_, mut unreliable) => {
let _ = copy_bidirectional(&mut unreliable, &mut visitor).await;
}
}
}
});
} else {
Expand All @@ -598,7 +605,7 @@ async fn run_tcp_connection_pool<T: Transport>(
#[instrument(skip_all)]
async fn run_udp_connection_pool<T: Transport>(
bind_addr: String,
mut data_ch_rx: mpsc::Receiver<T::Stream>,
mut data_ch_rx: mpsc::Receiver<TransportStream<T>>,
_data_ch_req_tx: mpsc::UnboundedSender<bool>,
mut shutdown_rx: broadcast::Receiver<bool>,
) -> Result<()> {
Expand All @@ -616,8 +623,8 @@ async fn run_udp_connection_pool<T: Transport>(
warn!("{:?}", e);
},
)
.await
.with_context(|| "Failed to listen for the service")?;
.await
.with_context(|| "Failed to listen for the service")?;

info!("Listening at {}", &bind_addr);

Expand All @@ -628,13 +635,28 @@ async fn run_udp_connection_pool<T: Transport>(
.recv()
.await
.ok_or(anyhow!("No available data channels"))?;
conn.write_all(&cmd).await?;
conn.write_all_reliably(&cmd).await?;

match conn {
TransportStream::StrictlyReliable(reliable) =>
udp_copy_bidirectional::<T, T::ReliableStream>(reliable, shutdown_rx, l).await,
TransportStream::PartiallyReliable(_, unreliable) =>
udp_copy_bidirectional::<T, T::UnreliableStream>(unreliable, shutdown_rx, l).await,
}
}

#[instrument(skip_all)]
async fn udp_copy_bidirectional<T: Transport, S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> (
mut conn: S,
mut shutdown_rx: broadcast::Receiver<bool>,
l: UdpSocket,
) -> Result<()> {
// after sending forward CMD, use unreliable stream if available for actual forwarding
let mut buf = [0u8; UDP_BUFFER_SIZE];
loop {
tokio::select! {
// Forward inbound traffic to the client
val = l.recv_from(&mut buf) => {
val = l.recv_from(&mut buf) => {
let (n, from) = val?;
UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
},
Expand Down
86 changes: 82 additions & 4 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,51 @@ use crate::config::TransportConfig;
use anyhow::Result;
use async_trait::async_trait;
use std::fmt::Debug;
use std::io;
use std::io::Error;
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::ToSocketAddrs;

// Specify a transport layer, like TCP, TLS
#[async_trait]
pub trait Transport: Debug + Send + Sync {
type Acceptor: Send + Sync;
type RawStream: Send + Sync;
type Stream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;
type ReliableStream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;
type UnreliableStream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;


async fn new(config: &TransportConfig) -> Result<Self>
where
Self: Sized;
async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor>;
async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)>;
async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream>;
async fn connect(&self, addr: &str) -> Result<Self::Stream>;

/// Perform handshake using a newly initiated raw stream (tcp/udp)
/// return a properly configured connection for protocol that transport uses.
///
/// The returned connection may either be Reliable or Partially Reliable
/// (wholly unreliable transport are not currently supported).
///
/// Both Partially reliable and strictly reliable transport must provide a reliable stream
/// If partially reliable, then an unreliable stream must additionally be provided
async fn handshake(&self, conn: Self::RawStream) -> Result<TransportStream<Self>>;

/// Connection to Server
/// return
/// - A reliable ordered stream used for control channel communication
/// - Optionally an unordered and unreliable stream used for data channel forwarding
/// If no such stream is provided, then data will be sent using the reliable stream.
async fn connect(&self, addr: &str) -> Result<TransportStream<Self>>;
}

#[derive(Debug)]
pub enum TransportStream<T: Transport + ?Sized> {
StrictlyReliable(T::ReliableStream),
PartiallyReliable(T::ReliableStream, T::UnreliableStream),
}

mod tcp;
Expand All @@ -33,3 +60,54 @@ pub use tls::TlsTransport;
mod noise;
#[cfg(feature = "noise")]
pub use noise::NoiseTransport;

impl<T> TransportStream<T>
where T: Transport
{
pub async fn write_all_reliably<'a>(&'a mut self, src: &'a [u8]) -> std::io::Result<()> {
let r = match self {
TransportStream::StrictlyReliable(s) => s.write_all(src).await,
TransportStream::PartiallyReliable(s, _) => s.write_all(src).await,
};
r
}

pub(crate) fn get_reliable_stream(&mut self) -> &mut T::ReliableStream
{
match self {
TransportStream::StrictlyReliable(s) => s,
TransportStream::PartiallyReliable(s, _) => s,
}
}

pub fn into_reliable_stream(self) -> T::ReliableStream {
match self {
TransportStream::StrictlyReliable(s) => s,
TransportStream::PartiallyReliable(s, _) => s,
}
}
}

/// A dummy struct for use with transports that are strictly reliable
#[derive(Debug)]
pub struct UnimplementedUnreliableStream;

impl AsyncRead for UnimplementedUnreliableStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
todo!()
}
}

impl AsyncWrite for UnimplementedUnreliableStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::result::Result<usize, Error>> {
todo!()
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Error>> {
todo!()
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Error>> {
todo!()
}
}
Loading