diff --git a/Cargo.toml b/Cargo.toml index 5da677c5c..4c4e33994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,15 +16,24 @@ brotli = ["brotli2"] gzip = ["flate2"] [dependencies] +arrayvec = "0.3.23" +atoi = "0.2.2" base64 = "0.7.0" brotli2 = { version = "0.2.1", optional = true } chrono = "0.2.0" +crossbeam = "0.3.0" filetime = "0.1.10" flate2 = { version = "0.2.14", optional = true } +httparse = "1.2.3" +itoa = "0.3" +mio = "0.6.10" multipart = { version = "0.5.1", default-features = false, features = ["server"] } +num_cpus = "1.6.2" rand = "0.3.11" rustc-serialize = "0.3" +rustls = "0.9.0" sha1 = "0.2.0" +slab = "0.4.0" term = "0.2" time = "0.1.31" tiny_http = "0.5.6" diff --git a/src/lib.rs b/src/lib.rs index 9733a77d4..997cf1e49 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,17 +52,26 @@ #![deny(unsafe_code)] +extern crate arrayvec; +extern crate atoi; extern crate base64; #[cfg(feature = "brotli2")] extern crate brotli2; extern crate chrono; +extern crate crossbeam; extern crate filetime; #[cfg(feature = "flate2")] extern crate flate2; +extern crate httparse; +extern crate itoa; +extern crate mio; extern crate multipart; +extern crate num_cpus; extern crate rand; extern crate rustc_serialize; +extern crate rustls; extern crate sha1; +extern crate slab; extern crate time; extern crate tiny_http; extern crate url; @@ -71,21 +80,20 @@ pub use assets::extension_to_mime; pub use assets::match_assets; pub use log::log; pub use response::{Response, ResponseBody}; +pub use server::Server; +pub use server::SslConfig; pub use tiny_http::ReadWrite; -use std::error::Error; +use arrayvec::ArrayString; use std::io::Cursor; use std::io::Result as IoResult; use std::io::Read; use std::marker::PhantomData; use std::net::SocketAddr; use std::net::ToSocketAddrs; -use std::panic; -use std::panic::AssertUnwindSafe; use std::slice::Iter as SliceIter; use std::sync::Arc; use std::sync::Mutex; -use std::thread; use std::ascii::AsciiExt; use std::fmt; @@ -101,6 +109,8 @@ mod find_route; mod log; mod response; mod router; +mod server; +mod socket_handler; #[doc(hidden)] pub mod try_or_400; @@ -208,169 +218,6 @@ pub fn start_server(addr: A, handler: F) -> ! panic!("The server socket closed unexpectedly") } -/// A listening server. -/// -/// This struct is the more manual server creation API of rouille and can be used as an alternative -/// to the `start_server` function. -/// -/// The `start_server` function is just a shortcut for `Server::new` followed with `run`. See the -/// documentation of the `start_server` function for more details about the handler. -/// -/// # Example -/// -/// ```no_run -/// use rouille::Server; -/// use rouille::Response; -/// -/// let server = Server::new("localhost:0", |request| { -/// Response::text("hello world") -/// }).unwrap(); -/// println!("Listening on {:?}", server.server_addr()); -/// server.run(); -/// ``` -pub struct Server { - server: tiny_http::Server, - handler: Arc>, -} - -impl Server where F: Send + Sync + 'static + Fn(&Request) -> Response { - /// Builds a new `Server` object. - /// - /// After this function returns, the HTTP server is listening. - /// - /// Returns an error if there was an error while creating the listening socket, for example if - /// the port is already in use. - pub fn new(addr: A, handler: F) -> Result, Box> - where A: ToSocketAddrs - { - let server = try!(tiny_http::Server::http(addr)); - - Ok(Server { - server: server, - handler: Arc::new(AssertUnwindSafe(handler)), // TODO: using AssertUnwindSafe here is wrong, but unwind safety has some usability problems in Rust in general - }) - } - - /// Returns the address of the listening socket. - #[inline] - pub fn server_addr(&self) -> SocketAddr { - self.server.server_addr() - } - - /// Runs the server forever, or until the listening socket is somehow force-closed by the - /// operating system. - #[inline] - pub fn run(self) { - for request in self.server.incoming_requests() { - self.process(request); - } - } - - /// Processes all the client requests waiting to be processed, then returns. - /// - /// This function executes very quickly, as each client requests that needs to be processed - /// is processed in a separate thread. - #[inline] - pub fn poll(&self) { - while let Ok(Some(request)) = self.server.try_recv() { - self.process(request); - } - } - - // Internal function, called when we got a request from tiny-http that needs to be processed. - fn process(&self, request: tiny_http::Request) { - // We spawn a thread so that requests are processed in parallel. - let handler = self.handler.clone(); - thread::spawn(move || { - // Small helper struct that makes it possible to put - // a `tiny_http::Request` inside a `Box`. - struct RequestRead(Arc>>); - impl Read for RequestRead { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> IoResult { - self.0.lock().unwrap().as_mut().unwrap().as_reader().read(buf) - } - } - - // Building the `Request` object. - let tiny_http_request; - let rouille_request = { - let url = request.url().to_owned(); - let method = request.method().as_str().to_owned(); - let headers = request.headers().iter().map(|h| (h.field.to_string(), h.value.clone().into())).collect(); - let remote_addr = request.remote_addr().clone(); - - tiny_http_request = Arc::new(Mutex::new(Some(request))); - - Request { - url: url, - method: method, - headers: headers, - https: false, - data: Arc::new(Mutex::new(Some(Box::new(RequestRead(tiny_http_request.clone())) as Box<_>))), - remote_addr: remote_addr, - } - }; - - // Calling the handler ; this most likely takes a lot of time. - // If the handler panics, we build a dummy response. - let mut rouille_response = { - // We don't use the `rouille_request` anymore after the panic, so it's ok to assert - // it's unwind safe. - let rouille_request = AssertUnwindSafe(rouille_request); - let res = panic::catch_unwind(move || { - let rouille_request = rouille_request; - handler(&rouille_request) - }); - - match res { - Ok(r) => r, - Err(_) => { - Response::html("

Internal Server Error

\ -

An internal error has occurred on the server.

