diff --git a/theia/proxy-rs/Cargo.lock b/theia/proxy-rs/Cargo.lock index 5b50c69c..205b136b 100644 --- a/theia/proxy-rs/Cargo.lock +++ b/theia/proxy-rs/Cargo.lock @@ -470,6 +470,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -490,6 +501,7 @@ checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1098,6 +1110,7 @@ dependencies = [ "anyhow", "axum", "axum-extra", + "futures-util", "hyper", "hyper-util", "jsonwebtoken", @@ -1105,6 +1118,7 @@ dependencies = [ "serde", "sqlx", "tokio", + "tokio-tungstenite", "tower-http", "tracing", "tracing-subscriber", diff --git a/theia/proxy-rs/Cargo.toml b/theia/proxy-rs/Cargo.toml index 6723d491..2d776877 100644 --- a/theia/proxy-rs/Cargo.toml +++ b/theia/proxy-rs/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" anyhow = "1.0.82" axum = { version = "0.7.5", features = ["ws"] } axum-extra = { version = "0.9.3", features = ["cookie"] } +futures-util = "0.3.30" hyper = "1.3.1" hyper-util = { version = "0.1.3", features = ["tokio"] } jsonwebtoken = "9.3.0" @@ -16,6 +17,7 @@ lazy_static = "1.4.0" serde = "1.0.200" sqlx = { version = "0.7.4", features = ["mysql", "runtime-tokio"] } tokio = { version = "1.37.0", features = ["rt-multi-thread"] } +tokio-tungstenite = "0.21.0" tower-http = { version = "0.5.2", features = ["trace"] } tracing = "0.1.40" tracing-subscriber = "0.3.18" diff --git a/theia/proxy-rs/src/main.rs b/theia/proxy-rs/src/main.rs index 4743834b..c12e60f5 100644 --- a/theia/proxy-rs/src/main.rs +++ b/theia/proxy-rs/src/main.rs @@ -1,12 +1,13 @@ +mod ws; + use anyhow::Result; use axum::{ - extract::{Path, Request}, - http::StatusCode, - response::{IntoResponse, Redirect}, body::Body, + extract::{Path, Request, WebSocketUpgrade}, + http::StatusCode, + response::{IntoResponse, Redirect, Response}, routing::get, Extension, Router, - response::Response, }; use axum_extra::extract::cookie::{Cookie, CookieJar}; use hyper::upgrade::Upgraded; @@ -15,11 +16,11 @@ use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validati use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use sqlx::{mysql::MySqlPoolOptions, prelude::FromRow, MySqlPool}; +use std::time::Duration; +use std::{env::var, sync::Arc}; +use tokio::net::TcpStream; use tower_http::trace::{self, TraceLayer}; use tracing::Level; -use tokio::net::TcpStream; -use std::{env::var, sync::Arc}; -use std::time::Duration; const PROXY_SERVER_PORT: u64 = 5000; const MAX_PROXY_PORT: u64 = 8010; @@ -98,11 +99,13 @@ async fn ping(jar: CookieJar, Extension(pool): Extension>) -> (St match jar.get("ide") { Some(cookie) => match authenticate_jwt(cookie.value()) { Ok(claims) => { - update_last_proxy_time(&claims.session_id, &pool).await.unwrap(); - }, + update_last_proxy_time(&claims.session_id, &pool) + .await + .unwrap(); + } Err(_) => {} }, - None => {}, + None => {} }; (StatusCode::OK, "pong".to_string()) @@ -171,10 +174,9 @@ async fn get_cluster_address(pool: &MySqlPool, session_id: &str) -> Result, + ws: WebSocketUpgrade, Extension(pool): Extension>, jar: CookieJar, req: Request, @@ -197,7 +199,7 @@ async fn handle( } }; - let _cluster_address = get_cluster_address(&pool, &token.session_id) + let cluster_address = get_cluster_address(&pool, &token.session_id) .await .map_err(|e| { eprintln!("Error: {}", e); @@ -208,36 +210,45 @@ async fn handle( }) .unwrap(); - proxy(req).await.unwrap(); + let host = format!("ws://{}:{}", cluster_address, port); + ws.on_upgrade(move |socket| ws::forward(&host, ws)); (StatusCode::OK, "authorized".to_string()) } -async fn proxy(req: Request) -> Result { - if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) { - tokio::task::spawn(async move { - match hyper::upgrade::on(req).await { - Ok(upgraded) => { - if let Err(e) = tunnel(upgraded, host_addr).await { - tracing::warn!("server io error: {}", e); - }; +async fn proxy(req: Request) -> Result { + let authority = req.uri().authority().map(|auth| auth.to_string()); + + let host_addr = match authority { + Some(addr) => addr, + None => { + return Ok(( + StatusCode::BAD_REQUEST, + "CONNECT must be to a socket address", + ) + .into_response()); + } + }; + + tokio::task::spawn(async move { + match hyper::upgrade::on(req).await { + Ok(upgraded) => { + let res = tunnel(upgraded, host_addr).await; + + if let Err(e) = res { + tracing::warn!("tunnel error: {}", e); } - Err(e) => tracing::warn!("upgrade error: {}", e), } - }); - - Ok(Response::new(Body::empty())) - } else { - tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri()); - Ok(( - StatusCode::BAD_REQUEST, - "CONNECT must be to a socket address", - ) - .into_response()) - } + Err(e) => { + tracing::warn!("upgrade error: {}", e); + } + }; + }); + + Ok(Response::new(Body::empty())) } -async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { +async fn tunnel(upgraded: Upgraded, addr: String) -> anyhow::Result<()> { let mut server = TcpStream::connect(addr).await?; let mut upgraded = TokioIo::new(upgraded); @@ -268,9 +279,11 @@ async fn main() { .map_err(|e| { tracing::error!("Failed to connect to database: {}", e); panic!("Failed to connect to database"); - }).map(|_| { + }) + .map(|_| { tracing::info!("Connected to database"); - }).unwrap(); + }) + .unwrap(); let pool = Arc::new(pool); @@ -290,6 +303,8 @@ async fn main() { tracing::info!("Server started on port {}", PROXY_SERVER_PORT); - let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", PROXY_SERVER_PORT)).await.unwrap(); + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", PROXY_SERVER_PORT)) + .await + .unwrap(); axum::serve(listener, app).await.unwrap(); } diff --git a/theia/proxy-rs/src/ws.rs b/theia/proxy-rs/src/ws.rs new file mode 100644 index 00000000..31ded65a --- /dev/null +++ b/theia/proxy-rs/src/ws.rs @@ -0,0 +1,120 @@ +use anyhow::Result; +use axum::extract::ws::{CloseFrame, Message as AxumMessage, WebSocket}; +use axum::extract::WebSocketUpgrade; +use futures_util::{SinkExt, StreamExt}; +use tokio_tungstenite::tungstenite; +use tokio_tungstenite::{connect_async, tungstenite::protocol::Message as TsMessage}; + +enum WebSocketMessageType { + Axum(AxumMessage), + Tungstenite(TsMessage), +} + +struct WebSocketMessage { + message: WebSocketMessageType, +} + +impl WebSocketMessage { + fn tungstenite(message: TsMessage) -> Self { + Self { + message: WebSocketMessageType::Tungstenite(message), + } + } + + fn axum(message: AxumMessage) -> Self { + Self { + message: WebSocketMessageType::Axum(message), + } + } + + fn into_tungstenite(self) -> TsMessage { + match self.message { + WebSocketMessageType::Axum(message) => match message { + AxumMessage::Text(text) => TsMessage::Text(text), + AxumMessage::Binary(binary) => TsMessage::Binary(binary), + AxumMessage::Ping(ping) => TsMessage::Ping(ping), + AxumMessage::Pong(pong) => TsMessage::Pong(pong), + AxumMessage::Close(Some(close)) => { + TsMessage::Close(Some(tungstenite::protocol::frame::CloseFrame { + code: tungstenite::protocol::frame::coding::CloseCode::from(close.code), + reason: close.reason, + })) + } + AxumMessage::Close(None) => TsMessage::Close(None), + }, + WebSocketMessageType::Tungstenite(message) => message, + } + } + + fn into_axum(self) -> AxumMessage { + match self.message { + WebSocketMessageType::Axum(message) => message, + WebSocketMessageType::Tungstenite(message) => match message { + TsMessage::Text(text) => AxumMessage::Text(text), + TsMessage::Binary(binary) => AxumMessage::Binary(binary), + TsMessage::Ping(ping) => AxumMessage::Ping(ping), + TsMessage::Pong(pong) => AxumMessage::Pong(pong), + TsMessage::Close(Some(close)) => AxumMessage::Close(Some(CloseFrame { + code: close.code.into(), + reason: close.reason, + })), + TsMessage::Close(None) => AxumMessage::Close(None), + TsMessage::Frame(frame) => { + tracing::warn!("unexpected frame: {:?}", frame); + AxumMessage::Close(None) + } + }, + } + } +} + +pub async fn forward(url: &str, client_ws: WebSocketUpgrade) { + let server_ws = match connect_async(url).await { + Ok((ws, _)) => ws, + Err(e) => { + tracing::warn!("connect error: {}", e); + return; + } + }; + + tokio::spawn(async move { + let (mut client_write, mut client_read) = client_ws.split(); + let (mut server_write, mut server_read) = server_ws.split(); + + tokio::spawn(async move { + while let Some(message) = client_read.next().await { + match message { + Ok(message) => { + let message = WebSocketMessage::axum(message); + let res = server_write.send(message.into_tungstenite()).await; + if let Err(e) = res { + tracing::warn!("client write error: {}", e); + continue; + } + } + Err(e) => { + tracing::warn!("client read error: {}", e); + continue; + } + } + } + }); + + while let Some(message) = server_read.next().await { + match message { + Ok(message) => { + let message = WebSocketMessage::tungstenite(message); + let res = client_write.send(message.into_axum()).await; + if let Err(e) = res { + tracing::warn!("client write error: {}", e); + continue; + } + } + Err(e) => { + tracing::warn!("client read error: {}", e); + continue; + } + } + } + }); +} diff --git a/theia/proxy-rs/stress.py b/theia/proxy-rs/stress.py deleted file mode 100644 index bd1fa2be..00000000 --- a/theia/proxy-rs/stress.py +++ /dev/null @@ -1,21 +0,0 @@ -import multiprocessing as mp -import time - -import requests -import urllib3 - -urllib3.disable_warnings() - - -def func(_): - requests.get('http://localhost:8080/', verify=False) - - -n = 10000 -start = time.time() -with mp.Pool(10) as pool: - pool.map(func, [None] * n) - pool.close() -elapsed = time.time() - start -print('took {:0.2f}s for {} requests'.format(elapsed, n)) -print('{:0.2f} req/s'.format(float(n) / elapsed))