From 809c7d2b8d31f5cc682f0e2480dbd4aef4385caf Mon Sep 17 00:00:00 2001 From: Sudo Dios Date: Mon, 9 Sep 2024 17:18:38 +0330 Subject: [PATCH] Added support for dual-stack socket #3 --- Cargo.lock | 1 + Cargo.toml | 1 + src/http/http_server.rs | 13 ++++---- src/http/mod.rs | 1 + src/http/tcp_socket.rs | 72 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 src/http/tcp_socket.rs diff --git a/Cargo.lock b/Cargo.lock index 7960ea6..94d1198 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1179,6 +1179,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "socket2", "tera", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index bbe02fb..391b062 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ tokio = {version = "1.40.0", features = ["net","io-util","rt","macros","rt-multi tokio-rustls = {version = "0.26.0", features = ["logging","tls12","ring"], default-features = false} webpki-roots = "0.26.5" rustls-pemfile = "2.1.3" +socket2 = "0.5.7" #ip maxminddb = "0.24.0" #image processing diff --git a/src/http/http_server.rs b/src/http/http_server.rs index b26398d..9dff9e5 100644 --- a/src/http/http_server.rs +++ b/src/http/http_server.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use log::{info, trace}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter, split}; -use tokio::net::TcpListener; use tokio::sync::Mutex; use tokio_rustls::TlsAcceptor; use crate::config::{ROUTES, SERVER_CONFIG}; @@ -11,12 +10,13 @@ use crate::http::request::handle_socket; use crate::http::response::Response; use crate::http::routes::*; +use crate::http::tcp_socket::TcpSocket; use crate::http::tls::setup_tls_acceptor; use crate::ip::ip_info::IPInfo; use crate::results::stats::handle_stat_page; pub struct HttpServer { - pub tcp_listener: TcpListener, + pub tcp_socket: TcpSocket, pub tls_acceptor: Option } @@ -24,16 +24,15 @@ impl HttpServer { pub async fn init () -> std::io::Result { let config = SERVER_CONFIG.get().unwrap(); - let addr = format!("{}:{}",config.bind_address,config.listen_port); - let listener = TcpListener::bind(addr.clone()).await?; - info!("Server started on {}",addr); + let tcp_socket = TcpSocket::bind(config)?; + info!("Server started on {}",tcp_socket.to_string()); info!("Server base url : {}/",config.base_url); let mut tls_acceptor = None; if config.enable_tls { tls_acceptor = Some(setup_tls_acceptor(&config.tls_cet_file,&config.tls_key_file)?); } Ok(HttpServer { - tcp_listener : listener, + tcp_socket, tls_acceptor }) } @@ -41,7 +40,7 @@ impl HttpServer { pub async fn listen (&mut self, database : &mut Arc>) { loop { - let tcp_accept = self.tcp_listener.accept().await; + let tcp_accept = self.tcp_socket.accept().await; let mut database = database.clone(); let tls_acceptor = self.tls_acceptor.clone(); diff --git a/src/http/mod.rs b/src/http/mod.rs index 0468cca..2d7761c 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -11,6 +11,7 @@ pub mod response; pub mod cookie; pub mod tls; pub mod http_client; +mod tcp_socket; #[derive(Debug)] pub enum Method { diff --git a/src/http/tcp_socket.rs b/src/http/tcp_socket.rs new file mode 100644 index 0000000..a54f37c --- /dev/null +++ b/src/http/tcp_socket.rs @@ -0,0 +1,72 @@ +use core::fmt; +use std::fmt::Formatter; +use std::io; +use std::io::{Error, ErrorKind}; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use socket2::{Domain, Type}; +use tokio::net::{TcpListener, TcpStream}; +use crate::config::ServerConfig; + +pub struct TcpSocket { + tcp_listener: TcpListener, + addr : TcpAddr +} + +impl TcpSocket { + pub fn bind(config: &ServerConfig) -> io::Result { + let tcp_addr = TcpAddr::new(config)?; + let socket = socket2::Socket::new(tcp_addr.domain,Type::STREAM,None)?; + if !tcp_addr.is_only_v6 { + socket.set_only_v6(false)?; + } + socket.bind(&tcp_addr.sock_addr.into())?; + socket.listen(128)?; + let tcp_listener = TcpListener::from_std(socket.into())?; + Ok(TcpSocket { + tcp_listener, + addr : tcp_addr + }) + } + + pub async fn accept (&self) -> io::Result<(TcpStream, SocketAddr)> { + self.tcp_listener.accept().await + } + +} + +impl fmt::Display for TcpSocket { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f,"{}",self.addr.sock_addr) + } +} + +#[derive(Debug)] +pub struct TcpAddr { + sock_addr: SocketAddr, + domain: Domain, + is_only_v6: bool, +} + +impl TcpAddr { + pub fn new(config: &ServerConfig) -> io::Result { + let bind_addr = config.bind_address.as_str(); + let parse_addr = Self::parse_addr(bind_addr)?; + let addr = SocketAddr::new(parse_addr.0, config.listen_port); + Ok(TcpAddr { + sock_addr: addr, + domain: parse_addr.1, + is_only_v6: bind_addr != "::" && bind_addr != "::0", + }) + } + + fn parse_addr(ip_str: &str) -> io::Result<(IpAddr, Domain)> { + match IpAddr::from_str(ip_str) { + Ok(ip) => match ip { + IpAddr::V4(_) => Ok((ip, Domain::IPV4)), + IpAddr::V6(_) => Ok((ip, Domain::IPV6)), + }, + Err(e) => Err(Error::new(ErrorKind::Other, e)), + } + } +} \ No newline at end of file