Skip to content

Commit

Permalink
FIX ws proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
synoet committed May 6, 2024
1 parent ede00e5 commit 3504979
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 59 deletions.
14 changes: 14 additions & 0 deletions theia/proxy-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions theia/proxy-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ edition = "2021"
anyhow = "1.0.82"
axum = { version = "0.7.5", features = ["ws"] }
axum-extra = { version = "0.9.3", features = ["cookie"] }
futures-util = "0.3.30"
hyper = "1.3.1"
hyper-util = { version = "0.1.3", features = ["tokio"] }
jsonwebtoken = "9.3.0"
lazy_static = "1.4.0"
serde = "1.0.200"
sqlx = { version = "0.7.4", features = ["mysql", "runtime-tokio"] }
tokio = { version = "1.37.0", features = ["rt-multi-thread"] }
tokio-tungstenite = "0.21.0"
tower-http = { version = "0.5.2", features = ["trace"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
91 changes: 53 additions & 38 deletions theia/proxy-rs/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mod ws;

use anyhow::Result;
use axum::{
extract::{Path, Request},
http::StatusCode,
response::{IntoResponse, Redirect},
body::Body,
extract::{Path, Request, WebSocketUpgrade},
http::StatusCode,
response::{IntoResponse, Redirect, Response},
routing::get,
Extension, Router,
response::Response,
};
use axum_extra::extract::cookie::{Cookie, CookieJar};
use hyper::upgrade::Upgraded;
Expand All @@ -15,11 +16,11 @@ use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validati
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlPoolOptions, prelude::FromRow, MySqlPool};
use std::time::Duration;
use std::{env::var, sync::Arc};
use tokio::net::TcpStream;
use tower_http::trace::{self, TraceLayer};
use tracing::Level;
use tokio::net::TcpStream;
use std::{env::var, sync::Arc};
use std::time::Duration;

const PROXY_SERVER_PORT: u64 = 5000;
const MAX_PROXY_PORT: u64 = 8010;
Expand Down Expand Up @@ -98,11 +99,13 @@ async fn ping(jar: CookieJar, Extension(pool): Extension<Arc<MySqlPool>>) -> (St
match jar.get("ide") {
Some(cookie) => match authenticate_jwt(cookie.value()) {
Ok(claims) => {
update_last_proxy_time(&claims.session_id, &pool).await.unwrap();
},
update_last_proxy_time(&claims.session_id, &pool)
.await
.unwrap();
}
Err(_) => {}
},
None => {},
None => {}
};

(StatusCode::OK, "pong".to_string())
Expand Down Expand Up @@ -171,10 +174,9 @@ async fn get_cluster_address(pool: &MySqlPool, session_id: &str) -> Result<Strin
}
}



async fn handle(
Path(path): Path<String>,
ws: WebSocketUpgrade,
Extension(pool): Extension<Arc<MySqlPool>>,
jar: CookieJar,
req: Request,
Expand All @@ -197,7 +199,7 @@ async fn handle(
}
};

let _cluster_address = get_cluster_address(&pool, &token.session_id)
let cluster_address = get_cluster_address(&pool, &token.session_id)
.await
.map_err(|e| {
eprintln!("Error: {}", e);
Expand All @@ -208,36 +210,45 @@ async fn handle(
})
.unwrap();

proxy(req).await.unwrap();
let host = format!("ws://{}:{}", cluster_address, port);

ws.on_upgrade(move |socket| ws::forward(&host, ws));
(StatusCode::OK, "authorized".to_string())
}

async fn proxy(req: Request) -> Result<Response, hyper::Error> {
if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) {
tokio::task::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
if let Err(e) = tunnel(upgraded, host_addr).await {
tracing::warn!("server io error: {}", e);
};
async fn proxy(req: Request) -> Result<Response> {
let authority = req.uri().authority().map(|auth| auth.to_string());

let host_addr = match authority {
Some(addr) => addr,
None => {
return Ok((
StatusCode::BAD_REQUEST,
"CONNECT must be to a socket address",
)
.into_response());
}
};

tokio::task::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
let res = tunnel(upgraded, host_addr).await;

if let Err(e) = res {
tracing::warn!("tunnel error: {}", e);
}
Err(e) => tracing::warn!("upgrade error: {}", e),
}
});

Ok(Response::new(Body::empty()))
} else {
tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri());
Ok((
StatusCode::BAD_REQUEST,
"CONNECT must be to a socket address",
)
.into_response())
}
Err(e) => {
tracing::warn!("upgrade error: {}", e);
}
};
});

