Skip to content

Commit

Permalink
prepare trait mod::Transport for partially reliable transport
Browse files Browse the repository at this point in the history
change handshake and connect to allow implementer to opt into providing a
second unreliable unordered stream to carry the forwarded packets.

rathole commands will still be sent over reliable ordered streams.
  • Loading branch information
emillynge committed Jan 12, 2022
1 parent edbb5ce commit 41866de
Show file tree
Hide file tree
Showing 8 changed files with 503 additions and 52 deletions.
47 changes: 34 additions & 13 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use crate::protocol::{
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES,
};
use crate::transport::{TcpTransport, Transport};
use crate::transport::{TcpTransport, Transport, TransportStream};
use anyhow::{anyhow, bail, Context, Result};
use backoff::ExponentialBackoff;
use bytes::{Bytes, BytesMut};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt, AsyncRead, AsyncWrite};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration};
Expand Down Expand Up @@ -148,6 +148,7 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
}
}


struct RunDataChannelArgs<T: Transport> {
session_key: Nonce,
remote_addr: String,
Expand All @@ -157,7 +158,7 @@ struct RunDataChannelArgs<T: Transport> {

async fn do_data_channel_handshake<T: Transport>(
args: Arc<RunDataChannelArgs<T>>,
) -> Result<T::Stream> {
) -> Result<(TransportStream<T>)> {
// Retry at least every 100ms, at most for 10 seconds
let backoff = ExponentialBackoff {
max_interval: Duration::from_millis(100),
Expand All @@ -167,7 +168,7 @@ async fn do_data_channel_handshake<T: Transport>(

// FIXME: Respect control channel shutdown here
// Connect to remote_addr
let mut conn: T::Stream = backoff::future::retry_notify(
let mut conn = backoff::future::retry_notify(
backoff,
|| async {
Ok(args
Expand All @@ -182,11 +183,12 @@ async fn do_data_channel_handshake<T: Transport>(
)
.await?;

// Send nonce
// Send nonce using reliable stream
let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned());
conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
conn.write_all_reliably(&bincode::serialize(&hello).unwrap()).await?;

// return the unreliable connection if one has been provided.
Ok(conn)
}

Expand All @@ -195,21 +197,37 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
let mut conn = do_data_channel_handshake(args.clone()).await?;

// Forward
match read_data_cmd(&mut conn).await? {
match read_data_cmd(&mut conn.get_reliable_stream()).await? {
DataChannelCmd::StartForwardTcp => {
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
match conn {
TransportStream::StrictlyReliable(reliable) => {
run_data_channel_for_tcp::<T, T::ReliableStream>(reliable, &args.local_addr).await?;
}
TransportStream::PartiallyReliable(_, unreliable) => {
run_data_channel_for_tcp::<T, T::UnreliableStream>(unreliable, &args.local_addr).await?;
}
}

}
DataChannelCmd::StartForwardUdp => {
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
match conn {
TransportStream::StrictlyReliable(reliable) => {
run_data_channel_for_udp::<T, T::ReliableStream>(reliable, &args.local_addr).await?;
}
TransportStream::PartiallyReliable(_, unreliable) => {
run_data_channel_for_udp::<T, T::UnreliableStream>(unreliable, &args.local_addr).await?;
}
}

}
}
Ok(())
}

// Simply copying back and forth for TCP
#[instrument(skip(conn))]
async fn run_data_channel_for_tcp<T: Transport>(
mut conn: T::Stream,
async fn run_data_channel_for_tcp<T: Transport, S: AsyncRead + AsyncWrite + Unpin>(
mut conn: S,
local_addr: &str,
) -> Result<()> {
debug!("New data channel starts forwarding");
Expand All @@ -228,7 +246,8 @@ async fn run_data_channel_for_tcp<T: Transport>(
type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;

#[instrument(skip(conn))]
async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
async fn run_data_channel_for_udp<T: Transport, S: 'static + AsyncRead + AsyncWrite + Send>( mut conn: S,
local_addr: &str) -> Result<()> {
debug!("New data channel starts forwarding");

let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
Expand Down Expand Up @@ -375,12 +394,14 @@ struct ControlChannelHandle {
impl<T: 'static + Transport> ControlChannel<T> {
#[instrument(skip_all)]
async fn run(&mut self) -> Result<()> {
let mut conn = self
// ignore unreliable stream, it is not needed for running the control channel
let mut conn_both = self
.transport
.connect(&self.remote_addr)
.await
.with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;

let mut conn = conn_both.get_reliable_stream();
// Send hello
debug!("Sending hello");
let hello_send =
Expand Down
60 changes: 41 additions & 19 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::protocol::{
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
HASH_WIDTH_IN_BYTES,
};
use crate::transport::{TcpTransport, Transport};
use crate::transport::{TcpTransport, Transport, TransportStream};
use anyhow::{anyhow, bail, Context, Result};
use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;
Expand Down Expand Up @@ -229,15 +229,15 @@ impl<'a, T: 'static + Transport> Server<'a, T> {

// Handle connections to `server.bind_addr`
async fn handle_connection<T: 'static + Transport>(
mut conn: T::Stream,
mut conn: TransportStream<T>,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
) -> Result<()> {
// Read hello
let hello = read_hello(&mut conn).await?;
let hello = read_hello(&mut conn.get_reliable_stream()).await?;
match hello {
ControlChannelHello(_, service_digest) => {
do_control_channel_handshake(conn, services, control_channels, service_digest).await?;
do_control_channel_handshake(conn.into_reliable_stream(), services, control_channels, service_digest).await?;
}
DataChannelHello(_, nonce) => {
do_data_channel_handshake(conn, control_channels, nonce).await?;
Expand All @@ -247,7 +247,7 @@ async fn handle_connection<T: 'static + Transport>(
}

async fn do_control_channel_handshake<T: 'static + Transport>(
mut conn: T::Stream,
mut conn: T::ReliableStream,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
service_digest: ServiceDigest,
Expand Down Expand Up @@ -326,7 +326,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
}

async fn do_data_channel_handshake<T: 'static + Transport>(
conn: T::Stream,
conn: TransportStream<T>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
nonce: Nonce,
) -> Result<()> {
Expand All @@ -353,7 +353,7 @@ async fn do_data_channel_handshake<T: 'static + Transport>(
pub struct ControlChannelHandle<T: Transport> {
// Shutdown the control channel by dropping it
_shutdown_tx: broadcast::Sender<bool>,
data_ch_tx: mpsc::Sender<T::Stream>,
data_ch_tx: mpsc::Sender<TransportStream<T>>,
}

impl<T> ControlChannelHandle<T>
Expand All @@ -363,7 +363,7 @@ where
// Create a control channel handle, where the control channel handling task
// and the connection pool task are created.
#[instrument(skip_all, fields(service = %service.name))]
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
fn new(conn: T::ReliableStream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
// Create a shutdown channel
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);

Expand Down Expand Up @@ -449,7 +449,7 @@ where

// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
struct ControlChannel<T: Transport> {
conn: T::Stream, // The connection of control channel
conn: T::ReliableStream, // The connection of control channel
service: ServerServiceConfig, // A copy of the corresponding service config
shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
Expand Down Expand Up @@ -571,19 +571,26 @@ fn tcp_listen_and_send(
}

#[instrument(skip_all)]
async fn run_tcp_connection_pool<T: Transport>(
async fn run_tcp_connection_pool<T: 'static + Transport>(
bind_addr: String,
mut data_ch_rx: mpsc::Receiver<T::Stream>,
mut data_ch_rx: mpsc::Receiver<TransportStream<T>>,
data_ch_req_tx: mpsc::UnboundedSender<bool>,
shutdown_rx: broadcast::Receiver<bool>,
) -> Result<()> {
let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
while let Some(mut visitor) = visitor_rx.recv().await {
if let Some(mut ch) = data_ch_rx.recv().await {
if let Some(mut conn) = data_ch_rx.recv().await {
tokio::spawn(async move {
let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
if ch.write_all(&cmd).await.is_ok() {
let _ = copy_bidirectional(&mut ch, &mut visitor).await;
if conn.write_all_reliably(&cmd).await.is_ok() {
match conn {
TransportStream::StrictlyReliable(mut reliable) => {
let _ = copy_bidirectional(&mut reliable, &mut visitor).await;
}
TransportStream::PartiallyReliable(_, mut unreliable) => {
let _ = copy_bidirectional(&mut unreliable, &mut visitor).await;
}
}
}
});
} else {
Expand All @@ -598,7 +605,7 @@ async fn run_tcp_connection_pool<T: Transport>(
#[instrument(skip_all)]
async fn run_udp_connection_pool<T: Transport>(
bind_addr: String,
mut data_ch_rx: mpsc::Receiver<T::Stream>,
mut data_ch_rx: mpsc::Receiver<TransportStream<T>>,
_data_ch_req_tx: mpsc::UnboundedSender<bool>,
mut shutdown_rx: broadcast::Receiver<bool>,
) -> Result<()> {
Expand All @@ -616,8 +623,8 @@ async fn run_udp_connection_pool<T: Transport>(
warn!("{:?}", e);
},
)
.await
.with_context(|| "Failed to listen for the service")?;
.await
.with_context(|| "Failed to listen for the service")?;

info!("Listening at {}", &bind_addr);

Expand All @@ -628,13 +635,28 @@ async fn run_udp_connection_pool<T: Transport>(
.recv()
.await
.ok_or(anyhow!("No available data channels"))?;
conn.write_all(&cmd).await?;
conn.write_all_reliably(&cmd).await?;

match conn {
TransportStream::StrictlyReliable(reliable) =>
udp_copy_bidirectional::<T, T::ReliableStream>(reliable, shutdown_rx, l).await,
TransportStream::PartiallyReliable(_, unreliable) =>
udp_copy_bidirectional::<T, T::UnreliableStream>(unreliable, shutdown_rx, l).await,
}
}

#[instrument(skip_all)]
async fn udp_copy_bidirectional<T: Transport, S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> (
mut conn: S,
mut shutdown_rx: broadcast::Receiver<bool>,
l: UdpSocket,
) -> Result<()> {
// after sending forward CMD, use unreliable stream if available for actual forwarding
let mut buf = [0u8; UDP_BUFFER_SIZE];
loop {
tokio::select! {
// Forward inbound traffic to the client
val = l.recv_from(&mut buf) => {
val = l.recv_from(&mut buf) => {
let (n, from) = val?;
UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
},
Expand Down
86 changes: 82 additions & 4 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,51 @@ use crate::config::TransportConfig;
use anyhow::Result;
use async_trait::async_trait;
use std::fmt::Debug;
use std::io;
use std::io::Error;
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::ToSocketAddrs;

// Specify a transport layer, like TCP, TLS
#[async_trait]
pub trait Transport: Debug + Send + Sync {
type Acceptor: Send + Sync;
type RawStream: Send + Sync;
type Stream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;
type ReliableStream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;
type UnreliableStream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;


async fn new(config: &TransportConfig) -> Result<Self>
where
Self: Sized;
async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor>;
async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)>;
async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream>;
async fn connect(&self, addr: &str) -> Result<Self::Stream>;

/// Perform handshake using a newly initiated raw stream (tcp/udp)
/// return a properly configured connection for protocol that transport uses.
///
/// The returned connection may either be Reliable or Partially Reliable
/// (wholly unreliable transport are not currently supported).
///
/// Both Partially reliable and strictly reliable transport must provide a reliable stream
/// If partially reliable, then an unreliable stream must additionally be provided
async fn handshake(&self, conn: Self::RawStream) -> Result<TransportStream<Self>>;

/// Connection to Server
/// return
/// - A reliable ordered stream used for control channel communication
/// - Optionally an unordered and unreliable stream used for data channel forwarding
/// If no such stream is provided, then data will be sent using the reliable stream.
async fn connect(&self, addr: &str) -> Result<TransportStream<Self>>;
}

#[derive(Debug)]
pub enum TransportStream<T: Transport + ?Sized> {
StrictlyReliable(T::ReliableStream),
PartiallyReliable(T::ReliableStream, T::UnreliableStream),
}

mod tcp;
Expand All @@ -33,3 +60,54 @@ pub use tls::TlsTransport;
mod noise;
#[cfg(feature = "noise")]
pub use noise::NoiseTransport;

impl<T> TransportStream<T>
where T: Transport
{
pub async fn write_all_reliably<'a>(&'a mut self, src: &'a [u8]) -> std::io::Result<()> {
let r = match self {
TransportStream::StrictlyReliable(s) => s.write_all(src).await,
TransportStream::PartiallyReliable(s, _) => s.write_all(src).await,
};
r
}

pub(crate) fn get_reliable_stream(&mut self) -> &mut T::ReliableStream
{
match self {
TransportStream::StrictlyReliable(s) => s,
TransportStream::PartiallyReliable(s, _) => s,
}
}

pub fn into_reliable_stream(self) -> T::ReliableStream {
match self {
TransportStream::StrictlyReliable(s) => s,
TransportStream::PartiallyReliable(s, _) => s,
}
}
}

/// A dummy struct for use with transports that are strictly reliable
#[derive(Debug)]
pub struct UnimplementedUnreliableStream;

impl AsyncRead for UnimplementedUnreliableStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
todo!()
}
}

impl AsyncWrite for UnimplementedUnreliableStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::result::Result<usize, Error>> {
todo!()
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Error>> {
todo!()
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Error>> {
todo!()
}
}
Loading

0 comments on commit 41866de

Please sign in to comment.