diff --git a/src/main.rs b/src/main.rs index 902cb31..90e8b2f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,7 +32,7 @@ use std::{ time::Duration, }; use tower::{BoxError, ServiceBuilder}; -use tower_http::trace::{DefaultMakeSpan, TraceLayer}; +use tower_http::trace::TraceLayer; use tracing::instrument; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use uuid::Uuid; @@ -68,12 +68,12 @@ async fn main() { // get config from env let config: Config = envy::from_env().unwrap(); - tracing::debug!("env config: {:?}", config); + tracing::info!("env config: {:?}", config); // determine server's public ip for local servers let server_ip = match public_ip::addr().await { Some(ip) => { - tracing::debug!("found server's public ip: {}", ip); + tracing::info!("found server's public ip: {}", ip); ip } None => panic!("unable to find server's public ip address, please make sure it has a connection to the internet"), @@ -88,15 +88,10 @@ async fn main() { let app = Router::new() .route("/api/list/servers", get(get_servers)) // websocket route - .route("/api/list/ws", get(ws_handler)) + .route("/api/list/ws", get(websocket_handler)) // determine the secure ip source from the env .layer(config.ip_source.into_extension()) - // logging so we can see whats going on - .layer( - TraceLayer::new_for_http() - .make_span_with(DefaultMakeSpan::default().include_headers(true)), - ) - // add fallback option + // add default services for error handling, timeout and tracing .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(|error: BoxError| async move { @@ -117,89 +112,107 @@ async fn main() { // run the server let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); - tracing::debug!("listening on {}", addr); + tracing::info!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); } +#[instrument(skip(app_state))] async fn get_servers( pagination: Option>, + SecureClientIp(ip): SecureClientIp, State(app_state): State, ) -> impl IntoResponse { + tracing::info!("sending server list"); let Query(pagination) = pagination.unwrap_or_default(); Json(app_state.server_list.get(&pagination)) } -async fn ws_handler( +#[instrument(level = "debug", skip(ws, app_state))] +async fn websocket_handler( ws: WebSocketUpgrade, SecureClientIp(ip): SecureClientIp, State(app_state): State, ) -> impl IntoResponse { - tracing::debug!("new connection from: {}", ip); + tracing::info!("new websocket connection"); ws.protocols(["json"]).on_upgrade(move |socket| { handle_socket(socket, ip, app_state.server_list, app_state.server_ip) }) } -#[instrument(level = "debug", name = "socket_handler", skip(socket, server_list))] +#[instrument(level = "debug", name = "websocket_handler", skip(socket, server_list))] async fn handle_socket( mut socket: WebSocket, ip: IpAddr, mut server_list: ServerList, server_ip: IpAddr, ) { - let mut game_id = Uuid::nil(); + let game_id; - // loop until the first message is received, which should be the name - while let Some(Ok(msg)) = socket.recv().await { - tracing::debug!("got msg: {:?}", msg); - match msg { - Message::Text(txt) => match parse_connect_message(txt, ip, server_ip) { - Ok(server) => { - tracing::info!("created new game server: {:?}", server); - game_id = server_list.add(server); + // wait for the first message with initial server info + match socket.recv().await { + Some(result) => match result { + Ok(msg) => match msg { + Message::Text(txt) => match parse_connect_message(txt, ip, server_ip) { + Ok(server) => { + tracing::info!("created new game server: {:?}", server); + game_id = server_list.add(server); + } + Err(e) => { + tracing::error!("{:?}", e); + return; + } + }, + Message::Close(_) => { + tracing::info!("connection closed while waiting for server info"); + return; } - Err(e) => { - tracing::error!(e); + _ => { + tracing::warn!( + "got invalid message type while waiting for server info: {:?}", + msg + ); + return; } }, - Message::Close(_) => { - tracing::debug!("connection closed: {}", ip); + Err(e) => { + tracing::error!("error while waiting for server info: {:?}", e); return; } - _ => { - tracing::warn!("got invalid message type: {:?}", msg) - } + }, + None => { + tracing::warn!("connection closed unexpectedly while waiting for server info"); + return; } } - // don't allow the game_id to change after this - let game_id = game_id; // begin the main loop to update the game server state loop { - if let Some(msg) = socket.recv().await { - if let Ok(msg) = msg { - match msg { - Message::Text(t) => { - parse_game_message(&server_list, &game_id, &t); - } + if let Some(msg_type) = socket.recv().await { + match msg_type { + Ok(msg) => match msg { + Message::Text(t) => parse_game_message(&server_list, &game_id, &t), Message::Close(_) => { - tracing::debug!("connection closed: {}", ip); - remove_server(server_list, &game_id); - return; + tracing::debug!("connection closed"); + break; } _ => { tracing::warn!("got invalid message type: {:?}", msg) } + }, + Err(e) => { + tracing::error!("error while waiting for game message: {:?}", e); + break; } - } else { - tracing::debug!("unexpected error: {:?}", msg); - remove_server(server_list, &game_id); - return; } + } else { + tracing::warn!("connection closed unexpectedly"); + break; } } + // Make sure server is always removed if the loop finishes + remove_server(server_list, &game_id); } fn is_local_ipv4(ip: IpAddr) -> bool { @@ -253,7 +266,10 @@ fn parse_connect_message(txt: String, ip: IpAddr, server_ip: IpAddr) -> Result