") - .with_status_code(500) - } - } - }; - - // writing the response - let (res_data, res_len) = rouille_response.data.into_reader_and_size(); - let mut response = tiny_http::Response::empty(rouille_response.status_code) - .with_data(res_data, res_len); - - let mut upgrade_header = "".into(); - - for (key, value) in rouille_response.headers { - if key.eq_ignore_ascii_case("Content-Length") { - continue; - } - - if key.eq_ignore_ascii_case("Upgrade") { - upgrade_header = value; - continue; - } - - if let Ok(header) = tiny_http::Header::from_bytes(key.as_bytes(), value.as_bytes()) { - response.add_header(header); - } else { - // TODO: ? - } - } - - if let Some(ref mut upgrade) = rouille_response.upgrade { - let trq = tiny_http_request.lock().unwrap().take().unwrap(); - let socket = trq.upgrade(&upgrade_header, response); - upgrade.build(socket); - - } else { - // We don't really care if we fail to send the response to the client, as there's - // nothing we can do anyway. - let _ = tiny_http_request.lock().unwrap().take().unwrap().respond(response); - } - }); - } -} - /// Trait for objects that can take ownership of a raw connection to the client data. /// /// The purpose of this trait is to be used with the `Connection: Upgrade` header, hence its name. @@ -384,7 +231,9 @@ pub trait Upgrade { /// This can be either a real request (received by the HTTP server) or a mock object created with /// one of the `fake_*` constructors. pub struct Request { - method: String, + // The method (`GET`, `POST`, ..). The longest registered method know to the author is + // `UPDATEREDIRECTREF` and is 17 bytes long. + method: ArrayString<[u8; 17]>, url: String, headers: Vec<(String, String)>, https: bool, @@ -414,7 +263,7 @@ impl Request { { Request { url: url.into(), - method: method.into(), + method: ArrayString::from(&method.into()).expect("Method too long"), https: false, data: Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>))), headers: headers, @@ -429,7 +278,7 @@ impl Request { { Request { url: url.into(), - method: method.into(), + method: ArrayString::from(&method.into()).expect("Method too long"), https: false, data: Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>))), headers: headers, @@ -446,7 +295,7 @@ impl Request { { Request { url: url.into(), - method: method.into(), + method: ArrayString::from(&method.into()).expect("Method too long"), https: true, data: Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>))), headers: headers, @@ -461,7 +310,7 @@ impl Request { { Request { url: url.into(), - method: method.into(), + method: ArrayString::from(&method.into()).expect("Method too long"), https: true, data: Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>))), headers: headers, diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 000000000..d40e508fb --- /dev/null +++ b/src/server.rs @@ -0,0 +1,440 @@ +// Copyright (c) 2017 The Rouille developers +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , +// at your option. All files in the project carrying such +// notice may not be copied, modified, or distributed except +// according to those terms. + +use std::error::Error; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Write; +use std::net::SocketAddr; +use std::net::ToSocketAddrs; +use std::panic::AssertUnwindSafe; +use std::sync::Arc; +use std::sync::Mutex; +use std::thread; +use mio::{Events, Poll, Ready, PollOpt}; +use mio::tcp::{TcpListener, TcpStream}; +use num_cpus; +use slab::Slab; + +pub use socket_handler::RustlsConfig as SslConfig; + +use socket_handler::RegistrationState; +use socket_handler::TaskPool; +use socket_handler::SocketHandler; +use socket_handler::SocketHandlerDispatch; +use socket_handler::Update as SocketHandlerUpdate; + +use Request; +use Response; + +/// A listening server. +/// +/// This struct is the more manual server creation API of rouille and can be used as an alternative +/// to the `start_server` function. +/// +/// The `start_server` function is just a shortcut for `Server::new` followed with `run`. See the +/// documentation of the `start_server` function for more details about the handler. +/// +/// # Example +/// +/// ```no_run +/// use rouille::Server; +/// use rouille::Response; +/// +/// let server = Server::new("localhost:0", |request| { +/// Response::text("hello world") +/// }).unwrap(); +/// println!("Listening on {:?}", server.server_addr()); +/// server.run(); +/// ``` +pub struct Server { + inner: Arc>, + local_events: Mutex, +} + +// Data shared between threads. +struct ThreadsShare { + // The main poll event. + poll: Poll, + // Storage for all the objects registered towards the `Poll`. + sockets: Mutex>, + // The function that handles requests. + handler: AssertUnwindSafe, + // Pool used to dispatch tasks. + task_pool: TaskPool, +} + +enum Socket { + Listener{ + listener: TcpListener, + https: Option, + }, + Stream { + stream: TcpStream, + read_closed: bool, + write_flush_suggested: bool, + handler: SocketHandlerDispatch, + update: SocketHandlerUpdate, + }, +} + +impl Server where F: Send + Sync + 'static + Fn(&Request) -> Response { + /// Builds a new `Server` object. + /// + /// After this function returns, the HTTP server is listening. + /// + /// Returns an error if there was an error while creating the listening socket, for example if + /// the port is already in use. + pub fn new
(addr: A, handler: F) -> Result, Box> + where A: ToSocketAddrs, + F: Fn(&Request) -> Response + Send + 'static + { + let server = Server::empty(handler)?; + server.add_http_listeners(addr)?; + Ok(server) + } + + /// Builds a new `Server` but without any listener. + pub fn empty(handler: F) -> Result, Box> + where F: Fn(&Request) -> Response + Send + 'static + { + let share = Arc::new(ThreadsShare { + poll: Poll::new()?, + sockets: Mutex::new(Slab::new()), + handler: AssertUnwindSafe(handler), // TODO: using AssertUnwindSafe here is wrong, but unwind safety has some usability problems in Rust in general + task_pool: TaskPool::new(), + }); + + for _ in 0 .. num_cpus::get() - 1 { + let share = share.clone(); + thread::spawn(move || { + // Each thread has its own local MIO events. + let mut events = Events::with_capacity(128); + + // TODO: The docs say that two events can be generated, one for read and one for + // write, presumably even if we pass one_shot(). Is this code ready for this + // situation? + + loop { + one_poll(&share, &mut events); + } + }); + } + + Ok(Server { + inner: share, + local_events: Mutex::new(Events::with_capacity(128)), + }) + } + + // Adds new HTTP listening addresses to the server. + pub fn add_http_listeners(&self, addr: A) -> Result<(), Box> + where A: ToSocketAddrs + { + for addr in addr.to_socket_addrs()? { + // TODO: error recovery? what if the 2nd address returns an error? do we keep the first running? + self.add_http_listener(&addr)?; + } + Ok(()) + } + + /// Adds a new listening addr to the server. + pub fn add_http_listener(&self, addr: &SocketAddr) -> Result<(), Box> { + let listener = TcpListener::bind(addr)?; + + let mut slab = self.inner.sockets.lock().unwrap(); + let entry = slab.vacant_entry(); + + self.inner.poll.register(&listener, entry.key().into(), + Ready::readable(), PollOpt::edge() | PollOpt::oneshot())?; + + entry.insert(Socket::Listener { + listener: listener, + https: None, + }); + + Ok(()) + } + + /// Adds a new listening addr to the server. + pub fn add_https_listener(&self, addr: &SocketAddr, config: SslConfig) + -> Result<(), Box> { + let listener = TcpListener::bind(addr)?; + + let mut slab = self.inner.sockets.lock().unwrap(); + let entry = slab.vacant_entry(); + + self.inner.poll.register(&listener, entry.key().into(), + Ready::readable(), PollOpt::edge() | PollOpt::oneshot())?; + + entry.insert(Socket::Listener { + listener: listener, + https: Some(config), + }); + + Ok(()) + } + + /// Returns the address of the listening socket. + #[inline] + pub fn server_addr(&self) -> SocketAddr { + unimplemented!() // FIXME: restore? + //self.server.server_addr() + } + + /// Runs the server forever, or until the listening socket is somehow force-closed by the + /// operating system. + #[inline] + pub fn run(self) { + let mut local_events = self.local_events.lock().unwrap(); + loop { + one_poll(&self.inner, &mut local_events); + } + } + + /// Processes all the client requests waiting to be processed, then returns. + /// + /// This function executes very quickly, as each client requests that needs to be processed + /// is processed in a separate thread. + #[inline] + pub fn poll(&self) { + let mut local_events = self.local_events.lock().unwrap(); + one_poll(&self.inner, &mut local_events); + } +} + +fn one_poll(share: &Arc>, events: &mut Events) + where F: Fn(&Request) -> Response + Send + Sync + 'static +{ + share.poll.poll(events, None).expect("Error with the system selector"); + + for event in events.iter() { + // We handle reading before writing, as handling reading can generate data to write. + + if event.readiness().is_readable() { + let socket = { + let mut slab = share.sockets.lock().unwrap(); + if !slab.contains(event.token().into()) { + continue; + } + slab.remove(event.token().into()) + }; + + handle_read(share, socket); + } + + if event.readiness().is_writable() { + let socket = { + let mut slab = share.sockets.lock().unwrap(); + if !slab.contains(event.token().into()) { + continue; + } + slab.remove(event.token().into()) + }; + + handle_write(share, socket); + } + } +} + +fn handle_read(share: &Arc>, socket: Socket) + where F: Fn(&Request) -> Response + Send + Sync + 'static +{ + match socket { + Socket::Listener { listener, https } => { + // Call `accept` repeatidely and register the newly-created sockets, + // until `WouldBlock` is returned. + loop { + match listener.accept() { + Ok((stream, client_addr)) => { + let mut slab = share.sockets.lock().unwrap(); + let entry = slab.vacant_entry(); + share.poll.register(&stream, entry.key().into(), Ready::readable(), + PollOpt::edge() | PollOpt::oneshot()) + .expect("Error while registering TCP stream"); + let handler = { + let share = share.clone(); + if let Some(ref https) = https { + SocketHandlerDispatch::https(https.clone(), client_addr, + share.task_pool.clone(), + move |rq| (share.handler)(&rq)) + } else { + SocketHandlerDispatch::http(client_addr, share.task_pool.clone(), + move |rq| (share.handler)(&rq)) + } + }; + entry.insert(Socket::Stream { + stream: stream, + read_closed: false, + write_flush_suggested: false, + handler: handler, + update: SocketHandlerUpdate::empty(), + }); + }, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => break, + Err(_) => { + // Handle errors with the listener by returning without re-registering it. + // This drops the listener. + return; + }, + }; + }; + + // Re-register the listener for the next time. + let mut slab = share.sockets.lock().unwrap(); + let entry = slab.vacant_entry(); + share.poll.reregister(&listener, entry.key().into(), Ready::readable(), + PollOpt::edge() | PollOpt::oneshot()) + .expect("Error while reregistering TCP listener"); + entry.insert(Socket::Listener { + listener: listener, + https: https, + }); + }, + + Socket::Stream { mut stream, mut read_closed, mut write_flush_suggested, mut handler, + mut update } => + { + // Read into `update.pending_read_buffer` until `WouldBlock` is returned. + loop { + let old_pr_len = update.pending_read_buffer.len(); + update.pending_read_buffer.resize(old_pr_len + 256, 0); + + match stream.read(&mut update.pending_read_buffer[old_pr_len..]) { + Ok(0) => { + update.pending_read_buffer.resize(old_pr_len, 0); + break; + }, + Ok(n) => { + update.pending_read_buffer.resize(old_pr_len + n, 0); + }, + Err(ref e) if e.kind() == ErrorKind::Interrupted => { + update.pending_read_buffer.resize(old_pr_len, 0); + }, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + update.pending_read_buffer.resize(old_pr_len, 0); + break; + }, + Err(_) => { + // Handle errors with the stream by returning without re-registering it. + // This drops the stream. + return; + }, + }; + } + + // Dispatch to handler. + let mut update_result = handler.update(&mut update); + if update_result.close_read { + read_closed = true; + } + if update_result.write_flush_suggested { + write_flush_suggested = true; + } + + // Re-register stream for next time. + let mut ready = Ready::empty(); + if !read_closed { + ready = ready | Ready::readable(); + } + if !update.pending_write_buffer.is_empty() { + ready = ready | Ready::writable(); + } + + let mut slab = share.sockets.lock().unwrap(); + let entry = slab.vacant_entry(); + + let mut insert_entry = false; + + if let Some((registration, state)) = update_result.registration.take() { + match state { + RegistrationState::FirstTime => { + share.poll.register(&*registration, entry.key().into(), + Ready::readable() | Ready::writable(), + PollOpt::edge() | PollOpt::oneshot()) + .expect("Error while registering registration"); + }, + RegistrationState::Reregister => { + share.poll.reregister(&*registration, entry.key().into(), + Ready::readable() | Ready::writable(), + PollOpt::edge() | PollOpt::oneshot()) + .expect("Error while registering registration"); + }, + } + + insert_entry = true; + } + + if !ready.is_empty() { + share.poll.reregister(&stream, entry.key().into(), ready, + PollOpt::edge() | PollOpt::oneshot()) + .expect("Error while reregistering TCP stream"); + insert_entry = true; + } + + if insert_entry { + entry.insert(Socket::Stream { stream, read_closed, write_flush_suggested, + handler, update }); + } + }, + } +} + +fn handle_write(share: &ThreadsShare, socket: Socket) { + // Write events can't happen for listeners. + let (mut stream, read_closed, mut write_flush_suggested, handler, mut update) = match socket { + Socket::Listener { .. } => unreachable!(), + Socket::Stream { stream, read_closed, write_flush_suggested, handler, update } => + (stream, read_closed, write_flush_suggested, handler, update), + }; + + // Write from `update.pending_write_buffer` to `stream`. + while !update.pending_write_buffer.is_empty() { + match stream.write(&update.pending_write_buffer) { + Ok(0) => break, + Ok(written) => { + let cut_len = update.pending_write_buffer.len() - written; + for n in 0 .. cut_len { + update.pending_write_buffer[n] = update.pending_write_buffer[n + written]; + } + update.pending_write_buffer.resize(cut_len, 0); + }, + Err(ref e) if e.kind() == ErrorKind::Interrupted => {}, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => break, + Err(_) => { + // Handle errors with the stream by returning without re-registering it. This + // drops the stream. + return; + }, + }; + }; + + if write_flush_suggested { + let _ = stream.flush(); + write_flush_suggested = false; + } + + // Re-register the stream for the next event. + let mut ready = Ready::empty(); + if !read_closed { + ready = ready | Ready::readable(); + } + if !update.pending_write_buffer.is_empty() { + ready = ready | Ready::writable(); + } + if !ready.is_empty() { + let mut slab = share.sockets.lock().unwrap(); + let entry = slab.vacant_entry(); + share.poll.reregister(&stream, entry.key().into(), ready, + PollOpt::edge() | PollOpt::oneshot()) + .expect("Error while reregistering TCP stream"); + entry.insert(Socket::Stream { stream, read_closed, write_flush_suggested, + handler, update }); + } +} diff --git a/src/socket_handler/http1.rs b/src/socket_handler/http1.rs new file mode 100644 index 000000000..f4c458313 --- /dev/null +++ b/src/socket_handler/http1.rs @@ -0,0 +1,651 @@ +// Copyright (c) 2017 The Rouille developers +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , +// at your option. All files in the project carrying such +// notice may not be copied, modified, or distributed except +// according to those terms. + +use std::ascii::AsciiExt; +use std::borrow::Cow; +use std::io::copy; +use std::io::Cursor; +use std::io::Error as IoError; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Write; +use std::mem; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::mpsc::{channel, Sender, Receiver}; +use std::sync::mpsc::TryRecvError; +use std::str; +use arrayvec::ArrayString; +use httparse; +use itoa::write as itoa; +use mio::Ready; +use mio::Registration; +use mio::SetReadiness; + +use socket_handler::Protocol; +use socket_handler::RegistrationState; +use socket_handler::SocketHandler; +use socket_handler::Update; +use socket_handler::UpdateResult; +use socket_handler::request_body_analyzer::RequestBodyAnalyzer; +use socket_handler::task_pool::TaskPool; +use Request; +use Response; + +/// Handles the processing of a client connection. +pub struct Http1Handler { + // The handler is a state machine. + state: Http1HandlerState, + + // Maximum number of bytes in the buffer while waiting for the request line or headers. + max_buffer_size: usize, + + // Address of the client. Necessary for the request objects. + client_addr: SocketAddr, + + // Protocol of the original server. Necessary for the request objects. + original_protocol: Protocol, + + // Object that handles the request and returns a response. + handler: Arc Response + Send + 'static>>, + + // The pool where to dispatch the handler. + task_pool: TaskPool, +} + +// Current status of the handler. +enum Http1HandlerState { + // A panic happened during the processing. In this state any call to `update` will panic. + Poisonned, + + // The `pending_read_buffer` doesn't have enough bytes to contain the initial request line. + WaitingForRqLine { + // Offset within `pending_read_buffer` where new data is available. Everything before this + // offset was already in `pending_read_buffer` the last time `update` returned. + new_data_start: usize, + }, + + // The request line has been parsed (its informations are inside the variant), but the + // `pending_read_buffer` doesn't have enough bytes to contain the headers. + WaitingForHeaders { + // Offset within `pending_read_buffer` where new data is available. Everything before this + // offset was already in `pending_read_buffer` the last time `update` returned. + new_data_start: usize, + // HTTP method (eg. GET, POST, ...) parsed from the request line. + method: ArrayString<[u8; 17]>, + // URL requested by the HTTP client parsed from the request line. + path: String, + // HTTP version parsed from the request line. + version: HttpVersion, + }, + + // The handler is currently being executed in the task pool and is streaming data. + ExecutingHandler { + // True if `Connection: close` was requested by the client as part of the headers. + connection_close: bool, + // Analyzes and decodes the client input. + input_analyzer: RequestBodyAnalyzer, + // Used to send buffers containing the body of the request. `None` if no more data. + input_data: Option>>, + // Contains blocks of output data streamed by the handler. Closed when the handler doesn't + // have any more data to send. + response_getter: Receiver>, + // Registration that is triggered by the background thread whenever some data is available + // in `response_getter`. + registration: (Arc, RegistrationState), + }, + + // Happens after a request with `Connection: close`. The connection is considered as closed by + // the handler and nothing more will be processed. + Closed, +} + +impl Http1Handler { + /// Starts handling a new HTTP client connection. + /// + /// `client_addr` and `original_protocol` are necessary for building the `Request` objects. + /// `task_pool` and `handler` indicate how the requests must be processed. + pub fn new(client_addr: SocketAddr, original_protocol: Protocol, task_pool: TaskPool, + handler: F) -> Http1Handler + where F: FnMut(Request) -> Response + Send + 'static + { + Http1Handler { + state: Http1HandlerState::WaitingForRqLine { new_data_start: 0 }, + max_buffer_size: 10240, + client_addr: client_addr, + original_protocol: original_protocol, + handler: Arc::new(Mutex::new(handler)), + task_pool: task_pool, + } + } +} + +impl SocketHandler for Http1Handler { + fn update(&mut self, update: &mut Update) -> UpdateResult { + loop { + match mem::replace(&mut self.state, Http1HandlerState::Poisonned) { + Http1HandlerState::Poisonned => { + panic!("Poisonned request handler"); + }, + + Http1HandlerState::WaitingForRqLine { new_data_start } => { + // Try to find a \r\n in the buffer. + let off = new_data_start.saturating_sub(1); + let rn = update.pending_read_buffer[off..].windows(2) + .position(|w| w == b"\r\n"); + if let Some(rn) = rn { + // Found a request line! + let method; + let path; + let version; + { + let (method_raw, path_raw, version_raw) = match parse_request_line(&update.pending_read_buffer[..rn]) { + Ok(v) => v, + Err(_) => { + write_status_and_headers(&mut update.pending_write_buffer, 400, &[], Some(0)); + self.state = Http1HandlerState::Closed; + break UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: true, + }; + }, + }; + + method = match ArrayString::from(method_raw) { + Ok(m) => m, + Err(_) => { + write_status_and_headers(&mut update.pending_write_buffer, 501, &[], Some(0)); + self.state = Http1HandlerState::Closed; + break UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: true, + }; + }, + }; + + path = path_raw.to_owned(); + version = version_raw; + }; + + // Remove the request line from the head of the buffer. + let cut_len = update.pending_read_buffer.len() - (rn + 2); + for n in 0 .. cut_len { + update.pending_read_buffer[n] = update.pending_read_buffer[n + rn + 2]; + } + update.pending_read_buffer.resize(cut_len, 0); + + self.state = Http1HandlerState::WaitingForHeaders { + new_data_start: 0, + method, + path, + version + }; + + } else { + // No full request line in the buffer yet. + + // Handle buffer too large. + if update.pending_read_buffer.len() > self.max_buffer_size { + write_status_and_headers(&mut update.pending_write_buffer, 413, + &[], Some(0)); + self.state = Http1HandlerState::Closed; + break UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: true, + }; + } + + self.state = Http1HandlerState::WaitingForRqLine { + new_data_start: update.pending_read_buffer.len(), + }; + + break UpdateResult { + registration: None, + close_read: false, + write_flush_suggested: false, + }; + } + }, + + Http1HandlerState::WaitingForHeaders { new_data_start, method, path, version } => { + // Try to find a `\r\n\r\n` in the buffer which would indicate the end of the + // headers. + let off = new_data_start.saturating_sub(3); + let rnrn = update.pending_read_buffer[off..].windows(4) + .position(|w| w == b"\r\n\r\n"); + if let Some(rnrn) = rnrn { + // Found headers! Parse them. + let headers: Vec<(String, String)> = { + let mut out_headers = Vec::new(); + let mut headers = [httparse::EMPTY_HEADER; 32]; + let (_, parsed_headers) = httparse::parse_headers(&update.pending_read_buffer, &mut headers).unwrap().unwrap(); // TODO: + for parsed in parsed_headers { + out_headers.push((parsed.name.to_owned(), String::from_utf8_lossy(parsed.value).into())); // TODO: wrong + } + out_headers + }; + + // Remove the headers from the head of the buffer. + let cut_len = update.pending_read_buffer.len() - (off + rnrn + 4); + for n in 0 .. cut_len { + update.pending_read_buffer[n] = update.pending_read_buffer[n + off + rnrn + 4]; + } + update.pending_read_buffer.resize(cut_len, 0); + + let input_analyzer = { + let iter = headers + .iter() + .map(|&(ref h, ref v)| (h.as_str(), v.as_str())); + RequestBodyAnalyzer::new(iter) + }; + + // We now create a new task for our task pool in which the request is + // built, the handler is called, and the response is sent as Vecs through + // a channel. + let (data_out_tx, data_out_rx) = channel(); + let (data_in_tx, data_in_rx) = channel(); + let (registration, set_ready) = Registration::new2(); + spawn_handler_task(&self.task_pool, self.handler.clone(), method, path, + headers, self.original_protocol, + self.client_addr.clone(), data_out_tx, data_in_rx, + set_ready); + + self.state = Http1HandlerState::ExecutingHandler { + connection_close: false, // TODO: + input_analyzer: input_analyzer, + input_data: Some(data_in_tx), + response_getter: data_out_rx, + registration: (Arc::new(registration), RegistrationState::FirstTime), + }; + + } else { + // No full headers in the buffer yet. + + // Handle buffer too large. + if update.pending_read_buffer.len() > self.max_buffer_size { + write_status_and_headers(&mut update.pending_write_buffer, 413, + &[], Some(0)); + self.state = Http1HandlerState::Closed; + break UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: true, + }; + } + + self.state = Http1HandlerState::WaitingForHeaders { + new_data_start: update.pending_read_buffer.len(), + method, + path, + version + }; + + break UpdateResult { + registration: None, + close_read: false, + write_flush_suggested: false, + }; + } + }, + + Http1HandlerState::ExecutingHandler { connection_close, mut input_data, + mut input_analyzer, response_getter, + registration } => + { + { + let analysis = input_analyzer.feed(&mut update.pending_read_buffer); + if analysis.body_data >= 1 { + // TODO: more optimal + let body_data = update.pending_read_buffer[0 .. analysis.body_data].to_owned(); + update.pending_read_buffer = update.pending_read_buffer[analysis.body_data + analysis.unused_trailing..].to_owned(); + let _ = input_data.as_mut().unwrap().send(body_data); + } else { + assert_eq!(analysis.unused_trailing, 0); + } + if analysis.finished { + input_data = None; + } + } + + match response_getter.try_recv() { + Ok(mut data) => { + // Got some data for the response. + if update.pending_write_buffer.is_empty() { + update.pending_write_buffer = data; + } else { + update.pending_write_buffer.append(&mut data); + } + self.state = Http1HandlerState::ExecutingHandler { + connection_close: connection_close, + input_data: input_data, + input_analyzer: input_analyzer, + response_getter: response_getter, + registration: registration, + }; + }, + Err(TryRecvError::Disconnected) => { + // The handler has finished streaming the response. + if connection_close { + self.state = Http1HandlerState::Closed; + } else { + self.state = Http1HandlerState::WaitingForRqLine { + new_data_start: 0 + }; + break UpdateResult { + registration: None, + close_read: false, + write_flush_suggested: true, + }; + } + }, + Err(TryRecvError::Empty) => { + // Spurious wakeup. + self.state = Http1HandlerState::ExecutingHandler { + connection_close: connection_close, + input_data: input_data, + input_analyzer: input_analyzer, + response_getter: response_getter, + registration: (registration.0.clone(), RegistrationState::Reregister), + }; + break UpdateResult { + registration: Some(registration), + close_read: false, + write_flush_suggested: false, + }; + }, + } + }, + + Http1HandlerState::Closed => { + update.pending_read_buffer.clear(); + self.state = Http1HandlerState::Closed; + break UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: true, + }; + }, + } + } + } +} + +// Starts the task of handling a request. +fn spawn_handler_task(task_pool: &TaskPool, + handler: Arc Response + Send + 'static>>, + method: ArrayString<[u8; 17]>, path: String, + headers: Vec<(String, String)>, original_protocol: Protocol, + remote_addr: SocketAddr, data_out_tx: Sender>, + data_in_rx: Receiver>, set_ready: SetReadiness) +{ + let https = original_protocol == Protocol::Https; + + struct ReadWrapper(Receiver>, Cursor>); + impl Read for ReadWrapper { + fn read(&mut self, buf: &mut [u8]) -> Result { + let initial_buf_len = buf.len() as u64; + let mut total_written = 0; + let mut buf = Cursor::new(buf); + + total_written += copy(&mut self.1, &mut buf).unwrap(); + debug_assert!(total_written <= initial_buf_len); + if total_written == initial_buf_len { + return Ok(total_written as usize); + } + + match self.0.recv() { + Ok(data) => self.1 = Cursor::new(data), + Err(_) => return Ok(total_written as usize), + }; + + total_written += copy(&mut self.1, &mut buf).unwrap(); + debug_assert!(total_written <= initial_buf_len); + if total_written == initial_buf_len { + return Ok(total_written as usize); + } + + loop { + match self.0.try_recv() { + Ok(data) => self.1 = Cursor::new(data), + Err(_) => return Ok(total_written as usize), + }; + + total_written += copy(&mut self.1, &mut buf).unwrap(); + debug_assert!(total_written <= initial_buf_len); + if total_written == initial_buf_len { + return Ok(total_written as usize); + } + } + } + } + + let reader = ReadWrapper(data_in_rx, Cursor::new(Vec::new())); + + task_pool.spawn(move || { + let request = Request { + method: method, + url: path, + headers: headers, + https: https, + data: Arc::new(Mutex::new(Some(Box::new(reader)))), + remote_addr: remote_addr, + }; + + let response = { + let mut handler = handler.lock().unwrap(); + (&mut *handler)(request) + }; + assert!(response.upgrade.is_none()); // TODO: unimplemented + + let (mut body_data, body_size) = response.data.into_reader_and_size(); + + let mut out_buffer = Vec::new(); + write_status_and_headers(&mut out_buffer, + response.status_code, + &response.headers, + body_size); + + match data_out_tx.send(out_buffer) { + Ok(_) => (), + Err(_) => return, + }; + + let _ = set_ready.set_readiness(Ready::readable()); + + loop { + let mut out_data = vec![0; 1024]; + match body_data.read(&mut out_data) { + Ok(0) => break, + Ok(n) => out_data.truncate(n), + Err(ref e) if e.kind() == ErrorKind::Interrupted => {}, + Err(_) => { + // Handle errors by silently stopping the stream. + // TODO: better way? + return; + }, + }; + + // Encoding as chunks if relevant. + // TODO: more optimized + if body_size.is_none() { + let len = out_data.len(); + let data = mem::replace(&mut out_data, Vec::with_capacity(len + 7)); + write!(&mut out_data, "{:x}", len).unwrap(); + out_data.extend_from_slice(b"\r\n"); + out_data.extend(data); + out_data.extend_from_slice(b"\r\n"); + } + + match data_out_tx.send(out_data) { + Ok(_) => (), + Err(_) => return, + }; + let _ = set_ready.set_readiness(Ready::readable()); + } + + if body_size.is_none() { + let _ = data_out_tx.send(b"0\r\n\r\n".to_vec()); + } + + let _ = set_ready.set_readiness(Ready::readable()); + }); +} + +// HTTP version (usually 1.0 or 1.1). +#[derive(Debug, Clone, PartialEq, Eq)] +struct HttpVersion(pub u8, pub u8); + +// Parses a "HTTP/1.1" string. +// TODO: handle [u8] correctly +fn parse_http_version(version: &str) -> Result { + let mut elems = version.splitn(2, '/'); + + elems.next(); + let vers = match elems.next() { + Some(v) => v, + None => return Err(()), + }; + + let mut elems = vers.splitn(2, '.'); + let major = elems.next().and_then(|n| n.parse().ok()); + let minor = elems.next().and_then(|n| n.parse().ok()); + + match (major, minor) { + (Some(ma), Some(mi)) => Ok(HttpVersion(ma, mi)), + _ => return Err(()), + } +} + +// Parses the request line of the request. +// eg. GET / HTTP/1.1 +// TODO: handle [u8] correctly +fn parse_request_line(line: &[u8]) -> Result<(&str, &str, HttpVersion), ()> { + let line = str::from_utf8(line).unwrap(); // TODO: + let mut words = line.split(' '); + + let method = words.next(); + let path = words.next(); + let version = words.next(); + + let (method, path, version) = match (method, path, version) { + (Some(m), Some(p), Some(v)) => (m, p, v), + _ => return Err(()) + }; + + let version = parse_http_version(version)?; + Ok((method, path, version)) +} + +// Writes the status line and headers of the response to `out`. +fn write_status_and_headers(mut out: &mut Vec, status_code: u16, + headers: &[(Cow<'static, str>, Cow<'static, str>)], + body_size: Option) +{ + out.extend_from_slice(b"HTTP/1.1 "); + itoa(&mut out, status_code).unwrap(); + out.push(b' '); + out.extend_from_slice(default_reason_phrase(status_code).as_bytes()); + out.extend_from_slice(b"\r\n"); + + let mut found_server_header = false; + let mut found_date_header = false; + for &(ref header, ref value) in headers { + if !found_server_header && header.eq_ignore_ascii_case("Server") { + found_server_header = true; + } + if !found_date_header && header.eq_ignore_ascii_case("Date") { + found_date_header = true; + } + + // Some headers can't be written with the response, as they are too "low-level". + if header.eq_ignore_ascii_case("Content-Length") || + header.eq_ignore_ascii_case("Transfer-Encoding") || + header.eq_ignore_ascii_case("Connection") || + header.eq_ignore_ascii_case("Trailer") + { + continue; + } + + out.extend_from_slice(header.as_bytes()); + out.extend_from_slice(b": "); + out.extend_from_slice(value.as_bytes()); + out.extend_from_slice(b"\r\n"); + } + + if !found_server_header { + out.extend_from_slice(b"Server: rouille\r\n"); + } + if !found_date_header { + out.extend_from_slice(b"Date: TODO\r\n"); // TODO: + } + + if let Some(body_size) = body_size { + out.extend_from_slice(b"Content-Length: "); + itoa(&mut out, body_size).unwrap(); + out.extend_from_slice(b"\r\n"); + } else { + out.extend_from_slice(b"Transfer-Encoding: chunked\r\n"); + } + out.extend_from_slice(b"\r\n"); +} + +// Returns the phrase corresponding to a status code. +fn default_reason_phrase(status_code: u16) -> &'static str { + match status_code { + 100 => "Continue", + 101 => "Switching Protocols", + 102 => "Processing", + 118 => "Connection timed out", + 200 => "OK", + 201 => "Created", + 202 => "Accepted", + 203 => "Non-Authoritative Information", + 204 => "No Content", + 205 => "Reset Content", + 206 => "Partial Content", + 207 => "Multi-Status", + 210 => "Content Different", + 300 => "Multiple Choices", + 301 => "Moved Permanently", + 302 => "Found", + 303 => "See Other", + 304 => "Not Modified", + 305 => "Use Proxy", + 307 => "Temporary Redirect", + 400 => "Bad Request", + 401 => "Unauthorized", + 402 => "Payment Required", + 403 => "Forbidden", + 404 => "Not Found", + 405 => "Method Not Allowed", + 406 => "Not Acceptable", + 407 => "Proxy Authentication Required", + 408 => "Request Time-out", + 409 => "Conflict", + 410 => "Gone", + 411 => "Length Required", + 412 => "Precondition Failed", + 413 => "Request Entity Too Large", + 414 => "Reques-URI Too Large", + 415 => "Unsupported Media Type", + 416 => "Request range not satisfiable", + 417 => "Expectation Failed", + 500 => "Internal Server Error", + 501 => "Not Implemented", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + 504 => "Gateway Time-out", + 505 => "HTTP Version not supported", + _ => "Unknown" + } +} diff --git a/src/socket_handler/mod.rs b/src/socket_handler/mod.rs new file mode 100644 index 000000000..43a4bba2c --- /dev/null +++ b/src/socket_handler/mod.rs @@ -0,0 +1,136 @@ +// Copyright (c) 2017 The Rouille developers +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , +// at your option. All files in the project carrying such +// notice may not be copied, modified, or distributed except +// according to those terms. + +use std::net::SocketAddr; +use std::sync::Arc; +use mio::Registration; + +use Request; +use Response; + +use self::http1::Http1Handler; +use self::rustls::RustlsHandler; +pub use self::rustls::RustlsConfig; +pub use self::task_pool::TaskPool; // TODO: shouldn't be pub, but is used by Server, move it somewher else + +mod http1; +mod request_body_analyzer; +mod rustls; +mod task_pool; + +/// Parses the data received by a socket and returns the data to send back. +pub struct SocketHandlerDispatch { + inner: SocketHandlerDispatchInner, +} + +enum SocketHandlerDispatchInner { + Http(Http1Handler), + Https(RustlsHandler), +} + +impl SocketHandlerDispatch { + /// Initialization for HTTP. + pub fn http(client_addr: SocketAddr, task_pool: TaskPool, handler: F) + -> SocketHandlerDispatch + where F: FnMut(Request) -> Response + Send + 'static + { + let http_handler = Http1Handler::new(client_addr, Protocol::Http, task_pool, handler); + let inner = SocketHandlerDispatchInner::Http(http_handler); + SocketHandlerDispatch { + inner: inner, + } + } + + /// Initialization for HTTPS. + pub fn https(config: RustlsConfig, client_addr: SocketAddr, task_pool: TaskPool, handler: F) + -> SocketHandlerDispatch + where F: FnMut(Request) -> Response + Send + 'static + { + let http_handler = Http1Handler::new(client_addr, Protocol::Https, task_pool, handler); + let inner = SocketHandlerDispatchInner::Https(RustlsHandler::new(config, http_handler)); + SocketHandlerDispatch { + inner: inner, + } + } +} + +impl SocketHandler for SocketHandlerDispatch { + fn update(&mut self, update: &mut Update) -> UpdateResult { + match self.inner { + SocketHandlerDispatchInner::Http(ref mut http) => http.update(update), + SocketHandlerDispatchInner::Https(ref mut https) => https.update(update), + } + } +} + +/// Protocol that can serve HTTP. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Protocol { + Http, + Https, +} + +pub trait SocketHandler { + /// Call this function whenever new data is received on the socket, or when the registration + /// wakes up. + fn update(&mut self, update: &mut Update) -> UpdateResult; +} + +#[derive(Debug)] +pub struct UpdateResult { + /// When `Some`, means that the user must call `update` when the `Registration` becomes ready + /// (either for reading or writing). The registration should be registered with `oneshot()`. + pub registration: Option<(Arc, RegistrationState)>, + + /// Set to true if the socket handler will no longer process incoming data. If + /// `close_read` is true, `pending_write_buffer` is empty, and `registration` is empty, + /// then you can drop the socket. + pub close_read: bool, + + /// If true, the socket handler suggests to flush the content of the write buffer to the + /// socket. + pub write_flush_suggested: bool, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum RegistrationState { + /// It is the first time this registration is returned. + FirstTime, + /// This registration has been registered before, and `reregister` should be used. + Reregister, +} + +/// Represents the communication between the `SocketHandler` and the outside. +/// +/// The "outside" is supposed to fill `pending_read_buffer` with incoming data, and remove data +/// from `pending_write_buffer`, then call `update`. +#[derive(Debug)] +pub struct Update { + /// Filled by the handler user and emptied by `update()`. Contains the data that comes from + /// the client. + // TODO: try VecDeque and check perfs + pub pending_read_buffer: Vec, + + /// Filled by `SocketHandler::update()` and emptied by the user. Contains the data that must + /// be sent back to the client. + // TODO: try VecDeque and check perfs + pub pending_write_buffer: Vec, + +} + +impl Update { + /// Builds a new empty `Update`. + pub fn empty() -> Update { + // TODO: don't create two Vecs for each socket + Update { + pending_read_buffer: Vec::with_capacity(1024), + pending_write_buffer: Vec::with_capacity(1024), + } + } +} diff --git a/src/socket_handler/request_body_analyzer.rs b/src/socket_handler/request_body_analyzer.rs new file mode 100644 index 000000000..6d1f6bbc1 --- /dev/null +++ b/src/socket_handler/request_body_analyzer.rs @@ -0,0 +1,232 @@ +// Copyright (c) 2017 The Rouille developers +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , +// at your option. All files in the project carrying such +// notice may not be copied, modified, or distributed except +// according to those terms. + +use atoi::atoi; +use std::ascii::AsciiExt; +use std::cmp; +use std::mem; + +pub struct RequestBodyAnalyzer { + inner: RequestBodyAnalyzerInner, +} + +enum RequestBodyAnalyzerInner { + ContentLength { + // Remaining body length. + remaining_content_length: u64, + }, + ChunkedTransferEncoding { + // Remaining size of the chunk being read. `None` if we are not in a chunk. + remaining_chunk_size: Option, + }, +} + +impl RequestBodyAnalyzer { + /// Reads the request's headers to determine how the body will need to be handled. + pub fn new<'a, I>(headers: I) -> RequestBodyAnalyzer + where I: Iterator // TODO: should be [u8] eventually + { + let mut content_length = None; + let mut chunked = false; + + for (header, value) in headers { + if header.eq_ignore_ascii_case("Transfer-Encoding") { + if value.eq_ignore_ascii_case("chunked") { + chunked = true; + } + } + if header.eq_ignore_ascii_case("Content-Length") { + content_length = atoi(value.as_bytes()); + } + } + + RequestBodyAnalyzer { + inner: match (content_length, chunked) { + (_, true) => RequestBodyAnalyzerInner::ChunkedTransferEncoding { + remaining_chunk_size: None, + }, + (Some(len), _) => RequestBodyAnalyzerInner::ContentLength { + remaining_content_length: len, + }, + _ => { + // If we have neither a Content-Length nor a Transfer-Encoding, + // assuming that we have no data. + // TODO: could also be multipart/byteranges + RequestBodyAnalyzerInner::ContentLength { + remaining_content_length: 0, + } + }, + }, + } + } + + /// Processes some data. Call this method with a slice containing data received by the socket. + /// This method will "decode" them in place. The decoding always takes less space than the + /// input, so there's no buffering of any sort. + pub fn feed(&mut self, data: &mut [u8]) -> FeedOutcome { + match self.inner { + RequestBodyAnalyzerInner::ContentLength { ref mut remaining_content_length } => { + if (data.len() as u64) < *remaining_content_length { + *remaining_content_length -= data.len() as u64; + FeedOutcome { + body_data: data.len(), + unused_trailing: 0, + finished: *remaining_content_length == 0, + } + + } else { + FeedOutcome { + body_data: mem::replace(&mut *remaining_content_length, 0) as usize, + unused_trailing: 0, + finished: true, + } + } + }, + + RequestBodyAnalyzerInner::ChunkedTransferEncoding { ref mut remaining_chunk_size } => { + // `out_body_data` contains the number of bytes from the start of `data` that are + // already final. + // + // `out_unused_trailing` contains the number of bytes after `out_body_data` that + // are garbage. + // + // Therefore at any point during this algorithm, + // `out_body_data + out_unused_trailing` is the offset of the next byte of input. + // + // Incrementing `out_unused_trailing` means that we skip bytes from the input. + let mut out_body_data = 0; + let mut out_unused_trailing = 0; + + loop { + if remaining_chunk_size.is_none() { + match try_read_chunk_size(&data[out_body_data + out_unused_trailing..]) { + Some((skip, chunk_size)) => { + *remaining_chunk_size = Some(chunk_size); + debug_assert_ne!(skip, 0); + out_unused_trailing += skip; + }, + None => return FeedOutcome { + body_data: out_body_data, + unused_trailing: out_unused_trailing, + finished: false, + }, + } + } + + if *remaining_chunk_size == Some(0) { + return FeedOutcome { + body_data: out_body_data, + unused_trailing: out_unused_trailing, + finished: true, + } + } + + debug_assert!(out_body_data + out_unused_trailing <= data.len()); + if data.len() == out_body_data + out_unused_trailing { + return FeedOutcome { + body_data: out_body_data, + unused_trailing: out_unused_trailing, + finished: false, + }; + } + + let copy_len = cmp::min(data.len() - out_body_data - out_unused_trailing, + remaining_chunk_size.unwrap()); + if out_unused_trailing != 0 { + for n in 0 .. copy_len { + data[out_body_data + n] = data[out_body_data + out_unused_trailing + n]; + } + } + // FIXME: wrong because ignores trailing \r\n at end of chunks + out_body_data += copy_len; + *remaining_chunk_size.as_mut().unwrap() -= copy_len; + if *remaining_chunk_size == Some(0) { + *remaining_chunk_size = None; + } + } + }, + } + } +} + +/// Result of the `feed` method. +pub struct FeedOutcome { + /// Number of bytes from the start of `data` that contain the body of the request. If + /// `finished` is true, then any further byte is part of the next request. If `finished` is + /// false, then any further byte is still part of this request but hasn't been decoded yet. + pub body_data: usize, + + /// Number of bytes following `body_data` that are irrelevant and that should be discarded. + pub unused_trailing: usize, + + /// True if the request is finished. Calling `feed` again would return a `FeedOutcome` with a + /// `body_data` of 0. + pub finished: bool, +} + +// Tries to read a chunk size from `data`. Returns `None` if not enough data. +// Returns the number of bytes that make up the chunk size, and the chunk size value. +fn try_read_chunk_size(data: &[u8]) -> Option<(usize, usize)> { + let crlf_pos = match data.windows(2).position(|n| n == b"\r\n") { + Some(p) => p, + None => return None, + }; + + let chunk_size = match atoi(&data[..crlf_pos]) { + Some(s) => s, + None => return None, // TODO: error instead + }; + + Some((crlf_pos + 2, chunk_size)) +} + +#[cfg(test)] +mod tests { + use super::RequestBodyAnalyzer; + + #[test] + fn chunked_decode_one_buf() { + let mut analyzer = { + let headers = vec![("Transfer-Encoding", "chunked")]; + RequestBodyAnalyzer::new(headers.into_iter()) + }; + + let mut buffer = b"6\r\nhello 5\r\nworld0\r\n".to_vec(); + let outcome = analyzer.feed(&mut buffer); + + assert_eq!(outcome.body_data, 11); + assert_eq!(outcome.unused_trailing, 20 - 11); + assert!(outcome.finished); + assert_eq!(&buffer[..11], &b"hello world"[..]); + } + + #[test] + fn chunked_decode_multi_buf() { + let mut analyzer = { + let headers = vec![("Transfer-Encoding", "chunked")]; + RequestBodyAnalyzer::new(headers.into_iter()) + }; + + let mut buf1 = b"6\r\nhel".to_vec(); + let out1 = analyzer.feed(&mut buf1); + + let mut buf2 = b"lo 5\r\nworld0\r\n".to_vec(); + let out2 = analyzer.feed(&mut buf2); + + assert_eq!(out1.body_data, 3); + assert_eq!(out1.unused_trailing, 6 - 3); + assert!(!out1.finished); + assert_eq!(&buf1[..3], &b"hel"[..]); + + assert_eq!(out2.body_data, 8); + assert_eq!(out2.unused_trailing, 14 - 8); + assert!(out2.finished); + assert_eq!(&buf2[..8], &b"lo world"[..]); + } +} diff --git a/src/socket_handler/rustls.rs b/src/socket_handler/rustls.rs new file mode 100644 index 000000000..5bfdb545d --- /dev/null +++ b/src/socket_handler/rustls.rs @@ -0,0 +1,245 @@ +// Copyright (c) 2017 The Rouille developers +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , +// at your option. All files in the project carrying such +// notice may not be copied, modified, or distributed except +// according to those terms. + +use std::collections::HashMap; +use std::error::Error; +use std::fs::File; +use std::io::BufReader; +use std::io::Read; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use std::sync::Mutex; + +use rustls::Certificate; +use rustls::PrivateKey; +use rustls::ResolvesServerCert; +use rustls::ServerConfig; +use rustls::ServerSession; +use rustls::ServerSessionMemoryCache; +use rustls::Session; +use rustls::SignatureScheme; +use rustls::internal::pemfile; +use rustls::sign::CertChainAndSigner; +use rustls::sign::RSASigner; + +use socket_handler::SocketHandler; +use socket_handler::Update; +use socket_handler::UpdateResult; + +/// Configuration for HTTPS handling. +/// +/// This struct internally contains `Arc`s, which means that you can clone it for a cheap cost. +/// +/// Note that this configuration can be updated at runtime. Certificates can be added or removed +/// while the server is running. This will only affect new HTTP connections though. +#[derive(Clone)] +pub struct RustlsConfig { + config: Arc, + inner: RustlsConfigInner, +} + +#[derive(Clone)] +struct RustlsConfigInner { + certificates: Arc>>, +} + +impl ResolvesServerCert for RustlsConfigInner { + fn resolve(&self, server_name: Option<&str>, _: &[SignatureScheme]) + -> Option + { + let server_name = match server_name { + Some(s) => s, + None => return None, + }; + + let certificates = self.certificates.lock().unwrap(); + certificates + .get(server_name) + .map(|v| v.clone()) + } +} + +impl RustlsConfig { + /// Builds a new configuration. You should do this at initialization only. + /// + /// Once the configuration is created, you should add certificates to it. Otherwise people + /// won't be able to connect to it. + pub fn new() -> RustlsConfig { + let inner = RustlsConfigInner { + certificates: Arc::new(Mutex::new(HashMap::new())), + }; + + let mut config = ServerConfig::new(); + //config.alpn_protocols = vec!["http/1.1".to_owned()]; // TODO: + config.cert_resolver = Box::new(inner.clone()); + config.session_storage = Mutex::new(ServerSessionMemoryCache::new(1024)); + + RustlsConfig { + config: Arc::new(config), + inner: inner, + } + } + + /// Removes the certificate of a domain name. + pub fn remove_certificate(&self, domain_name: &str) { + let mut certificates = self.inner.certificates.lock().unwrap(); + certificates.remove(domain_name); + } + + /// Sets the certificate of a domain name. The certificates and private key are parsed from + /// PEM files whose path is passed as parameter. + /// + /// Replaces the existing certificate for this domain name if one has been set earlier. + pub fn set_certificate_from_pem(&self, domain_name: S, pub_pem: Pu, priv_pem: Pr) + -> Result<(), Box> + where S: Into, + Pu: AsRef, + Pr: AsRef + { + let pub_chain = load_certificates(pub_pem)?; + let priv_key = load_private_key(priv_pem)?; + let signer = RSASigner::new(&priv_key) + .map_err(|_| String::from("Failed to create RSASigner"))?; + + let mut certificates = self.inner.certificates.lock().unwrap(); + certificates.insert(domain_name.into(), (pub_chain, Arc::new(Box::new(signer) as Box<_>))); + Ok(()) + } +} + +/// Handles the processing of a client connection through TLS. +pub struct RustlsHandler { + // The inner handler. + handler: H, + // The Rustls session. + session: ServerSession, + // The update object to communicate with the handler. + handler_update: Update, +} + +impl RustlsHandler { + /// Starts handling a TLS connection. + /// + /// This struct only performs the encoding and decoding, while the actual handling is performed + /// by `inner`. + pub fn new(config: RustlsConfig, inner: H) -> RustlsHandler { + RustlsHandler { + handler: inner, + session: ServerSession::new(&config.config), + handler_update: Update::empty(), + } + } +} + +impl SocketHandler for RustlsHandler + where H: SocketHandler +{ + fn update(&mut self, update: &mut Update) -> UpdateResult { + // Pass outside data to the ServerSession. + match self.session.read_tls(&mut (&update.pending_read_buffer[..])) { + Ok(read_num) => { + assert_eq!(read_num, update.pending_read_buffer.len()); + update.pending_read_buffer.clear(); + }, + Err(_) => { + return UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: false, + }; + }, + }; + + if let Err(_) = self.session.process_new_packets() { + // Drop the socket in case of an error. + update.pending_write_buffer.clear(); + return UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: false, + }; + } + + // Pass data from the ServerSession to the inner handler. + if let Err(_) = self.session.read_to_end(&mut self.handler_update.pending_read_buffer) { + return UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: false, + }; + } + + // Call the inner handler. + let result = self.handler.update(&mut self.handler_update); + + // Pass data from the inner handler to the ServerSession. + match self.session.write_all(&self.handler_update.pending_write_buffer) { + Ok(_) => self.handler_update.pending_write_buffer.clear(), + Err(_) => { + return UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: false, + }; + } + }; + + // Pass data from the ServerSession to the outside. + while self.session.wants_write() { + if let Err(_) = self.session.write_tls(&mut update.pending_write_buffer) { + return UpdateResult { + registration: None, + close_read: true, + write_flush_suggested: true, + }; + } + } + + result + } +} + +// Load certificates chain from a PEM file. +fn load_certificates

(path: P) -> Result, Box> + where P: AsRef +{ + let file = File::open(path)?; + let mut reader = BufReader::new(file); + let certs = pemfile::certs(&mut reader) + .map_err(|_| String::from("Certificates PEM file contains invalid keys"))?; + Ok(certs) +} + +// Load private key from a PEM file. +fn load_private_key

(path: P) -> Result> + where P: AsRef +{ + let path = path.as_ref(); + + let mut rsa_keys = { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + pemfile::rsa_private_keys(&mut reader) + .map_err(|_| String::from("Private key PEM file contains invalid keys"))? + }; + + let mut pkcs8_keys = { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + pemfile::pkcs8_private_keys(&mut reader) + .map_err(|_| String::from("Private key PEM file contains invalid keys"))? + }; + + Ok(if !pkcs8_keys.is_empty() { + pkcs8_keys.remove(0) + } else { + rsa_keys.remove(0) + }) +} diff --git a/src/socket_handler/task_pool.rs b/src/socket_handler/task_pool.rs new file mode 100644 index 000000000..0b436d04e --- /dev/null +++ b/src/socket_handler/task_pool.rs @@ -0,0 +1,79 @@ +// Copyright 2015 The tiny-http Contributors +// Copyright (c) 2017 The Rouille developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::thread; +use crossbeam::sync::MsQueue; +use num_cpus; + +/// Manages a collection of threads. +#[derive(Clone)] +pub struct TaskPool { + sharing: Arc, +} + +struct Sharing { + // List of the tasks to be done by worker threads. + // + // If the task returns `true` then the worker thread must continue. Otherwise it must stop. + // This feature is necessary in order to be able to stop worker threads. + todo: MsQueue bool + Send>>, +} + +impl TaskPool { + /// Initializes a new task pool. + pub fn new() -> TaskPool { + let pool = TaskPool { + sharing: Arc::new(Sharing { + todo: MsQueue::new(), + }), + }; + + for _ in 0..num_cpus::get() { + let sharing = pool.sharing.clone(); + thread::spawn(move || { + loop { + let mut task = sharing.todo.pop(); + if !task() { + break; + } + } + }); + } + + pool + } + + /// Executes a function in a worker thread. + #[inline] + pub fn spawn(&self, code: F) + where F: FnOnce() + Send + 'static + { + let mut code = Some(code); + self.sharing.todo.push(Box::new(move || { + let code = code.take().unwrap(); + code(); + true + })); + } +} + +impl Drop for Sharing { + fn drop(&mut self) { + for _ in 0 .. num_cpus::get() { + self.todo.push(Box::new(|| false)); + } + } +}