diff --git a/src/reject.rs b/src/reject.rs deleted file mode 100644 index 533444321..000000000 --- a/src/reject.rs +++ /dev/null @@ -1,853 +0,0 @@ -//! Rejections -//! -//! Part of the power of the [`Filter`](../trait.Filter.html) system is being able to -//! reject a request from a filter chain. This allows for filters to be -//! combined with `or`, so that if one side of the chain finds that a request -//! doesn't fulfill its requirements, the other side can try to process -//! the request. -//! -//! Many of the built-in [`filters`](../filters) will automatically reject -//! the request with an appropriate rejection. However, you can also build -//! new custom [`Filter`](../trait.Filter.html)s and still want other routes to be -//! matchable in the case a predicate doesn't hold. -//! -//! As a request is processed by a Filter chain, the rejections are accumulated into -//! a list contained by the [`Rejection`](struct.Rejection.html) type. Rejections from -//! filters can be handled using [`Filter::recover`](../trait.Filter.html#method.recover). -//! This is a convenient way to map rejections into a [`Reply`](../reply/trait.Reply.html). -//! -//! For a more complete example see the -//! [Rejection Example](https://github.com/seanmonstar/warp/blob/master/examples/rejections.rs) -//! from the repository. -//! -//! # Example -//! -//! ``` -//! use warp::{reply, Reply, Filter, reject, Rejection, http::StatusCode}; -//! -//! #[derive(Debug)] -//! struct InvalidParameter; -//! -//! impl reject::Reject for InvalidParameter {}; -//! -//! // Custom rejection handler that maps rejections into responses. -//! async fn handle_rejection(err: Rejection) -> Result { -//! if err.is_not_found() { -//! Ok(reply::with_status("NOT_FOUND", StatusCode::NOT_FOUND)) -//! } else if let Some(e) = err.find::() { -//! Ok(reply::with_status("BAD_REQUEST", StatusCode::BAD_REQUEST)) -//! } else { -//! eprintln!("unhandled rejection: {:?}", err); -//! Ok(reply::with_status("INTERNAL_SERVER_ERROR", StatusCode::INTERNAL_SERVER_ERROR)) -//! } -//! } -//! -//! #[tokio::main] -//! async fn main() { -//! -//! // Filter on `/:id`, but reject with InvalidParameter if the `id` is `0`. -//! // Recover from this rejection using a custom rejection handler. -//! let route = warp::path::param() -//! .and_then(|id: u32| async move { -//! if id == 0 { -//! Err(warp::reject::custom(InvalidParameter)) -//! } else { -//! Ok("id is valid") -//! } -//! }) -//! .recover(handle_rejection); -//! -//! warp::serve(route).run(([127, 0, 0, 1], 3030)).await; -//! } -//! ``` - -use std::any::Any; -use std::convert::Infallible; -use std::error::Error as StdError; -use std::fmt; - -use http::{ - self, - header::{HeaderValue, CONTENT_TYPE}, - StatusCode, -}; -use hyper::Body; - -pub(crate) use self::sealed::{CombineRejection, IsReject}; - -/// Rejects a request with `404 Not Found`. -#[inline] -pub fn reject() -> Rejection { - not_found() -} - -/// Rejects a request with `404 Not Found`. -#[inline] -pub fn not_found() -> Rejection { - Rejection { - reason: Reason::NotFound, - } -} - -// 400 Bad Request -#[inline] -pub(crate) fn invalid_query() -> Rejection { - known(InvalidQuery { _p: () }) -} - -// 400 Bad Request -#[inline] -pub(crate) fn missing_header(name: &'static str) -> Rejection { - known(MissingHeader { name }) -} - -// 400 Bad Request -#[inline] -pub(crate) fn invalid_header(name: &'static str) -> Rejection { - known(InvalidHeader { name }) -} - -// 400 Bad Request -#[inline] -pub(crate) fn missing_cookie(name: &'static str) -> Rejection { - known(MissingCookie { name }) -} - -// 405 Method Not Allowed -#[inline] -pub(crate) fn method_not_allowed() -> Rejection { - known(MethodNotAllowed { _p: () }) -} - -// 411 Length Required -#[inline] -pub(crate) fn length_required() -> Rejection { - known(LengthRequired { _p: () }) -} - -// 413 Payload Too Large -#[inline] -pub(crate) fn payload_too_large() -> Rejection { - known(PayloadTooLarge { _p: () }) -} - -// 415 Unsupported Media Type -// -// Used by the body filters if the request payload content-type doesn't match -// what can be deserialized. -#[inline] -pub(crate) fn unsupported_media_type() -> Rejection { - known(UnsupportedMediaType { _p: () }) -} - -/// Rejects a request with a custom cause. -/// -/// A [`recover`][] filter should convert this `Rejection` into a `Reply`, -/// or else this will be returned as a `500 Internal Server Error`. -/// -/// [`recover`]: ../trait.Filter.html#method.recover -pub fn custom(err: T) -> Rejection { - Rejection::custom(Box::new(err)) -} - -/// Protect against re-rejecting a rejection. -/// -/// ```compile_fail -/// fn with(r: warp::Rejection) { -/// let _wat = warp::reject::custom(r); -/// } -/// ``` -fn __reject_custom_compilefail() {} - -/// A marker trait to ensure proper types are used for custom rejections. -/// -/// Can be converted into Rejection. -/// -/// # Example -/// -/// ``` -/// use warp::{Filter, reject::Reject}; -/// -/// #[derive(Debug)] -/// struct RateLimited; -/// -/// impl Reject for RateLimited {} -/// -/// let route = warp::any().and_then(|| async { -/// Err::<(), _>(warp::reject::custom(RateLimited)) -/// }); -/// ``` -// Require `Sized` for now to prevent passing a `Box`, since we -// would be double-boxing it, and the downcasting wouldn't work as expected. -pub trait Reject: fmt::Debug + Sized + Send + Sync + 'static {} - -trait Cause: fmt::Debug + Send + Sync + 'static { - fn as_any(&self) -> &dyn Any; -} - -impl Cause for T -where - T: fmt::Debug + Send + Sync + 'static, -{ - fn as_any(&self) -> &dyn Any { - self - } -} - -impl dyn Cause { - fn downcast_ref(&self) -> Option<&T> { - self.as_any().downcast_ref::() - } -} - -pub(crate) fn known>(err: T) -> Rejection { - Rejection::known(err.into()) -} - -/// Rejection of a request by a [`Filter`](crate::Filter). -/// -/// See the [`reject`](module@crate::reject) documentation for more. -pub struct Rejection { - reason: Reason, -} - -enum Reason { - NotFound, - Other(Box), -} - -enum Rejections { - Known(Known), - Custom(Box), - Combined(Box, Box), -} - -macro_rules! enum_known { - ($($(#[$attr:meta])* $var:ident($ty:path),)+) => ( - pub(crate) enum Known { - $( - $(#[$attr])* - $var($ty), - )+ - } - - impl Known { - fn inner_as_any(&self) -> &dyn Any { - match *self { - $( - $(#[$attr])* - Known::$var(ref t) => t, - )+ - } - } - } - - impl fmt::Debug for Known { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - $( - $(#[$attr])* - Known::$var(ref t) => t.fmt(f), - )+ - } - } - } - - impl fmt::Display for Known { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - $( - $(#[$attr])* - Known::$var(ref t) => t.fmt(f), - )+ - } - } - } - - $( - #[doc(hidden)] - $(#[$attr])* - impl From<$ty> for Known { - fn from(ty: $ty) -> Known { - Known::$var(ty) - } - } - )+ - ); -} - -enum_known! { - MethodNotAllowed(MethodNotAllowed), - InvalidHeader(InvalidHeader), - MissingHeader(MissingHeader), - MissingCookie(MissingCookie), - InvalidQuery(InvalidQuery), - LengthRequired(LengthRequired), - PayloadTooLarge(PayloadTooLarge), - UnsupportedMediaType(UnsupportedMediaType), - FileOpenError(crate::fs::FileOpenError), - FilePermissionError(crate::fs::FilePermissionError), - BodyReadError(crate::body::BodyReadError), - BodyDeserializeError(crate::body::BodyDeserializeError), - CorsForbidden(crate::cors::CorsForbidden), - #[cfg(feature = "websocket")] - MissingConnectionUpgrade(crate::ws::MissingConnectionUpgrade), - MissingExtension(crate::ext::MissingExtension), - BodyConsumedMultipleTimes(crate::body::BodyConsumedMultipleTimes), -} - -impl Rejection { - fn known(known: Known) -> Self { - Rejection { - reason: Reason::Other(Box::new(Rejections::Known(known))), - } - } - - fn custom(other: Box) -> Self { - Rejection { - reason: Reason::Other(Box::new(Rejections::Custom(other))), - } - } - - /// Searches this `Rejection` for a specific cause. - /// - /// A `Rejection` will accumulate causes over a `Filter` chain. This method - /// can search through them and return the first cause of this type. - /// - /// # Example - /// - /// ``` - /// #[derive(Debug)] - /// struct Nope; - /// - /// impl warp::reject::Reject for Nope {} - /// - /// let reject = warp::reject::custom(Nope); - /// - /// if let Some(nope) = reject.find::() { - /// println!("found it: {:?}", nope); - /// } - /// ``` - pub fn find(&self) -> Option<&T> { - if let Reason::Other(ref rejections) = self.reason { - return rejections.find(); - } - None - } - - /// Returns true if this Rejection was made via `warp::reject::not_found`. - /// - /// # Example - /// - /// ``` - /// let rejection = warp::reject(); - /// - /// assert!(rejection.is_not_found()); - /// ``` - pub fn is_not_found(&self) -> bool { - if let Reason::NotFound = self.reason { - true - } else { - false - } - } -} - -impl From for Rejection { - #[inline] - fn from(err: T) -> Rejection { - custom(err) - } -} - -impl From for Rejection { - #[inline] - fn from(infallible: Infallible) -> Rejection { - match infallible {} - } -} - -impl IsReject for Infallible { - fn status(&self) -> StatusCode { - match *self {} - } - - fn into_response(&self) -> crate::reply::Response { - match *self {} - } -} - -impl IsReject for Rejection { - fn status(&self) -> StatusCode { - match self.reason { - Reason::NotFound => StatusCode::NOT_FOUND, - Reason::Other(ref other) => other.status(), - } - } - - fn into_response(&self) -> crate::reply::Response { - match self.reason { - Reason::NotFound => { - let mut res = http::Response::default(); - *res.status_mut() = StatusCode::NOT_FOUND; - res - } - Reason::Other(ref other) => other.into_response(), - } - } -} - -impl fmt::Debug for Rejection { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_tuple("Rejection").field(&self.reason).finish() - } -} - -impl fmt::Debug for Reason { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Reason::NotFound => f.write_str("NotFound"), - Reason::Other(ref other) => match **other { - Rejections::Known(ref e) => fmt::Debug::fmt(e, f), - Rejections::Custom(ref e) => fmt::Debug::fmt(e, f), - Rejections::Combined(ref a, ref b) => { - let mut list = f.debug_list(); - a.debug_list(&mut list); - b.debug_list(&mut list); - list.finish() - } - }, - } - } -} - -// ===== Rejections ===== - -impl Rejections { - fn status(&self) -> StatusCode { - match *self { - Rejections::Known(ref k) => match *k { - Known::MethodNotAllowed(_) => StatusCode::METHOD_NOT_ALLOWED, - Known::InvalidHeader(_) - | Known::MissingHeader(_) - | Known::MissingCookie(_) - | Known::InvalidQuery(_) - | Known::BodyReadError(_) - | Known::BodyDeserializeError(_) => StatusCode::BAD_REQUEST, - #[cfg(feature = "websocket")] - Known::MissingConnectionUpgrade(_) => StatusCode::BAD_REQUEST, - Known::LengthRequired(_) => StatusCode::LENGTH_REQUIRED, - Known::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE, - Known::UnsupportedMediaType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE, - Known::FilePermissionError(_) | Known::CorsForbidden(_) => StatusCode::FORBIDDEN, - Known::FileOpenError(_) - | Known::MissingExtension(_) - | Known::BodyConsumedMultipleTimes(_) => StatusCode::INTERNAL_SERVER_ERROR, - }, - Rejections::Custom(..) => StatusCode::INTERNAL_SERVER_ERROR, - Rejections::Combined(ref a, ref b) => preferred(a, b).status(), - } - } - - fn into_response(&self) -> crate::reply::Response { - match *self { - Rejections::Known(ref e) => { - let mut res = http::Response::new(Body::from(e.to_string())); - *res.status_mut() = self.status(); - res.headers_mut().insert( - CONTENT_TYPE, - HeaderValue::from_static("text/plain; charset=utf-8"), - ); - res - } - Rejections::Custom(ref e) => { - tracing::error!( - "unhandled custom rejection, returning 500 response: {:?}", - e - ); - let body = format!("Unhandled rejection: {:?}", e); - let mut res = http::Response::new(Body::from(body)); - *res.status_mut() = self.status(); - res.headers_mut().insert( - CONTENT_TYPE, - HeaderValue::from_static("text/plain; charset=utf-8"), - ); - res - } - Rejections::Combined(ref a, ref b) => preferred(a, b).into_response(), - } - } - - fn find(&self) -> Option<&T> { - match *self { - Rejections::Known(ref e) => e.inner_as_any().downcast_ref(), - Rejections::Custom(ref e) => e.downcast_ref(), - Rejections::Combined(ref a, ref b) => a.find().or_else(|| b.find()), - } - } - - fn debug_list(&self, f: &mut fmt::DebugList<'_, '_>) { - match *self { - Rejections::Known(ref e) => { - f.entry(e); - } - Rejections::Custom(ref e) => { - f.entry(e); - } - Rejections::Combined(ref a, ref b) => { - a.debug_list(f); - b.debug_list(f); - } - } - } -} - -fn preferred<'a>(a: &'a Rejections, b: &'a Rejections) -> &'a Rejections { - // Compare status codes, with this priority: - // - NOT_FOUND is lowest - // - METHOD_NOT_ALLOWED is second - // - if one status code is greater than the other - // - otherwise, prefer A... - match (a.status(), b.status()) { - (_, StatusCode::NOT_FOUND) => a, - (StatusCode::NOT_FOUND, _) => b, - (_, StatusCode::METHOD_NOT_ALLOWED) => a, - (StatusCode::METHOD_NOT_ALLOWED, _) => b, - (sa, sb) if sa < sb => b, - _ => a, - } -} - -unit_error! { - /// Invalid query - pub InvalidQuery: "Invalid query string" -} - -unit_error! { - /// HTTP method not allowed - pub MethodNotAllowed: "HTTP method not allowed" -} - -unit_error! { - /// A content-length header is required - pub LengthRequired: "A content-length header is required" -} - -unit_error! { - /// The request payload is too large - pub PayloadTooLarge: "The request payload is too large" -} - -unit_error! { - /// The request's content-type is not supported - pub UnsupportedMediaType: "The request's content-type is not supported" -} - -/// Missing request header -#[derive(Debug)] -pub struct MissingHeader { - name: &'static str, -} - -impl MissingHeader { - /// Retrieve the name of the header that was missing - pub fn name(&self) -> &str { - self.name - } -} - -impl ::std::fmt::Display for MissingHeader { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "Missing request header {:?}", self.name) - } -} - -impl StdError for MissingHeader {} - -/// Invalid request header -#[derive(Debug)] -pub struct InvalidHeader { - name: &'static str, -} - -impl InvalidHeader { - /// Retrieve the name of the header that was invalid - pub fn name(&self) -> &str { - self.name - } -} - -impl ::std::fmt::Display for InvalidHeader { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "Invalid request header {:?}", self.name) - } -} - -impl StdError for InvalidHeader {} - -/// Missing cookie -#[derive(Debug)] -pub struct MissingCookie { - name: &'static str, -} - -impl MissingCookie { - /// Retrieve the name of the cookie that was missing - pub fn name(&self) -> &str { - self.name - } -} - -impl ::std::fmt::Display for MissingCookie { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "Missing request cookie {:?}", self.name) - } -} - -impl StdError for MissingCookie {} - -mod sealed { - use super::{Reason, Rejection, Rejections}; - use http::StatusCode; - use std::convert::Infallible; - use std::fmt; - - // This sealed trait exists to allow Filters to return either `Rejection` - // or `!`. There are no other types that make sense, and so it is sealed. - pub trait IsReject: fmt::Debug + Send + Sync { - fn status(&self) -> StatusCode; - fn into_response(&self) -> crate::reply::Response; - } - - fn _assert_object_safe() { - fn _assert(_: &dyn IsReject) {} - } - - // This weird trait is to allow optimizations of propagating when a - // rejection can *never* happen (currently with the `Never` type, - // eventually to be replaced with `!`). - // - // Using this trait means the `Never` gets propagated to chained filters, - // allowing LLVM to eliminate more code paths. Without it, such as just - // requiring that `Rejection::from(Never)` were used in those filters, - // would mean that links later in the chain may assume a rejection *could* - // happen, and no longer eliminate those branches. - pub trait CombineRejection: Send + Sized { - /// The type that should be returned when only 1 of the two - /// "rejections" occurs. - /// - /// # For example: - /// - /// `warp::any().and(warp::path("foo"))` has the following steps: - /// - /// 1. Since this is `and`, only **one** of the rejections will occur, - /// and as soon as it does, it will be returned. - /// 2. `warp::any()` rejects with `Never`. So, it will never return `Never`. - /// 3. `warp::path()` rejects with `Rejection`. It may return `Rejection`. - /// - /// Thus, if the above filter rejects, it will definitely be `Rejection`. - type One: IsReject + From + From + Into; - - /// The type that should be returned when both rejections occur, - /// and need to be combined. - type Combined: IsReject; - - fn combine(self, other: E) -> Self::Combined; - } - - impl CombineRejection for Rejection { - type One = Rejection; - type Combined = Rejection; - - fn combine(self, other: Rejection) -> Self::Combined { - let reason = match (self.reason, other.reason) { - (Reason::Other(left), Reason::Other(right)) => { - Reason::Other(Box::new(Rejections::Combined(left, right))) - } - (Reason::Other(other), Reason::NotFound) - | (Reason::NotFound, Reason::Other(other)) => { - // ignore the NotFound - Reason::Other(other) - } - (Reason::NotFound, Reason::NotFound) => Reason::NotFound, - }; - - Rejection { reason } - } - } - - impl CombineRejection for Rejection { - type One = Rejection; - type Combined = Infallible; - - fn combine(self, other: Infallible) -> Self::Combined { - match other {} - } - } - - impl CombineRejection for Infallible { - type One = Rejection; - type Combined = Infallible; - - fn combine(self, _: Rejection) -> Self::Combined { - match self {} - } - } - - impl CombineRejection for Infallible { - type One = Infallible; - type Combined = Infallible; - - fn combine(self, _: Infallible) -> Self::Combined { - match self {} - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use http::StatusCode; - - #[derive(Debug, PartialEq)] - struct Left; - - #[derive(Debug, PartialEq)] - struct Right; - - impl Reject for Left {} - impl Reject for Right {} - - #[test] - fn rejection_status() { - assert_eq!(not_found().status(), StatusCode::NOT_FOUND); - assert_eq!( - method_not_allowed().status(), - StatusCode::METHOD_NOT_ALLOWED - ); - assert_eq!(length_required().status(), StatusCode::LENGTH_REQUIRED); - assert_eq!(payload_too_large().status(), StatusCode::PAYLOAD_TOO_LARGE); - assert_eq!( - unsupported_media_type().status(), - StatusCode::UNSUPPORTED_MEDIA_TYPE - ); - assert_eq!(custom(Left).status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[tokio::test] - async fn combine_rejection_causes_with_some_left_and_none_right() { - let left = custom(Left); - let right = not_found(); - let reject = left.combine(right); - let resp = reject.into_response(); - - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - response_body_string(resp).await, - "Unhandled rejection: Left" - ) - } - - #[tokio::test] - async fn combine_rejection_causes_with_none_left_and_some_right() { - let left = not_found(); - let right = custom(Right); - let reject = left.combine(right); - let resp = reject.into_response(); - - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - response_body_string(resp).await, - "Unhandled rejection: Right" - ) - } - - #[tokio::test] - async fn unhandled_customs() { - let reject = not_found().combine(custom(Right)); - - let resp = reject.into_response(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - response_body_string(resp).await, - "Unhandled rejection: Right" - ); - - // There's no real way to determine which is worse, since both are a 500, - // so pick the first one. - let reject = custom(Left).combine(custom(Right)); - - let resp = reject.into_response(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - response_body_string(resp).await, - "Unhandled rejection: Left" - ); - - // With many rejections, custom still is top priority. - let reject = not_found() - .combine(not_found()) - .combine(not_found()) - .combine(custom(Right)) - .combine(not_found()); - - let resp = reject.into_response(); - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - response_body_string(resp).await, - "Unhandled rejection: Right" - ); - } - - async fn response_body_string(resp: crate::reply::Response) -> String { - let (_, body) = resp.into_parts(); - let body_bytes = hyper::body::to_bytes(body).await.expect("failed concat"); - String::from_utf8_lossy(&body_bytes).to_string() - } - - #[test] - fn find_cause() { - let rej = custom(Left); - - assert_eq!(rej.find::(), Some(&Left)); - - let rej = rej.combine(method_not_allowed()); - - assert_eq!(rej.find::(), Some(&Left)); - assert!(rej.find::().is_some(), "MethodNotAllowed"); - } - - #[test] - fn size_of_rejection() { - assert_eq!( - ::std::mem::size_of::(), - ::std::mem::size_of::(), - ); - } - - #[derive(Debug)] - struct X(u32); - impl Reject for X {} - - fn combine_n(n: u32, new_reject: F) -> Rejection - where - F: Fn(u32) -> R, - R: Reject, - { - let mut rej = not_found(); - - for i in 0..n { - rej = rej.combine(custom(new_reject(i))); - } - - rej - } - - #[test] - fn test_debug() { - let rej = combine_n(3, X); - - let s = format!("{:?}", rej); - assert_eq!(s, "Rejection([X(0), X(1), X(2)])"); - } -} diff --git a/src/reject/mod.rs b/src/reject/mod.rs new file mode 100644 index 000000000..75270d2ae --- /dev/null +++ b/src/reject/mod.rs @@ -0,0 +1,635 @@ +//! Mismatch +//! +//! Part of the power of the [`Filter`](../trait.Filter.html) system is being able to +//! reject a request from a filter chain. This allows for filters to be +//! combined with `or`, so that if one side of the chain finds that a request +//! doesn't fulfill its requirements, the other side can try to process +//! the request. +//! +//! Many of the built-in [`filters`](../filters) will automatically reject +//! the request with an appropriate rejection. However, you can also build +//! new custom [`Filter`](../trait.Filter.html)s and still want other routes to be +//! matchable in the case a predicate doesn't hold. +//! +//! As a request is processed by a Filter chain, the rejections are accumulated into +//! a list contained by the [`Rejection`](struct.Rejection.html) type. Mismatch from +//! filters can be handled using [`Filter::recover`](../trait.Filter.html#method.recover). +//! This is a convenient way to map rejections into a [`Reply`](../reply/trait.Reply.html). +//! +//! For a more complete example see the +//! [Rejection Example](https://github.com/seanmonstar/warp/blob/master/examples/rejections.rs) +//! from the repository. +//! +//! # Example +//! +//! ``` +//! use warp::{reply, Reply, Filter, reject, Rejection, http::StatusCode}; +//! +//! #[derive(Debug)] +//! struct InvalidParameter; +//! +//! impl reject::Reject for InvalidParameter {}; +//! +//! // Custom rejection handler that maps rejections into responses. +//! async fn handle_rejection(err: Rejection) -> Result { +//! if err.is_not_found() { +//! Ok(reply::with_status("NOT_FOUND", StatusCode::NOT_FOUND)) +//! } else if let Some(e) = err.find::() { +//! Ok(reply::with_status("BAD_REQUEST", StatusCode::BAD_REQUEST)) +//! } else { +//! eprintln!("unhandled rejection: {:?}", err); +//! Ok(reply::with_status("INTERNAL_SERVER_ERROR", StatusCode::INTERNAL_SERVER_ERROR)) +//! } +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! +//! // Filter on `/:id`, but reject with InvalidParameter if the `id` is `0`. +//! // Recover from this rejection using a custom rejection handler. +//! let route = warp::path::param() +//! .and_then(|id: u32| async move { +//! if id == 0 { +//! Err(warp::reject::custom(InvalidParameter)) +//! } else { +//! Ok("id is valid") +//! } +//! }) +//! .recover(handle_rejection); +//! +//! warp::serve(route).run(([127, 0, 0, 1], 3030)).await; +//! } +//! ``` + +use std::any::Any; +use std::convert::Infallible; +use std::error::Error as StdError; +use std::fmt; + +use http::{ + self, + header::{HeaderValue, CONTENT_TYPE}, + StatusCode, +}; +use hyper::Body; + +use crate::Reply; + +pub(crate) use self::sealed::{CombineRejection, IsReject}; + +/// Rejects a request with `404 Not Found`. +#[inline] +pub fn reject() -> Rejection { + not_found() +} + +/// Rejects a request with `404 Not Found`. +#[inline] +pub fn not_found() -> Rejection { + Rejection { + reason: Reason::Mismatch(Vec::new()), + } +} + +// 400 Bad Request +#[inline] +pub(crate) fn invalid_query() -> Rejection { + known(InvalidQuery { _p: () }) +} + +// 400 Bad Request +#[inline] +pub(crate) fn missing_header(name: &'static str) -> Rejection { + known(MissingHeader { name }) +} + +// 400 Bad Request +#[inline] +pub(crate) fn invalid_header(name: &'static str) -> Rejection { + known(InvalidHeader { name }) +} + +// 400 Bad Request +#[inline] +pub(crate) fn missing_cookie(name: &'static str) -> Rejection { + known(MissingCookie { name }) +} + +// 405 Method Not Allowed +#[inline] +pub(crate) fn method_not_allowed() -> Rejection { + known(MethodNotAllowed { _p: () }) +} + +// 411 Length Required +#[inline] +pub(crate) fn length_required() -> Rejection { + known(LengthRequired { _p: () }) +} + +// 413 Payload Too Large +#[inline] +pub(crate) fn payload_too_large() -> Rejection { + known(PayloadTooLarge { _p: () }) +} + +// 415 Unsupported Media Type +// +// Used by the body filters if the request payload content-type doesn't match +// what can be deserialized. +#[inline] +pub(crate) fn unsupported_media_type() -> Rejection { + known(UnsupportedMediaType { _p: () }) +} + +/// Rejects a request with a custom cause. +/// +/// A [`recover`][] filter should convert this `Rejection` into a `Reply`, +/// or else this will be returned as a `500 Internal Server Error`. +/// +/// [`recover`]: ../trait.Filter.html#method.recover +pub fn custom(err: T) -> Rejection { + Rejection::custom(Box::new(err)) +} + +/// Creates a fatal rejection, which prevents any further matching +/// +/// It is converted directly to a Response and sent. +pub fn fatal(err: T) -> Rejection{ + Rejection::fatal(err) +} + +/// Protect against re-rejecting a rejection. +/// +/// ```compile_fail +/// fn with(r: warp::Rejection) { +/// let _wat = warp::reject::custom(r); +/// } +/// ``` +fn __reject_custom_compilefail() {} + +/// A marker trait to ensure proper types are used for custom rejections. +/// +/// Can be converted into Rejection. +/// +/// # Example +/// +/// ``` +/// use warp::{Filter, reject::Reject}; +/// +/// #[derive(Debug)] +/// struct RateLimited; +/// +/// impl Reject for RateLimited {} +/// +/// let route = warp::any().and_then(|| async { +/// Err::<(), _>(warp::reject::custom(RateLimited)) +/// }); +/// ``` +// Require `Sized` for now to prevent passing a `Box`, since we +// would be double-boxing it, and the downcasting wouldn't work as expected. +pub trait Reject: fmt::Debug + Sized + Send + Sync + 'static {} + +/// A marker trait to ensure proper types are used for fatal rejections. +/// +/// Can be converted into Rejection. +/// +/// # Example +/// +/// ``` +/// use warp::{Filter, reply::{Reply, Response, html, with_status}, http::StatusCode, reject::Fatal}; +/// +/// #[derive(Debug)] +/// struct DbError; +/// impl Reply for DbError { +/// fn into_response(self) -> Response { with_status(html("DB error"), StatusCode::INTERNAL_SERVER_ERROR).into_response() } +/// } +/// impl Fatal for DbError {} +/// +/// let route = warp::any().and_then(|| async { +/// Err::<(), _>(warp::reject::fatal(DbError)) +/// }); +/// ``` +pub trait Fatal: Reply + Send + 'static {} + + +trait Cause: fmt::Debug + Send + Sync + 'static { + fn as_any(&self) -> &dyn Any; +} + +impl Cause for T +where + T: fmt::Debug + Send + Sync + 'static, +{ + fn as_any(&self) -> &dyn Any { + self + } +} + +impl dyn Cause { + fn downcast_ref(&self) -> Option<&T> { + self.as_any().downcast_ref::() + } +} + +pub(crate) fn known>(err: T) -> Rejection { + Rejection::known(err.into()) +} + +/// Rejection of a request by a [`Filter`](crate::Filter). +/// +/// See the [`reject`](module@crate::reject) documentation for more. +#[derive(Debug)] +pub struct Rejection { + reason: Reason, +} + +#[derive(Debug)] +enum Reason { + Mismatch(Vec), + Fatal(Box), +} + +#[derive(Debug)] +enum Mismatch{ + Custom(Box), + Known(Known), +} + +macro_rules! enum_known { + ($($(#[$attr:meta])* $var:ident($ty:path),)+) => ( + pub(crate) enum Known { + $( + $(#[$attr])* + $var($ty), + )+ + } + + impl Known { + fn inner_as_any(&self) -> &dyn Any { + match *self { + $( + $(#[$attr])* + Known::$var(ref t) => t, + )+ + } + } + } + + impl fmt::Debug for Known { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + $( + $(#[$attr])* + Known::$var(ref t) => t.fmt(f), + )+ + } + } + } + + impl fmt::Display for Known { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + $( + $(#[$attr])* + Known::$var(ref t) => t.fmt(f), + )+ + } + } + } + + $( + #[doc(hidden)] + $(#[$attr])* + impl From<$ty> for Known { + fn from(ty: $ty) -> Known { + Known::$var(ty) + } + } + )+ + ); +} + +enum_known! { + MethodNotAllowed(MethodNotAllowed), + InvalidHeader(InvalidHeader), + MissingHeader(MissingHeader), + MissingCookie(MissingCookie), + InvalidQuery(InvalidQuery), + LengthRequired(LengthRequired), + PayloadTooLarge(PayloadTooLarge), + UnsupportedMediaType(UnsupportedMediaType), + FileOpenError(crate::fs::FileOpenError), + FilePermissionError(crate::fs::FilePermissionError), + BodyReadError(crate::body::BodyReadError), + BodyDeserializeError(crate::body::BodyDeserializeError), + CorsForbidden(crate::cors::CorsForbidden), + #[cfg(feature = "websocket")] + MissingConnectionUpgrade(crate::ws::MissingConnectionUpgrade), + MissingExtension(crate::ext::MissingExtension), + BodyConsumedMultipleTimes(crate::body::BodyConsumedMultipleTimes), +} + +impl Rejection { + fn known(known: Known) -> Self { + Rejection { + reason: Reason::Mismatch(vec![Mismatch::Known(known)]), + } + } + + fn custom(other: Box) -> Self { + Rejection { + reason: Reason::Mismatch(vec![Mismatch::Custom(other)]), + } + } + + fn fatal(err: T) -> Self { + Rejection{ + reason:Reason::Fatal(Box::new(err.into_response())), + } + } + + /// Searches this `Rejection` for a specific cause. + /// + /// A `Rejection` will accumulate causes over a `Filter` chain. This method + /// can search through them and return the first cause of this type. + /// + /// # Example + /// + /// ``` + /// #[derive(Debug)] + /// struct Nope; + /// + /// impl warp::reject::Reject for Nope {} + /// + /// let reject = warp::reject::custom(Nope); + /// + /// if let Some(nope) = reject.find::() { + /// println!("found it: {:?}", nope); + /// } + /// ``` + pub fn find(&self) -> Option<&T> { + if let Reason::Mismatch(ref mismatches) = self.reason { + for mismatch in mismatches { + if let Some(requested) = mismatch.find() { + return Some(requested); + } + } + } + None + } + + /// Returns true if this Rejection was made via `warp::reject::not_found`. + /// + /// # Example + /// + /// ``` + /// let rejection = warp::reject(); + /// + /// assert!(rejection.is_not_found()); + /// ``` + pub fn is_not_found(&self) -> bool { + match &self.reason { + Reason::Mismatch(m) => m.is_empty(), + _ => false, + } + } +} + +//impl From for Rejection { +// #[inline] +// fn from(err: T) -> Rejection { +// fatal(err) +// } +//} + +impl From for Rejection { + #[inline] + fn from(err: T) -> Rejection { + custom(err) + } +} + +impl From for Rejection { + #[inline] + fn from(infallible: Infallible) -> Rejection { + match infallible {} + } +} + +fn preferred<'a>(mismatches: &'a Vec) -> Option<&'a Mismatch> { + // Compare status codes, with this priority: + // - NOT_FOUND is lowest + // - METHOD_NOT_ALLOWED is second + // - if one status code is greater than the other + // - otherwise, prefer A... + let mut tmp = None; + let mut tmp_status = StatusCode::NOT_FOUND; + for mismatch in mismatches { + match (tmp_status, Some(mismatch).status()) { + (_, StatusCode::NOT_FOUND) => (), + (StatusCode::NOT_FOUND, _) => { + tmp = Some(mismatch); + tmp_status = tmp.status(); + }, + (_, StatusCode::METHOD_NOT_ALLOWED) => (), + (StatusCode::METHOD_NOT_ALLOWED, _) => { + tmp = Some(mismatch); + tmp_status = tmp.status(); + }, + (sa, sb) if sa < sb => { + tmp = Some(mismatch); + tmp_status = tmp.status(); + }, + _ => (), + } + } + tmp +} + + +impl IsReject for Infallible { + fn status(&self) -> StatusCode { + match *self {} + } + + fn into_response(self) -> crate::reply::Response { + match self {} + } +} + +impl IsReject for Rejection { + fn status(&self) -> StatusCode { + match self.reason { + Reason::Fatal(ref response) => response.status(), + Reason::Mismatch(ref other) => preferred(other).status(), + } + } + + fn into_response(self) -> crate::reply::Response { + match self.reason { + Reason::Fatal(response) => *response, + Reason::Mismatch(other) => preferred(&other).into_response(), + } + } +} + +// ===== Mismatch ===== + +impl IsReject for Option<&Mismatch> { + fn status(&self) -> StatusCode { + match *self { + Some(inner) => match inner { + Mismatch::Known(ref k) => match *k { + Known::MethodNotAllowed(_) => StatusCode::METHOD_NOT_ALLOWED, + Known::InvalidHeader(_) + | Known::MissingHeader(_) + | Known::MissingCookie(_) + | Known::InvalidQuery(_) + | Known::BodyReadError(_) + | Known::BodyDeserializeError(_) => StatusCode::BAD_REQUEST, + #[cfg(feature = "websocket")] + Known::MissingConnectionUpgrade(_) => StatusCode::BAD_REQUEST, + Known::LengthRequired(_) => StatusCode::LENGTH_REQUIRED, + Known::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE, + Known::UnsupportedMediaType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Known::FilePermissionError(_) | Known::CorsForbidden(_) => StatusCode::FORBIDDEN, + Known::FileOpenError(_) + | Known::MissingExtension(_) + | Known::BodyConsumedMultipleTimes(_) => StatusCode::INTERNAL_SERVER_ERROR, + }, + Mismatch::Custom(..) => StatusCode::INTERNAL_SERVER_ERROR, + }, + None => { + StatusCode::NOT_FOUND + }, + } + } + + fn into_response(self) -> crate::reply::Response { + match self { + Some(inner) => match inner { + Mismatch::Known(ref e) => { + let mut res = http::Response::new(Body::from(e.to_string())); + *res.status_mut() = self.status(); + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + res + } + Mismatch::Custom(ref e) => { + tracing::error!( + "unhandled custom rejection, returning 500 response: {:?}", + e + ); + let body = format!("Unhandled rejection: {:?}", e); + let mut res = http::Response::new(Body::from(body)); + *res.status_mut() = self.status(); + res.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + res + } + }, + None => { + let mut response = http::Response::new(Body::from("Not found")); + *response.status_mut() = self.status(); + response.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + response + }, + } + } +} +impl Mismatch { + fn find(&self) -> Option<&T> { + match *self { + Mismatch::Known(ref e) => e.inner_as_any().downcast_ref(), + Mismatch::Custom(ref e) => e.downcast_ref(), + } + } +} + +unit_error! { + /// Invalid query + pub InvalidQuery: "Invalid query string" +} +unit_error! { + /// HTTP method not allowed + pub MethodNotAllowed: "HTTP method not allowed" +} +unit_error! { + /// A content-length header is required + pub LengthRequired: "A content-length header is required" +} +unit_error! { + /// The request payload is too large + pub PayloadTooLarge: "The request payload is too large" +} +unit_error! { + /// The request's content-type is not supported + pub UnsupportedMediaType: "The request's content-type is not supported" +} + +/// Missing request header +#[derive(Debug)] +pub struct MissingHeader { + name: &'static str, +} +impl MissingHeader { + /// Retrieve the name of the header that was missing + pub fn name(&self) -> &str { + self.name + } +} +impl ::std::fmt::Display for MissingHeader { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Missing request header {:?}", self.name) + } +} +impl StdError for MissingHeader {} + +/// Invalid request header +#[derive(Debug)] +pub struct InvalidHeader { + name: &'static str, +} +impl InvalidHeader { + /// Retrieve the name of the header that was invalid + pub fn name(&self) -> &str { + self.name + } +} +impl ::std::fmt::Display for InvalidHeader { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Invalid request header {:?}", self.name) + } +} +impl StdError for InvalidHeader {} + +/// Missing cookie +#[derive(Debug)] +pub struct MissingCookie { + name: &'static str, +} +impl MissingCookie { + /// Retrieve the name of the cookie that was missing + pub fn name(&self) -> &str { + self.name + } +} +impl ::std::fmt::Display for MissingCookie { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Missing request cookie {:?}", self.name) + } +} +impl StdError for MissingCookie {} + +mod sealed; + +#[cfg(test)] +mod tests; diff --git a/src/reject/sealed.rs b/src/reject/sealed.rs new file mode 100644 index 000000000..08d489d3d --- /dev/null +++ b/src/reject/sealed.rs @@ -0,0 +1,95 @@ +use super::{Reason, Rejection}; +use http::StatusCode; +use std::convert::Infallible; +use std::fmt; + +// This sealed trait exists to allow Filters to return either `Rejection` +// or `!`. There are no other types that make sense, and so it is sealed. +pub trait IsReject: fmt::Debug + Send + Sync { + fn status(&self) -> StatusCode; + fn into_response(self) -> crate::reply::Response; +} + +fn _assert_object_safe() { + fn _assert(_: &dyn IsReject) {} +} + +// This weird trait is to allow optimizations of propagating when a +// rejection can *never* happen (currently with the `Never` type, +// eventually to be replaced with `!`). +// +// Using this trait means the `Never` gets propagated to chained filters, +// allowing LLVM to eliminate more code paths. Without it, such as just +// requiring that `Rejection::from(Never)` were used in those filters, +// would mean that links later in the chain may assume a rejection *could* +// happen, and no longer eliminate those branches. +pub trait CombineRejection: Send + Sized { + /// The type that should be returned when only 1 of the two + /// "rejections" occurs. + /// + /// # For example: + /// + /// `warp::any().and(warp::path("foo"))` has the following steps: + /// + /// 1. Since this is `and`, only **one** of the rejections will occur, + /// and as soon as it does, it will be returned. + /// 2. `warp::any()` rejects with `Never`. So, it will never return `Never`. + /// 3. `warp::path()` rejects with `Rejection`. It may return `Rejection`. + /// + /// Thus, if the above filter rejects, it will definitely be `Rejection`. + type One: IsReject + From + From + Into; + + /// The type that should be returned when both rejections occur, + /// and need to be combined. + type Combined: IsReject; + + fn combine(self, other: E) -> Self::Combined; +} + +impl CombineRejection for Rejection { + type One = Rejection; + type Combined = Rejection; + + fn combine(self, other: Rejection) -> Self::Combined { + let reason = match (self.reason, other.reason) { + // If both are mismatches, add together their mismatches + // Using a vector may take a bit longer, but copying may well take less + // time than allocating a new Box would. + (Reason::Mismatch(mut this), Reason::Mismatch(mut that)) => { + this.append(&mut that); + Reason::Mismatch(this) + }, + // There should never occur two fatal errors, since the first should cause return + // As such the implied priority of self should never matter. + (Reason::Fatal(err), _) | (_, Reason::Fatal(err)) => Reason::Fatal(err), + }; + Rejection{ reason } + } +} + +impl CombineRejection for Rejection { + type One = Rejection; + type Combined = Infallible; + + fn combine(self, other: Infallible) -> Self::Combined { + match other {} + } +} + +impl CombineRejection for Infallible { + type One = Rejection; + type Combined = Infallible; + + fn combine(self, _: Rejection) -> Self::Combined { + match self {} + } +} + +impl CombineRejection for Infallible { + type One = Infallible; + type Combined = Infallible; + + fn combine(self, _: Infallible) -> Self::Combined { + match self {} + } +} diff --git a/src/reject/tests.rs b/src/reject/tests.rs new file mode 100644 index 000000000..66ce98377 --- /dev/null +++ b/src/reject/tests.rs @@ -0,0 +1,154 @@ +use super::*; +use http::StatusCode; + +#[derive(Debug, PartialEq)] +struct Left; + +#[derive(Debug, PartialEq)] +struct Right; + +impl Reject for Left {} +impl Reject for Right {} + +#[test] +fn rejection_status() { + assert_eq!(not_found().status(), StatusCode::NOT_FOUND); + assert_eq!( + method_not_allowed().status(), + StatusCode::METHOD_NOT_ALLOWED + ); + assert_eq!(length_required().status(), StatusCode::LENGTH_REQUIRED); + assert_eq!(payload_too_large().status(), StatusCode::PAYLOAD_TOO_LARGE); + assert_eq!( + unsupported_media_type().status(), + StatusCode::UNSUPPORTED_MEDIA_TYPE + ); + assert_eq!(custom(Left).status(), StatusCode::INTERNAL_SERVER_ERROR); +} + +#[tokio::test] +async fn combine_rejection_causes_with_some_left_and_none_right() { + let left = custom(Left); + let right = not_found(); + let reject = left.combine(right); + let resp = reject.into_response(); + + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Left" + ) +} + +#[tokio::test] +async fn combine_rejection_causes_with_none_left_and_some_right() { + let left = not_found(); + let right = custom(Right); + let reject = left.combine(right); + let resp = reject.into_response(); + + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Right" + ) +} + +#[tokio::test] +async fn unhandled_customs() { + let reject = not_found().combine(custom(Right)); + + let resp = reject.into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Right" + ); + + // There's no real way to determine which is worse, since both are a 500, + // so pick the first one. + let reject = custom(Left).combine(custom(Right)); + + let resp = reject.into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Left" + ); + + // With many rejections, custom still is top priority. + let reject = not_found() + .combine(not_found()) + .combine(not_found()) + .combine(custom(Right)) + .combine(not_found()); + + let resp = reject.into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response_body_string(resp).await, + "Unhandled rejection: Right" + ); +} + +async fn response_body_string(resp: crate::reply::Response) -> String { + let (_, body) = resp.into_parts(); + let body_bytes = hyper::body::to_bytes(body).await.expect("failed concat"); + String::from_utf8_lossy(&body_bytes).to_string() +} + +#[test] +fn find_cause() { + let rej = custom(Left); + + assert_eq!(rej.find::(), Some(&Left)); + + let rej = rej.combine(method_not_allowed()); + + assert_eq!(rej.find::(), Some(&Left)); + assert!(rej.find::().is_some(), "MethodNotAllowed"); +} + +#[test] +fn size_of_rejection() { + assert_eq!( + ::std::mem::size_of::(), + ::std::mem::size_of::() * 4, + ); +} + +#[test] +fn basic_fatal_creation() { + let err = "test"; + assert_eq!( + format!("{:?}", fatal(err).into_response().body()), + format!("{:?}", err.into_response().body()) + ); + assert_eq!(fatal(err).status(), StatusCode::OK); +} + +#[derive(Debug)] +struct X(u32); +impl Reject for X {} + +fn combine_n(n: u32, new_reject: F) -> Rejection +where + F: Fn(u32) -> R, + R: Reject, +{ + let mut rej = not_found(); + + for i in 0..n { + rej = rej.combine(custom(new_reject(i))); + } + + rej +} + +#[test] +fn test_debug() { + let rej = combine_n(3, X); + + let s = format!("{:?}", rej); + assert_eq!(s, "Rejection { reason: Mismatch([Custom(X(0)), Custom(X(1)), Custom(X(2))]) }"); +}