Ok(Response::new(Body::empty()))
}

async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
async fn tunnel(upgraded: Upgraded, addr: String) -> anyhow::Result<()> {
let mut server = TcpStream::connect(addr).await?;
let mut upgraded = TokioIo::new(upgraded);

Expand Down Expand Up @@ -268,9 +279,11 @@ async fn main() {
.map_err(|e| {
tracing::error!("Failed to connect to database: {}", e);
panic!("Failed to connect to database");
}).map(|_| {
})
.map(|_| {
tracing::info!("Connected to database");
}).unwrap();
})
.unwrap();

let pool = Arc::new(pool);

Expand All @@ -290,6 +303,8 @@ async fn main() {

tracing::info!("Server started on port {}", PROXY_SERVER_PORT);

let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", PROXY_SERVER_PORT)).await.unwrap();
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", PROXY_SERVER_PORT))
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
120 changes: 120 additions & 0 deletions theia/proxy-rs/src/ws.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use anyhow::Result;
use axum::extract::ws::{CloseFrame, Message as AxumMessage, WebSocket};
use axum::extract::WebSocketUpgrade;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message as TsMessage};

enum WebSocketMessageType {
Axum(AxumMessage),
Tungstenite(TsMessage),
}

struct WebSocketMessage {
message: WebSocketMessageType,
}

impl WebSocketMessage {
fn tungstenite(message: TsMessage) -> Self {
Self {
message: WebSocketMessageType::Tungstenite(message),
}
}

fn axum(message: AxumMessage) -> Self {
Self {
message: WebSocketMessageType::Axum(message),
}
}

fn into_tungstenite(self) -> TsMessage {
match self.message {
WebSocketMessageType::Axum(message) => match message {
AxumMessage::Text(text) => TsMessage::Text(text),
AxumMessage::Binary(binary) => TsMessage::Binary(binary),
AxumMessage::Ping(ping) => TsMessage::Ping(ping),
AxumMessage::Pong(pong) => TsMessage::Pong(pong),
AxumMessage::Close(Some(close)) => {
TsMessage::Close(Some(tungstenite::protocol::frame::CloseFrame {
code: tungstenite::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
}))
}
AxumMessage::Close(None) => TsMessage::Close(None),
},
WebSocketMessageType::Tungstenite(message) => message,
}
}

fn into_axum(self) -> AxumMessage {
match self.message {
WebSocketMessageType::Axum(message) => message,
WebSocketMessageType::Tungstenite(message) => match message {
TsMessage::Text(text) => AxumMessage::Text(text),
TsMessage::Binary(binary) => AxumMessage::Binary(binary),
TsMessage::Ping(ping) => AxumMessage::Ping(ping),
TsMessage::Pong(pong) => AxumMessage::Pong(pong),
TsMessage::Close(Some(close)) => AxumMessage::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
})),
TsMessage::Close(None) => AxumMessage::Close(None),
TsMessage::Frame(frame) => {
tracing::warn!("unexpected frame: {:?}", frame);
AxumMessage::Close(None)
}
},
}
}
}

pub async fn forward(url: &str, client_ws: WebSocketUpgrade) {
let server_ws = match connect_async(url).await {
Ok((ws, _)) => ws,
Err(e) => {
tracing::warn!("connect error: {}", e);
return;
}
};

tokio::spawn(async move {
let (mut client_write, mut client_read) = client_ws.split();
let (mut server_write, mut server_read) = server_ws.split();

tokio::spawn(async move {
while let Some(message) = client_read.next().await {
match message {
Ok(message) => {
let message = WebSocketMessage::axum(message);
let res = server_write.send(message.into_tungstenite()).await;
if let Err(e) = res {
tracing::warn!("client write error: {}", e);
continue;
}
}
Err(e) => {
tracing::warn!("client read error: {}", e);
continue;
}
}
}
});

while let Some(message) = server_read.next().await {
match message {
Ok(message) => {
let message = WebSocketMessage::tungstenite(message);
let res = client_write.send(message.into_axum()).await;
if let Err(e) = res {
tracing::warn!("client write error: {}", e);
continue;
}
}
Err(e) => {
tracing::warn!("client read error: {}", e);
continue;
}
}
}
});
}
21 changes: 0 additions & 21 deletions theia/proxy-rs/stress.py

This file was deleted.

0 comments on commit 3504979

Please sign in to comment.