diff --git a/contrib/dyn_templates/src/fairing.rs b/contrib/dyn_templates/src/fairing.rs index 8decf1de1e..6cef441315 100644 --- a/contrib/dyn_templates/src/fairing.rs +++ b/contrib/dyn_templates/src/fairing.rs @@ -1,6 +1,7 @@ use rocket::{Rocket, Build, Orbit}; use rocket::fairing::{self, Fairing, Info, Kind}; use rocket::figment::{Source, value::magic::RelativePathBuf}; +use rocket::catcher::TypedError; use rocket::trace::Trace; use crate::context::{Callback, Context, ContextManager}; diff --git a/contrib/dyn_templates/src/template.rs b/contrib/dyn_templates/src/template.rs index 97a73b7b76..ab9f79c4ec 100644 --- a/contrib/dyn_templates/src/template.rs +++ b/contrib/dyn_templates/src/template.rs @@ -265,19 +265,21 @@ impl Template { /// extension and a fixed-size body containing the rendered template. If /// rendering fails, an `Err` of `Status::InternalServerError` is returned. impl<'r> Responder<'r, 'static> for Template { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { - let ctxt = req.rocket() - .state::() - .ok_or_else(|| { - error!( - "uninitialized template context: missing `Template::fairing()`.\n\ - To use templates, you must attach `Template::fairing()`." - ); - - Status::InternalServerError - })?; + type Error = std::convert::Infallible; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { + if let Some(ctxt) = req.rocket().state::() { + match self.finalize(&ctxt.context()) { + Ok(v) => v.respond_to(req), + Err(s) => response::Outcome::Forward(s), + } + } else { + error!( + "uninitialized template context: missing `Template::fairing()`.\n\ + To use templates, you must attach `Template::fairing()`." + ); - self.finalize(&ctxt.context())?.respond_to(req) + response::Outcome::Forward(Status::InternalServerError) + } } } diff --git a/contrib/ws/src/websocket.rs b/contrib/ws/src/websocket.rs index 361550441e..75d557b9b3 100644 --- a/contrib/ws/src/websocket.rs +++ b/contrib/ws/src/websocket.rs @@ -238,7 +238,8 @@ impl<'r> FromRequest<'r> for WebSocket { } impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { Response::build() .raw_header("Sec-Websocket-Version", "13") .raw_header("Sec-WebSocket-Accept", self.ws.key.clone()) @@ -250,7 +251,8 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> { impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S> where S: futures::Stream> + Send + 'o { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { Response::build() .raw_header("Sec-Websocket-Version", "13") .raw_header("Sec-WebSocket-Accept", self.ws.key.clone()) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index 57c898a059..bd8022130c 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -1,19 +1,64 @@ mod parse; -use devise::ext::SpanDiagnosticExt; -use devise::{Spanned, Result}; +use devise::{Result, Spanned}; use proc_macro2::{TokenStream, Span}; use crate::http_codegen::Optional; -use crate::syn_ext::ReturnTypeExt; +use crate::syn_ext::{IdentExt, ReturnTypeExt}; use crate::exports::*; +use self::parse::ErrorGuard; + +use super::param::Guard; + +fn error_type(guard: &ErrorGuard) -> TokenStream { + let ty = &guard.ty; + quote! { + (#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>()) + } +} + +fn error_guard_decl(guard: &ErrorGuard) -> TokenStream { + let (ident, ty) = (guard.ident.rocketized(), &guard.ty); + quote_spanned! { ty.span() => + let #ident: &#ty = match #_catcher::downcast(__error_init) { + Some(v) => v, + None => return #_Result::Err(#__status), + }; + } +} + +fn request_guard_decl(guard: &Guard) -> TokenStream { + let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); + quote_spanned! { ty.span() => + let #ident: #ty = match <#ty as #FromError>::from_error( + #__status, + #__req, + __error_init + ).await { + #_Result::Ok(__v) => __v, + #_Result::Err(__e) => { + ::rocket::trace::info!( + name: "forward", + target: concat!("rocket::codegen::catch::", module_path!()), + parameter = stringify!(#ident), + type_name = stringify!(#ty), + status = __e.code, + "error guard forwarding; trying next catcher" + ); + + return #_Err(#__status); + }, + }; + } +} + pub fn _catch( args: proc_macro::TokenStream, input: proc_macro::TokenStream ) -> Result { // Parse and validate all of the user's input. - let catch = parse::Attribute::parse(args.into(), input)?; + let catch = parse::Attribute::parse(args.into(), input.into())?; // Gather everything we'll need to generate the catcher. let user_catcher_fn = &catch.function; @@ -22,35 +67,28 @@ pub fn _catch( let status_code = Optional(catch.status.map(|s| s.code)); let deprecated = catch.function.attrs.iter().find(|a| a.path().is_ident("deprecated")); - // Determine the number of parameters that will be passed in. - if catch.function.sig.inputs.len() > 2 { - return Err(catch.function.sig.paren_token.span.join() - .error("invalid number of arguments: must be zero, one, or two") - .help("catchers optionally take `&Request` or `Status, &Request`")); - } - // This ensures that "Responder not implemented" points to the return type. let return_type_span = catch.function.sig.output.ty() .map(|ty| ty.span()) .unwrap_or_else(Span::call_site); - // Set the `req` and `status` spans to that of their respective function - // arguments for a more correct `wrong type` error span. `rev` to be cute. - let codegen_args = &[__req, __status]; - let inputs = catch.function.sig.inputs.iter().rev() - .zip(codegen_args.iter()) - .map(|(fn_arg, codegen_arg)| match fn_arg { - syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()), - syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span()) - }).rev(); + let error_guard = catch.error_guard.as_ref().map(error_guard_decl); + let error_type = Optional(catch.error_guard.as_ref().map(error_type)); + let request_guards = catch.request_guards.iter().map(request_guard_decl); + let parameter_names = catch.arguments.map.values() + .map(|(ident, _)| ident.rocketized()); // We append `.await` to the function call if this is `async`. let dot_await = catch.function.sig.asyncness .map(|a| quote_spanned!(a.span() => .await)); let catcher_response = quote_spanned!(return_type_span => { - let ___responder = #user_catcher_fn_name(#(#inputs),*) #dot_await; - #_response::Responder::respond_to(___responder, #__req)? + let ___responder = #user_catcher_fn_name(#(#parameter_names),*) #dot_await; + match #_response::Responder::respond_to(___responder, #__req) { + #Outcome::Success(v) => v, + // If the responder fails, we drop any typed error, and convert to 500 + #Outcome::Error(_) | #Outcome::Forward(_) => return Err(#Status::InternalServerError), + } }); // Generate the catcher, keeping the user's input around. @@ -68,20 +106,26 @@ pub fn _catch( fn into_info(self) -> #_catcher::StaticInfo { fn monomorphized_function<'__r>( #__status: #Status, - #__req: &'__r #Request<'_> + #__req: &'__r #Request<'_>, + __error_init: #_Option<&'__r (dyn #TypedError<'__r> + '__r)>, ) -> #_catcher::BoxFuture<'__r> { #_Box::pin(async move { + #error_guard + #(#request_guards)* let __response = #catcher_response; - #Response::build() - .status(#__status) - .merge(__response) - .ok() + #_Result::Ok( + #Response::build() + .status(#__status) + .merge(__response) + .finalize() + ) }) } #_catcher::StaticInfo { name: ::core::stringify!(#user_catcher_fn_name), code: #status_code, + error_type: #error_type, handler: monomorphized_function, location: (::core::file!(), ::core::line!(), ::core::column!()), } diff --git a/core/codegen/src/attribute/catch/parse.rs b/core/codegen/src/attribute/catch/parse.rs index 34125c9c74..e20e92f830 100644 --- a/core/codegen/src/attribute/catch/parse.rs +++ b/core/codegen/src/attribute/catch/parse.rs @@ -1,7 +1,12 @@ -use devise::ext::SpanDiagnosticExt; -use devise::{MetaItem, Spanned, Result, FromMeta, Diagnostic}; -use proc_macro2::TokenStream; +use devise::ext::{SpanDiagnosticExt, TypeExt}; +use devise::{Diagnostic, FromMeta, MetaItem, Result, SpanWrapped, Spanned}; +use proc_macro2::{Span, TokenStream, Ident}; +use quote::ToTokens; +use crate::attribute::param::{Dynamic, Guard}; +use crate::name::{ArgumentMap, Arguments, Name}; +use crate::proc_macro_ext::Diagnostics; +use crate::syn_ext::FnArgExt; use crate::{http, http_codegen}; /// This structure represents the parsed `catch` attribute and associated items. @@ -10,6 +15,56 @@ pub struct Attribute { pub status: Option, /// The function that was decorated with the `catch` attribute. pub function: syn::ItemFn, + pub arguments: Arguments, + pub error_guard: Option, + pub request_guards: Vec, +} + +pub struct ErrorGuard { + pub span: Span, + pub name: Name, + pub ident: syn::Ident, + pub ty: syn::Type, +} + +impl ErrorGuard { + fn new(param: SpanWrapped, args: &Arguments) -> Result { + if let Some((ident, ty)) = args.map.get(¶m.name) { + match ty { + syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(Self { + span: param.span(), + name: param.name.clone(), + ident: ident.clone(), + ty: elem.as_ref().clone(), + }), + ty => { + let msg = format!( + "Error argument must be a reference, found `{}`", + ty.to_token_stream() + ); + let diag = param.span() + .error("invalid type") + .span_note(ty.span(), msg) + .help(format!("Perhaps use `&{}` instead", ty.to_token_stream())); + Err(diag) + } + } + } else { + let msg = format!("expected argument named `{}` here", param.name); + let diag = param.span().error("unused parameter").span_note(args.span, msg); + Err(diag) + } + } +} + +fn status_guard(param: SpanWrapped, args: &Arguments) -> Result<(Name, Ident)> { + if let Some((ident, _)) = args.map.get(¶m.name) { + Ok((param.name.clone(), ident.clone())) + } else { + let msg = format!("expected argument named `{}` here", param.name); + let diag = param.span().error("unused parameter").span_note(args.span, msg); + Err(diag) + } } /// We generate a full parser for the meta-item for great error messages. @@ -17,6 +72,8 @@ pub struct Attribute { struct Meta { #[meta(naked)] code: Code, + error: Option>, + status: Option>, } /// `Some` if there's a code, `None` if it's `default`. @@ -43,16 +100,60 @@ impl FromMeta for Code { impl Attribute { pub fn parse(args: TokenStream, input: proc_macro::TokenStream) -> Result { + let mut diags = Diagnostics::new(); + let function: syn::ItemFn = syn::parse(input) .map_err(Diagnostic::from) .map_err(|diag| diag.help("`#[catch]` can only be used on functions"))?; let attr: MetaItem = syn::parse2(quote!(catch(#args)))?; - let status = Meta::from_meta(&attr) - .map(|meta| meta.code.0) + let attr = Meta::from_meta(&attr) + .map(|meta| meta) .map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \ `#[catch(404)]` or `#[catch(default)]`"))?; - Ok(Attribute { status, function }) + let span = function.sig.paren_token.span.join(); + let mut arguments = Arguments { map: ArgumentMap::new(), span }; + for arg in function.sig.inputs.iter() { + if let Some((ident, ty)) = arg.typed() { + let value = (ident.clone(), ty.with_stripped_lifetimes()); + arguments.map.insert(Name::from(ident), value); + } else { + let span = arg.span(); + let diag = if arg.wild().is_some() { + span.error("handler arguments must be named") + .help("to name an ignored handler argument, use `_name`") + } else { + span.error("handler arguments must be of the form `ident: Type`") + }; + + diags.push(diag); + } + } + let error_guard = attr.error.clone() + .map(|p| ErrorGuard::new(p, &arguments)) + .and_then(|p| p.map_err(|e| diags.push(e)).ok()); + let request_guards = arguments.map.iter() + .filter(|(name, _)| { + let mut all_other_guards = error_guard.iter() + .map(|g| &g.name); + + all_other_guards.all(|n| n != *name) + }) + .enumerate() + .map(|(index, (name, (ident, ty)))| Guard { + source: Dynamic { index, name: name.clone(), trailing: false }, + fn_ident: ident.clone(), + ty: ty.clone(), + }) + .collect(); + + diags.head_err_or(Attribute { + status: attr.code.0, + function, + arguments, + error_guard, + request_guards, + }) } } diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index cb501e43ed..7d3356a7f7 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -41,7 +41,7 @@ fn query_decls(route: &Route) -> Option { } define_spanned_export!(Span::call_site() => - __req, __data, _form, Outcome, _Ok, _Err, _Some, _None, Status + __req, __data, _form, Outcome, _Ok, _Err, _Some, _None, Status, resolve_error ); // Record all of the static parameters for later filtering. @@ -108,13 +108,23 @@ fn query_decls(route: &Route) -> Option { ::rocket::trace::span_info!( "codegen", "query string failed to match route declaration" => - { for _err in __e { ::rocket::trace::info!( + { for _err in __e.iter() { ::rocket::trace::info!( target: concat!("rocket::codegen::route::", module_path!()), "{_err}" ); } } ); + let __e = #resolve_error!(__e); + ::rocket::trace::info!( + target: concat!("rocket::codegen::route::", module_path!()), + error_type = __e.name, + "Forwarding error" + ); - return #Outcome::Forward((#__data, #Status::UnprocessableEntity)); + return #Outcome::Forward(( + #__data, + #Status::UnprocessableEntity, + __e.val + )); } (#(#ident.unwrap()),*) @@ -125,7 +135,7 @@ fn query_decls(route: &Route) -> Option { fn request_guard_decl(guard: &Guard) -> TokenStream { let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); define_spanned_export!(ty.span() => - __req, __data, _request, display_hack, FromRequest, Outcome + __req, __data, _request, display_hack, FromRequest, Outcome, resolve_error, _None ); quote_spanned! { ty.span() => @@ -141,20 +151,22 @@ fn request_guard_decl(guard: &Guard) -> TokenStream { "request guard forwarding" ); - return #Outcome::Forward((#__data, __e)); + return #Outcome::Forward((#__data, __e, #_None)); }, #[allow(unreachable_code)] #Outcome::Error((__c, __e)) => { + let __err = #resolve_error!(__e); ::rocket::trace::info!( name: "failure", target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - reason = %#display_hack!(__e), + // reason = %#display_hack!(&__e), + error_type = __err.name, "request guard failed" ); - return #Outcome::Error(__c); + return #Outcome::Error((__c, __err.val)); } }; } @@ -164,21 +176,23 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { let (i, name, ty) = (guard.index, &guard.name, &guard.ty); define_spanned_export!(ty.span() => __req, __data, _None, _Some, _Ok, _Err, - Outcome, FromSegments, FromParam, Status, display_hack + Outcome, FromSegments, FromParam, Status, display_hack, resolve_error ); // Returned when a dynamic parameter fails to parse. let parse_error = quote!({ + let __err = #resolve_error!(__error); ::rocket::trace::info!( name: "forward", target: concat!("rocket::codegen::route::", module_path!()), parameter = #name, type_name = stringify!(#ty), - reason = %#display_hack!(__error), + // reason = %#display_hack!(&__error), + error_type = __err.name, "path guard forwarding" ); - #Outcome::Forward((#__data, #Status::UnprocessableEntity)) + #Outcome::Forward((#__data, #Status::UnprocessableEntity, __err.val)) }); // All dynamic parameters should be found if this function is being called; @@ -200,7 +214,11 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { #i ); - return #Outcome::Forward((#__data, #Status::InternalServerError)); + return #Outcome::Forward(( + #__data, + #Status::InternalServerError, + #_None + )); } } }, @@ -219,7 +237,8 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { fn data_guard_decl(guard: &Guard) -> TokenStream { let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); - define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome); + define_spanned_export!(ty.span() => + __req, __data, display_hack, FromData, Outcome, resolve_error, _None); quote_spanned! { ty.span() => let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await { @@ -234,20 +253,22 @@ fn data_guard_decl(guard: &Guard) -> TokenStream { "data guard forwarding" ); - return #Outcome::Forward((__d, __e)); + return #Outcome::Forward((__d, __e, #_None)); } #[allow(unreachable_code)] #Outcome::Error((__c, __e)) => { + let __e = #resolve_error!(__e); ::rocket::trace::info!( name: "failure", target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - reason = %#display_hack!(__e), + // reason = %#display_hack!(&__e), + error_type = __e.name, "data guard failed" ); - return #Outcome::Error(__c); + return #Outcome::Error((__c, __e.val)); } }; } diff --git a/core/codegen/src/attribute/route/parse.rs b/core/codegen/src/attribute/route/parse.rs index 13f3b93d0f..60dcd261b5 100644 --- a/core/codegen/src/attribute/route/parse.rs +++ b/core/codegen/src/attribute/route/parse.rs @@ -1,6 +1,6 @@ use devise::{Spanned, SpanWrapped, Result, FromMeta}; use devise::ext::{SpanDiagnosticExt, TypeExt}; -use indexmap::{IndexSet, IndexMap}; +use indexmap::IndexSet; use proc_macro2::Span; use crate::attribute::suppress::Lint; @@ -8,7 +8,7 @@ use crate::proc_macro_ext::Diagnostics; use crate::http_codegen::{Method, MediaType}; use crate::attribute::param::{Parameter, Dynamic, Guard}; use crate::syn_ext::FnArgExt; -use crate::name::Name; +use crate::name::{ArgumentMap, Arguments, Name}; use crate::http::ext::IntoOwned; use crate::http::uri::{Origin, fmt}; @@ -31,14 +31,6 @@ pub struct Route { pub arguments: Arguments, } -type ArgumentMap = IndexMap; - -#[derive(Debug)] -pub struct Arguments { - pub span: Span, - pub map: ArgumentMap -} - /// The parsed `#[route(..)]` attribute. #[derive(Debug, FromMeta)] pub struct Attribute { diff --git a/core/codegen/src/derive/mod.rs b/core/codegen/src/derive/mod.rs index 134279e733..9a6141a6b4 100644 --- a/core/codegen/src/derive/mod.rs +++ b/core/codegen/src/derive/mod.rs @@ -3,3 +3,4 @@ pub mod from_form; pub mod from_form_field; pub mod responder; pub mod uri_display; +pub mod typed_error; diff --git a/core/codegen/src/derive/responder.rs b/core/codegen/src/derive/responder.rs index 736ee97d42..e4484bd109 100644 --- a/core/codegen/src/derive/responder.rs +++ b/core/codegen/src/derive/responder.rs @@ -1,8 +1,9 @@ use quote::ToTokens; use devise::{*, ext::{TypeExt, SpanDiagnosticExt}}; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; +use syn::{Ident, Lifetime, Type}; -use crate::exports::*; +use crate::{exports::*, syn_ext::IdentExt}; use crate::syn_ext::{TypeExt as _, GenericsExt as _}; use crate::http_codegen::{ContentType, Status}; @@ -25,32 +26,7 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream { .type_bound_mapper(MapperBuild::new() .try_enum_map(|m, e| mapper::enum_null(m, e)) .try_fields_map(|_, fields| { - let generic_idents = fields.parent.input().generics().type_idents(); - let lifetime = |ty: &syn::Type| syn::Lifetime::new("'o", ty.span()); - let mut types = fields.iter() - .map(|f| (f, &f.field.inner.ty)) - .map(|(f, ty)| (f, ty.with_replaced_lifetimes(lifetime(ty)))); - - let mut bounds = vec![]; - if let Some((_, ty)) = types.next() { - if !ty.is_concrete(&generic_idents) { - let span = ty.span(); - bounds.push(quote_spanned!(span => #ty: #_response::Responder<'r, 'o>)); - } - } - - for (f, ty) in types { - let attr = FieldAttr::one_from_attrs("response", &f.attrs)?.unwrap_or_default(); - if ty.is_concrete(&generic_idents) || attr.ignore { - continue; - } - - bounds.push(quote_spanned! { ty.span() => - #ty: ::std::convert::Into<#_http::Header<'o>> - }); - } - - Ok(quote!(#(#bounds,)*)) + bounds_from_fields(fields) }) ) .validator(ValidatorBuild::new() @@ -65,7 +41,9 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream { ) .inner_mapper(MapperBuild::new() .with_output(|_, output| quote! { - fn respond_to(self, __req: &'r #Request<'_>) -> #_response::Result<'o> { + fn respond_to(self, __req: &'r #Request<'_>) + -> #_response::Outcome<'o, Self::Error> + { #output } }) @@ -74,15 +52,29 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream { quote_spanned!(item.span() => __res.set_header(#item);) } + let error_outcome = match fields.parent { + FieldParent::Variant(p) => { + // let name = p.parent.ident.append("Error"); + // let var_name = &p.ident; + // quote! { #name::#var_name(e) } + quote! { #_catcher::AnyError(#_Box::new(e)) } + }, + _ => quote! { e }, + }; + let attr = ItemAttr::one_from_attrs("response", fields.parent.attrs())? .unwrap_or_default(); let responder = fields.iter().next().map(|f| { let (accessor, ty) = (f.accessor(), f.ty.with_stripped_lifetimes()); quote_spanned! { f.span() => - let mut __res = <#ty as #_response::Responder>::respond_to( + let mut __res = match <#ty as #_response::Responder>::respond_to( #accessor, __req - )?; + ) { + #Outcome::Success(val) => val, + #Outcome::Error(e) => return #Outcome::Error(#error_outcome), + #Outcome::Forward(f) => return #Outcome::Forward(f), + }; } }).expect("have at least one field"); @@ -106,9 +98,169 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream { #(#headers)* #content_type #status - #_Ok(__res) + #Outcome::Success(__res) }) }) ) + // TODO: What's the proper way to do this? + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + type Error = #output; + }) + .try_struct_map(|_, item| { + let (old, ty) = item.fields.iter().next().map(|f| { + let ty = f.ty.with_replaced_lifetimes(Lifetime::new("'o", Span::call_site())); + let old = f.ty.with_replaced_lifetimes(Lifetime::new("'a", Span::call_site())); + (old, ty) + }).expect("have at least one field"); + let type_params: Vec<_> = item.generics.type_params().map(|p| &p.ident).collect(); + let output_life = if old == ty && ty.is_concrete(&type_params) { + quote! { 'static } + } else { + quote! { 'o } + }; + + Ok(quote! { + <#ty as #_response::Responder<'r, #output_life>>::Error + }) + }) + .enum_map(|_, _item| { + // let name = item.ident.append("Error"); + // let response_types: Vec<_> = item.variants() + // .flat_map(|f| responder_types(f.fields()).into_iter()).collect(); + // // TODO: add where clauses, and filter for the type params I need + // let type_params: Vec<_> = item.generics + // .type_params() + // .map(|p| &p.ident) + // .filter(|p| generic_used(p, &response_types)) + // .collect(); + // quote!{ #name<'r, 'o, #(#type_params,)*> } + quote!{ #_catcher::AnyError<'r> } + }) + ) + // .outer_mapper(MapperBuild::new() + // .enum_map(|_, item| { + // let name = item.ident.append("Error"); + // let variants = item.variants().map(|d| { + // let var_name = &d.ident; + // let (old, ty) = d.fields().iter().next().map(|f| { + // let ty = f.ty.with_replaced_lifetimes( + // Lifetime::new("'o", Span::call_site())); + // (f.ty.clone(), ty) + // }).expect("have at least one field"); + // let output_life = if old == ty { + // quote! { 'static } + // } else { + // quote! { 'o } + // }; + // quote!{ + // #var_name(<#ty as #_response::Responder<'r, #output_life>>::Error), + // } + // }); + // let source = item.variants().map(|d| { + // let var_name = &d.ident; + // quote!{ + // Self::#var_name(v) => #_Some(v), + // } + // }); + // let response_types: Vec<_> = item.variants() + // .flat_map(|f| responder_types(f.fields()).into_iter()).collect(); + // // TODO: add where clauses, and filter for the type params I need + // let type_params: Vec<_> = item.generics + // .type_params() + // .map(|p| &p.ident) + // .filter(|p| generic_used(p, &response_types)) + // .collect(); + // let bounds: Vec<_> = item.variants() + // .map(|f| bounds_from_fields(f.fields()).expect("Bounds must be valid")) + // .collect(); + // let bounds: Vec<_> = item.variants() + // .flat_map(|f| responder_types(f.fields()).into_iter()) + // .map(|t| quote!{#t: #_response::Responder<'r, 'o>,}) + // .collect(); + // quote!{ + // pub enum #name<'r, 'o, #(#type_params: 'r,)*> + // where #(#bounds)* + // { + // #(#variants)* + // UnusedVariant( + // // Make this variant impossible to construct + // ::std::convert::Infallible, + // ::std::marker::PhantomData<&'o ()>, + // ), + // } + // // TODO: validate this impl - roughly each variant must be (at least) inv + // // wrt a lifetime, since they impl CanTransendTo> + // // TODO: also need to add requirements on the type parameters + // unsafe impl<'r, 'o: 'r, #(#type_params: 'r,)*> ::rocket::catcher::Transient + // for #name<'r, 'o, #(#type_params,)*> + // where #(#bounds)* + // { + // type Static = #name<'static, 'static>; + // type Transience = ::rocket::catcher::Inv<'r>; + // } + // impl<'r, 'o: 'r, #(#type_params,)*> #TypedError<'r> + // for #name<'r, 'o, #(#type_params,)*> + // where #(#bounds)* + // { + // fn source(&self) -> #_Option<&dyn #TypedError<'r>> { + // match self { + // #(#source)* + // Self::UnusedVariant(f, ..) => match *f { } + // } + // } + // } + // } + // }) + // ) .to_tokens() } + +fn generic_used(ident: &Ident, res_types: &[Type]) -> bool { + res_types.iter().any(|t| !t.is_concrete(&[ident])) +} + +fn responder_types(fields: Fields<'_>) -> Vec { + let generic_idents = fields.parent.input().generics().type_idents(); + let lifetime = |ty: &syn::Type| syn::Lifetime::new("'o", ty.span()); + let mut types = fields.iter() + .map(|f| (f, &f.field.inner.ty)) + .map(|(f, ty)| (f, ty.with_replaced_lifetimes(lifetime(ty)))); + + let mut bounds = vec![]; + if let Some((_, ty)) = types.next() { + if !ty.is_concrete(&generic_idents) { + bounds.push(ty); + } + } + bounds +} + +fn bounds_from_fields(fields: Fields<'_>) -> Result { + let generic_idents = fields.parent.input().generics().type_idents(); + let lifetime = |ty: &syn::Type| syn::Lifetime::new("'o", ty.span()); + let mut types = fields.iter() + .map(|f| (f, &f.field.inner.ty)) + .map(|(f, ty)| (f, ty.with_replaced_lifetimes(lifetime(ty)))); + + let mut bounds = vec![]; + if let Some((_, ty)) = types.next() { + if !ty.is_concrete(&generic_idents) { + let span = ty.span(); + bounds.push(quote_spanned!(span => #ty: #_response::Responder<'r, 'o>)); + } + } + + for (f, ty) in types { + let attr = FieldAttr::one_from_attrs("response", &f.attrs)?.unwrap_or_default(); + if ty.is_concrete(&generic_idents) || attr.ignore { + continue; + } + + bounds.push(quote_spanned! { ty.span() => + #ty: ::std::convert::Into<#_http::Header<'o>> + }); + } + + Ok(quote!(#(#bounds,)*)) +} diff --git a/core/codegen/src/derive/typed_error.rs b/core/codegen/src/derive/typed_error.rs new file mode 100644 index 0000000000..2b95ef856b --- /dev/null +++ b/core/codegen/src/derive/typed_error.rs @@ -0,0 +1,154 @@ +use devise::{*, ext::SpanDiagnosticExt}; +use proc_macro2::TokenStream; +use syn::{ConstParam, Index, LifetimeParam, Member, TypeParam}; + +use crate::exports::{*, Status as _Status}; +use crate::http_codegen::Status; + +#[derive(Debug, Default, FromMeta)] +struct ItemAttr { + status: Option>, + // TODO: support an option to avoid implementing Transient + // no_transient: bool, +} + +#[derive(Default, FromMeta)] +struct FieldAttr { + source: bool, +} + +pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { + let impl_tokens = quote!(impl<'r> #TypedError<'r>); + let typed_error: TokenStream = DeriveGenerator::build_for(input.clone(), impl_tokens) + .support(Support::Struct | Support::Enum | Support::Lifetime | Support::Type) + .replace_generic(0, 0) + .type_bound_mapper(MapperBuild::new() + .input_map(|_, i| { + let bounds = i.generics().type_params().map(|g| &g.ident); + quote! { #(#bounds: 'static,)* } + }) + ) + .validator(ValidatorBuild::new() + .input_validate(|_, i| match i.generics().lifetimes().count() > 1 { + true => Err(i.generics().span().error("only one lifetime is supported")), + false => Ok(()) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + #[allow(unused_variables)] + fn respond_to(&self, request: &'r #Request<'_>) + -> #_Result<#Response<'r>, #_Status> + { + #output + } + }) + .try_fields_map(|_, fields| { + let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())?; + Ok(item.map_or_else(|| quote! { + #_Err(#_Status::InternalServerError) + }, |ItemAttr { status, ..}| quote! { + #_Err(#status) + })) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + fn source(&'r self) -> #_Option<&'r (dyn #TypedError<'r> + 'r)> { + #output + } + }) + .try_fields_map(|_, fields| { + let mut source = None; + for field in fields.iter() { + if FieldAttr::one_from_attrs("error", &field.attrs)?.is_some_and(|a| a.source) { + if source.is_some() { + return Err(Diagnostic::spanned( + field.span(), + Level::Error, + "Only one field may be declared as `#[error(source)]`")); + } + if let FieldParent::Variant(_) = field.parent { + let name = field.match_ident(); + source = Some(quote! { #_Some(#name as &dyn #TypedError<'r>) }) + } else { + let span = field.field.span().into(); + let member = match field.ident { + Some(ref ident) => Member::Named(ident.clone()), + None => Member::Unnamed(Index { index: field.index as u32, span }) + }; + + source = Some(quote_spanned!( + span => #_Some(&self.#member as &dyn #TypedError<'r> + ))); + } + } + } + Ok(source.unwrap_or_else(|| quote! { #_None })) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + fn status(&self) -> #_Status { #output } + }) + .try_fields_map(|_, fields| { + let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())?; + Ok(item.map_or_else(|| quote! { + #_Status::InternalServerError + }, |ItemAttr { status, ..}| quote! { + #status + })) + }) + ) + .to_tokens(); + let impl_tokens = quote!(unsafe impl #_catcher::Transient); + let transient: TokenStream = DeriveGenerator::build_for(input, impl_tokens) + .support(Support::Struct | Support::Enum | Support::Lifetime | Support::Type) + .replace_generic(1, 0) + .type_bound_mapper(MapperBuild::new() + .input_map(|_, i| { + let bounds = i.generics().type_params().map(|g| &g.ident); + quote! { #(#bounds: 'static,)* } + }) + ) + .validator(ValidatorBuild::new() + .input_validate(|_, i| match i.generics().lifetimes().count() > 1 { + true => Err(i.generics().span().error("only one lifetime is supported")), + false => Ok(()) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + #output + }) + .input_map(|_, input| { + let name = input.ident(); + let args = input.generics() + .params + .iter() + .map(|g| { + match g { + syn::GenericParam::Lifetime(_) => quote!{ 'static }, + syn::GenericParam::Type(TypeParam { ident, .. }) => quote! { #ident }, + syn::GenericParam::Const(ConstParam { .. }) => todo!(), + } + }); + let trans = input.generics() + .lifetimes() + .map(|LifetimeParam { lifetime, .. }| quote!{#_catcher::Inv<#lifetime>}); + quote!{ + type Static = #name <#(#args)*>; + type Transience = (#(#trans,)*); + } + }) + ) + // TODO: hack to generate unsafe impl + .outer_mapper(MapperBuild::new() + .input_map(|_, _| quote!{ unsafe }) + ) + .to_tokens(); + quote!{ + #typed_error + #transient + } +} diff --git a/core/codegen/src/exports.rs b/core/codegen/src/exports.rs index 50470b46b9..6d70522efe 100644 --- a/core/codegen/src/exports.rs +++ b/core/codegen/src/exports.rs @@ -86,11 +86,14 @@ define_exported_paths! { _Vec => ::std::vec::Vec, _Cow => ::std::borrow::Cow, _ExitCode => ::std::process::ExitCode, + _trace => ::rocket::trace, display_hack => ::rocket::error::display_hack, + try_outcome => ::rocket::outcome::try_outcome, BorrowMut => ::std::borrow::BorrowMut, Outcome => ::rocket::outcome::Outcome, FromForm => ::rocket::form::FromForm, FromRequest => ::rocket::request::FromRequest, + FromError => ::rocket::catcher::FromError, FromData => ::rocket::data::FromData, FromSegments => ::rocket::request::FromSegments, FromParam => ::rocket::request::FromParam, @@ -102,6 +105,8 @@ define_exported_paths! { Route => ::rocket::Route, Catcher => ::rocket::Catcher, Status => ::rocket::http::Status, + resolve_error => ::rocket::catcher::resolve_typed_catcher, + TypedError => ::rocket::catcher::TypedError, } macro_rules! define_spanned_export { diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index 39401f1c5d..d3976292bb 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -294,17 +294,16 @@ route_attribute!(options => Method::Options); /// ```rust /// # #[macro_use] extern crate rocket; /// # -/// use rocket::Request; -/// use rocket::http::Status; +/// use rocket::http::{Status, uri::Origin}; /// /// #[catch(404)] -/// fn not_found(req: &Request) -> String { -/// format!("Sorry, {} does not exist.", req.uri()) +/// fn not_found(uri: &Origin) -> String { +/// format!("Sorry, {} does not exist.", uri) /// } /// -/// #[catch(default)] -/// fn default(status: Status, req: &Request) -> String { -/// format!("{} ({})", status, req.uri()) +/// #[catch(default, status = "")] +/// fn default(status: Status, uri: &Origin) -> String { +/// format!("{} ({})", status, uri) /// } /// ``` /// @@ -313,19 +312,59 @@ route_attribute!(options => Method::Options); /// The grammar for the `#[catch]` attributes is defined as: /// /// ```text -/// catch := STATUS | 'default' +/// catch := STATUS | 'default' (',' parameter)* /// /// STATUS := valid HTTP status code (integer in [200, 599]) +/// parameter := 'rank' '=' INTEGER +/// | 'status' '=' '"' SINGLE_PARAM '"' +/// | 'error' '=' '"' SINGLE_PARAM '"' +/// SINGLE_PARAM := '<' IDENT '>' /// ``` /// /// # Typing Requirements /// -/// The decorated function may take zero, one, or two arguments. It's type -/// signature must be one of the following, where `R:`[`Responder`]: +/// Every identifier, except for `_`, that appears in a dynamic parameter, must appear +/// as an argument to the function. +/// +/// The type of each function argument corresponding to a dynamic parameter is required to +/// meet specific requirements. +/// +/// - `status`: Must be [`Status`]. +/// - `error`: Must be a reference to a type that implements `Transient`. See +/// [Typed catchers](Self#Typed-catchers) for more info. +/// +/// All other arguments must implement [`FromRequest`]. +/// +/// A route argument declared a `_` must not appear in the function argument list and has no typing requirements. +/// +/// The return type of the decorated function must implement the [`Responder`] trait. +/// +/// # Typed catchers +/// +/// To make catchers more expressive and powerful, they can catch specific +/// error types. This is accomplished using the [`transient`] crate as a +/// replacement for [`std::any::Any`]. When a [`FromRequest`], [`FromParam`], +/// [`FromSegments`], [`FromForm`], or [`FromData`] implementation fails or +/// forwards, Rocket will convert to the error type to `dyn Any>`, if the +/// error type implements `Transient`. /// -/// * `fn() -> R` -/// * `fn(`[`&Request`]`) -> R` -/// * `fn(`[`Status`]`, `[`&Request`]`) -> R` +/// Only a single error type can be carried by a request - if a route forwards, +/// and another route is attempted, any error produced by the second route +/// overwrites the first. +/// +/// ## Custom error types +/// +/// All[^transient-impls] error types that Rocket itself produces implement +/// `Transient`, and can therefore be caught by a typed catcher. If you have +/// a custom guard of any type, you can implement `Transient` using the derive +/// macro provided by the `transient` crate. If the error type has lifetimes, +/// please read the documentation for the `Transient` derive macro - although it +/// prevents any unsafe implementation, it's not the easiest to use. Note that +/// Rocket upcasts the type to `dyn Any>`, where `'r` is the lifetime of +/// the `Request`, so any `Transient` impl must be able to trancend to `Co<'r>`, +/// and desend from `Co<'r>` at the catcher. +/// +/// [^transient-impls]: As of writing, this is a WIP. /// /// # Semantics /// @@ -333,10 +372,12 @@ route_attribute!(options => Method::Options); /// /// 1. An error [`Handler`]. /// -/// The generated handler calls the decorated function, passing in the -/// [`Status`] and [`&Request`] values if requested. The returned value is -/// used to generate a [`Response`] via the type's [`Responder`] -/// implementation. +/// The generated handler validates and generates all arguments for the generated function according +/// to their specific requirements. The order in which arguments are processed is: +/// +/// 1. The `error` type. This means no other guards will be evaluated if the error type does not match. +/// 2. Request guards, from left to right. If a Request guards forwards, the next catcher will be tried. +/// If the Request guard fails, the error is instead routed to the `500` catcher. /// /// 2. A static structure used by [`catchers!`] to generate a [`Catcher`]. /// @@ -351,6 +392,7 @@ route_attribute!(options => Method::Options); /// [`Catcher`]: ../rocket/struct.Catcher.html /// [`Response`]: ../rocket/struct.Response.html /// [`Responder`]: ../rocket/response/trait.Responder.html +/// [`FromRequest`]: ../rocket/request/trait.FromRequest.html #[proc_macro_attribute] pub fn catch(args: TokenStream, input: TokenStream) -> TokenStream { emit!(attribute::catch::catch_attribute(args, input)) @@ -967,6 +1009,15 @@ pub fn derive_responder(input: TokenStream) -> TokenStream { emit!(derive::responder::derive_responder(input)) } +/// Derive for the [`TypedError`] trait. +/// +/// TODO: Full documentation +/// [`TypedError`]: ../rocket/catcher/trait.TypedError.html +#[proc_macro_derive(TypedError, attributes(error))] +pub fn derive_typed_error(input: TokenStream) -> TokenStream { + emit!(derive::typed_error::derive_typed_error(input)) +} + /// Derive for the [`UriDisplay`] trait. /// /// The [`UriDisplay`] derive can be applied to enums and structs. When diff --git a/core/codegen/src/name.rs b/core/codegen/src/name.rs index c5b8e2b1a3..bed617aabe 100644 --- a/core/codegen/src/name.rs +++ b/core/codegen/src/name.rs @@ -1,8 +1,18 @@ use crate::http::uncased::UncasedStr; +use indexmap::IndexMap; use syn::{Ident, ext::IdentExt}; use proc_macro2::{Span, TokenStream}; +pub type ArgumentMap = IndexMap; + +#[derive(Debug)] +pub struct Arguments { + pub span: Span, + pub map: ArgumentMap +} + + /// A "name" read by codegen, which may or may not be an identifier. A `Name` is /// typically constructed indirectly via FromMeta, or From or directly /// from a string via `Name::new()`. A name is tokenized as a string. diff --git a/core/codegen/tests/async-routes.rs b/core/codegen/tests/async-routes.rs index 3a9ff56416..eadce62367 100644 --- a/core/codegen/tests/async-routes.rs +++ b/core/codegen/tests/async-routes.rs @@ -2,7 +2,6 @@ #[macro_use] extern crate rocket; use rocket::http::uri::Origin; -use rocket::request::Request; async fn noop() { } @@ -19,7 +18,7 @@ async fn repeated_query(sort: Vec<&str>) -> &str { } #[catch(404)] -async fn not_found(req: &Request<'_>) -> String { +async fn not_found(uri: &Origin<'_>) -> String { noop().await; - format!("{} not found", req.uri()) + format!("{} not found", uri) } diff --git a/core/codegen/tests/catcher.rs b/core/codegen/tests/catcher.rs index ddc59cb175..7c36e3b86c 100644 --- a/core/codegen/tests/catcher.rs +++ b/core/codegen/tests/catcher.rs @@ -5,14 +5,18 @@ #[macro_use] extern crate rocket; -use rocket::{Request, Rocket, Build}; +use rocket::{Rocket, Build}; use rocket::local::blocking::Client; -use rocket::http::Status; +use rocket::http::{Status, uri::Origin}; -#[catch(404)] fn not_found_0() -> &'static str { "404-0" } -#[catch(404)] fn not_found_1(_: &Request<'_>) -> &'static str { "404-1" } -#[catch(404)] fn not_found_2(_: Status, _: &Request<'_>) -> &'static str { "404-2" } -#[catch(default)] fn all(_: Status, r: &Request<'_>) -> String { r.uri().to_string() } +#[catch(404)] +fn not_found_0() -> &'static str { "404-0" } +#[catch(404)] +fn not_found_1() -> &'static str { "404-1" } +#[catch(404, status = "<_s>")] +fn not_found_2(_s: Status) -> &'static str { "404-2" } +#[catch(default, status = "<_s>")] +fn all(_s: Status, uri: &Origin<'_>) -> String { uri.to_string() } #[test] fn test_simple_catchers() { @@ -37,10 +41,14 @@ fn test_simple_catchers() { } #[get("/")] fn forward(code: u16) -> Status { Status::new(code) } -#[catch(400)] fn forward_400(status: Status, _: &Request<'_>) -> String { status.code.to_string() } -#[catch(404)] fn forward_404(status: Status, _: &Request<'_>) -> String { status.code.to_string() } -#[catch(444)] fn forward_444(status: Status, _: &Request<'_>) -> String { status.code.to_string() } -#[catch(500)] fn forward_500(status: Status, _: &Request<'_>) -> String { status.code.to_string() } +#[catch(400, status = "")] +fn forward_400(status: Status) -> String { status.code.to_string() } +#[catch(404, status = "")] +fn forward_404(status: Status) -> String { status.code.to_string() } +#[catch(444, status = "")] +fn forward_444(status: Status) -> String { status.code.to_string() } +#[catch(500, status = "")] +fn forward_500(status: Status) -> String { status.code.to_string() } #[test] fn test_status_param() { @@ -58,3 +66,23 @@ fn test_status_param() { assert_eq!(response.into_string().unwrap(), code.to_string()); } } + +#[catch(404)] +fn bad_req_untyped() -> &'static str { "404" } +#[catch(404, error = "<_e>")] +fn bad_req_string(_e: &String) -> &'static str { "404 String" } +#[catch(404, error = "<_e>")] +fn bad_req_tuple(_e: &()) -> &'static str { "404 ()" } + +#[test] +fn test_typed_catchers() { + fn rocket() -> Rocket { + rocket::build() + .register("/", catchers![bad_req_untyped, bad_req_string, bad_req_tuple]) + } + + // Assert the catchers do not collide. They are only differentiated by their error type. + let client = Client::debug(rocket()).unwrap(); + let response = client.get("/").dispatch(); + assert_eq!(response.status(), Status::NotFound); +} diff --git a/core/codegen/tests/route-raw.rs b/core/codegen/tests/route-raw.rs index 9ce65167b2..536162fb6a 100644 --- a/core/codegen/tests/route-raw.rs +++ b/core/codegen/tests/route-raw.rs @@ -1,6 +1,7 @@ #[macro_use] extern crate rocket; use rocket::local::blocking::Client; +use rocket_http::Method; // Test that raw idents can be used for route parameter names @@ -15,8 +16,8 @@ fn swap(r#raw: String, bare: String) -> String { } #[catch(400)] -fn catch(r#raw: &rocket::Request<'_>) -> String { - format!("{}", raw.method()) +fn catch(r#raw: Method) -> String { + format!("{}", raw) } #[test] diff --git a/core/codegen/tests/typed_error.rs b/core/codegen/tests/typed_error.rs new file mode 100644 index 0000000000..b407811ff4 --- /dev/null +++ b/core/codegen/tests/typed_error.rs @@ -0,0 +1,69 @@ +#[macro_use] extern crate rocket; +use rocket::catcher::TypedError; +use rocket::http::Status; + +fn boxed_error<'r>(_val: Box + 'r>) {} + +#[derive(TypedError)] +pub enum Foo<'r> { + First(String), + Second(Vec), + Third { + #[error(source)] + responder: std::io::Error, + }, + #[error(status = 400)] + Fourth { + string: &'r str, + }, +} + +#[test] +fn validate_foo() { + let first = Foo::First("".into()); + assert_eq!(first.status(), Status::InternalServerError); + assert!(first.source().is_none()); + boxed_error(Box::new(first)); + let second = Foo::Second(vec![]); + assert_eq!(second.status(), Status::InternalServerError); + assert!(second.source().is_none()); + boxed_error(Box::new(second)); + let third = Foo::Third { + responder: std::io::Error::new(std::io::ErrorKind::NotFound, ""), + }; + assert_eq!(third.status(), Status::InternalServerError); + assert!(std::ptr::eq( + third.source().unwrap(), + if let Foo::Third { responder } = &third { responder } else { panic!() } + )); + boxed_error(Box::new(third)); + let fourth = Foo::Fourth { string: "" }; + assert_eq!(fourth.status(), Status::BadRequest); + assert!(fourth.source().is_none()); + boxed_error(Box::new(fourth)); +} + +#[derive(TypedError)] +pub struct InfallibleError { + #[error(source)] + _inner: std::convert::Infallible, +} + +#[derive(TypedError)] +pub struct StaticError { + #[error(source)] + inner: std::string::FromUtf8Error, +} + +#[test] +fn validate_static() { + let val = StaticError { + inner: String::from_utf8(vec![0xFF]).unwrap_err(), + }; + assert_eq!(val.status(), Status::InternalServerError); + assert!(std::ptr::eq( + val.source().unwrap(), + &val.inner, + )); + boxed_error(Box::new(val)); +} diff --git a/core/codegen/tests/ui-fail/typed_error.rs b/core/codegen/tests/ui-fail/typed_error.rs new file mode 100644 index 0000000000..9ede6a1153 --- /dev/null +++ b/core/codegen/tests/ui-fail/typed_error.rs @@ -0,0 +1,32 @@ +#[macro_use] extern crate rocket; + +#[derive(TypedError)] +struct InnerError; +struct InnerNonError; + +#[derive(TypedError)] +struct Thing1<'a, 'b> { + a: &'a str, + b: &'b str, +} + +#[derive(TypedError)] +struct Thing2 { + #[error(source)] + inner: InnerNonError, +} + +#[derive(TypedError)] +enum Thing3<'a, 'b> { + A(&'a str), + B(&'b str), +} + +#[derive(TypedError)] +enum Thing4 { + A(#[error(source)] InnerNonError), + B(#[error(source)] InnerError), +} + +#[derive(TypedError)] +enum EmptyEnum { } diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index ff62a0ae87..c8b86fc666 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -36,6 +36,7 @@ memchr = "2" stable-pattern = "0.1" cookie = { version = "0.18", features = ["percent-encode"] } state = "0.6" +transient = { version = "0.4", path = "/code/matthew/transient" } [dependencies.serde] version = "1.0" diff --git a/core/http/src/status.rs b/core/http/src/status.rs index f90e40f240..20a944bc51 100644 --- a/core/http/src/status.rs +++ b/core/http/src/status.rs @@ -1,5 +1,7 @@ use std::fmt; +use transient::Transient; + /// Enumeration of HTTP status classes. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub enum StatusClass { @@ -112,7 +114,7 @@ impl StatusClass { /// } /// # } /// ``` -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Transient)] pub struct Status { /// The HTTP status code associated with this status. pub code: u16, diff --git a/core/http/src/uri/error.rs b/core/http/src/uri/error.rs index 06705d27f2..7c01d79790 100644 --- a/core/http/src/uri/error.rs +++ b/core/http/src/uri/error.rs @@ -1,6 +1,7 @@ //! Errors arising from parsing invalid URIs. use std::fmt; +use transient::Static; pub use crate::parse::uri::Error; @@ -29,6 +30,8 @@ pub enum PathError { BadEnd(char), } +impl Static for PathError {} + impl fmt::Display for PathError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 004115a681..2d66ec5bf7 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -27,9 +27,9 @@ default = ["http2", "tokio-macros", "trace"] http2 = ["hyper/http2", "hyper-util/http2"] http3-preview = ["s2n-quic", "s2n-quic-h3", "tls"] secrets = ["cookie/private", "cookie/key-expansion"] -json = ["serde_json"] -msgpack = ["rmp-serde"] -uuid = ["uuid_", "rocket_http/uuid"] +json = ["serde_json", "transient/serde_json"] +msgpack = ["rmp-serde", "transient/rmp-serde"] +uuid = ["uuid_", "rocket_http/uuid", "transient/uuid"] tls = ["rustls", "tokio-rustls", "rustls-pemfile"] mtls = ["tls", "x509-parser"] tokio-macros = ["tokio/macros"] @@ -74,6 +74,7 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] } cookie = { version = "0.18", features = ["percent-encode"] } futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" +transient = { version = "0.4", features = ["either"], path = "/code/matthew/transient" } # tracing tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] } @@ -128,6 +129,7 @@ optional = true [dependencies.s2n-quic-h3] git = "https://github.com/SergioBenitez/s2n-quic-h3.git" +rev = "865fd25" optional = true [target.'cfg(unix)'.dependencies] diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index 2aa1402ada..1c74b0fda1 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -1,12 +1,14 @@ use std::fmt; use std::io::Cursor; +use transient::TypeId; + use crate::http::uri::Path; use crate::http::ext::IntoOwned; use crate::response::Response; use crate::request::Request; use crate::http::{Status, ContentType, uri}; -use crate::catcher::{Handler, BoxFuture}; +use crate::catcher::{BoxFuture, TypedError, Handler}; /// An error catching route. /// @@ -72,8 +74,7 @@ use crate::catcher::{Handler, BoxFuture}; /// ```rust,no_run /// #[macro_use] extern crate rocket; /// -/// use rocket::Request; -/// use rocket::http::Status; +/// use rocket::http::{Status, uri::Origin}; /// /// #[catch(500)] /// fn internal_error() -> &'static str { @@ -81,13 +82,13 @@ use crate::catcher::{Handler, BoxFuture}; /// } /// /// #[catch(404)] -/// fn not_found(req: &Request) -> String { -/// format!("I couldn't find '{}'. Try something else?", req.uri()) +/// fn not_found(uri: &Origin) -> String { +/// format!("I couldn't find '{}'. Try something else?", uri) /// } /// -/// #[catch(default)] -/// fn default(status: Status, req: &Request) -> String { -/// format!("{} ({})", status, req.uri()) +/// #[catch(default, status = "")] +/// fn default(status: Status, uri: &Origin) -> String { +/// format!("{} ({})", status, uri) /// } /// /// #[launch] @@ -96,13 +97,6 @@ use crate::catcher::{Handler, BoxFuture}; /// } /// ``` /// -/// A function decorated with `#[catch]` may take zero, one, or two arguments. -/// It's type signature must be one of the following, where `R:`[`Responder`]: -/// -/// * `fn() -> R` -/// * `fn(`[`&Request`]`) -> R` -/// * `fn(`[`Status`]`, `[`&Request`]`) -> R` -/// /// See the [`catch`] documentation for full details. /// /// [`catch`]: crate::catch @@ -120,6 +114,9 @@ pub struct Catcher { /// The catcher's associated error handler. pub handler: Box, + /// Catcher error type + pub(crate) error_type: Option<(TypeId, &'static str)>, + /// The mount point. pub(crate) base: uri::Origin<'static>, @@ -132,8 +129,9 @@ pub struct Catcher { pub(crate) location: Option<(&'static str, u32, u32)>, } -// The rank is computed as -(number of nonempty segments in base) => catchers +// The rank is computed as -(number of nonempty segments in base) *2 => catchers // with more nonempty segments have lower ranks => higher precedence. +// Doubled to provide space between for typed catchers. fn rank(base: Path<'_>) -> isize { -(base.segments().filter(|s| !s.is_empty()).count() as isize) } @@ -147,22 +145,26 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) + /// -> BoxFuture<'r> + /// { /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).responder_error() }) /// } /// - /// fn handle_500<'r>(_: Status, req: &'r Request<'_>) -> BoxFuture<'r> { - /// Box::pin(async move{ "Whoops, we messed up!".respond_to(req) }) + /// fn handle_500<'r>(_: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) -> BoxFuture<'r> { + /// Box::pin(async move{ "Whoops, we messed up!".respond_to(req).responder_error() }) /// } /// - /// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_default<'r>(status: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) + /// -> BoxFuture<'r> + /// { /// let res = (status, format!("{}: {}", status, req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).responder_error() }) /// } /// /// let not_found_catcher = Catcher::new(404, handle_404); @@ -187,6 +189,7 @@ impl Catcher { name: None, base: uri::Origin::root().clone(), handler: Box::new(handler), + error_type: None, rank: rank(uri::Origin::root().path()), code, location: None, @@ -199,13 +202,15 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) + /// -> BoxFuture<'r> + /// { /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).responder_error() }) /// } /// /// let catcher = Catcher::new(404, handle_404); @@ -225,14 +230,16 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// # use rocket::uri; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) + /// -> BoxFuture<'r> + /// { /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).responder_error() }) /// } /// /// let catcher = Catcher::new(404, handle_404); @@ -279,13 +286,15 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) + /// -> BoxFuture<'r> + /// { /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).responder_error() }) /// } /// /// let catcher = Catcher::new(404, handle_404); @@ -313,7 +322,9 @@ impl Catcher { impl Default for Catcher { fn default() -> Self { - fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + fn handler<'r>(s: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) + -> BoxFuture<'r> + { Box::pin(async move { Ok(default_handler(s, req)) }) } @@ -330,8 +341,11 @@ pub struct StaticInfo { pub name: &'static str, /// The catcher's status code. pub code: Option, + /// The catcher's error type. + pub error_type: Option<(TypeId, &'static str)>, /// The catcher's handler, i.e, the annotated function. - pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>, + pub handler: for<'r> fn(Status, &'r Request<'_>, Option<&'r dyn TypedError<'r>>) + -> BoxFuture<'r>, /// The file, line, and column where the catcher was defined. pub location: (&'static str, u32, u32), } @@ -342,6 +356,7 @@ impl From for Catcher { fn from(info: StaticInfo) -> Catcher { let mut catcher = Catcher::new(info.code, info.handler); catcher.name = Some(info.name.into()); + catcher.error_type = info.error_type; catcher.location = Some(info.location); catcher } @@ -352,6 +367,7 @@ impl fmt::Debug for Catcher { f.debug_struct("Catcher") .field("name", &self.name) .field("base", &self.base) + .field("error_type", &self.error_type.as_ref().map(|(_, n)| n)) .field("code", &self.code) .field("rank", &self.rank) .finish() @@ -418,7 +434,7 @@ macro_rules! default_handler_fn { pub(crate) fn default_handler<'r>( status: Status, - req: &'r Request<'_> + req: &'r Request<'_>, ) -> Response<'r> { let preferred = req.accept().map(|a| a.preferred()); let (mime, text) = if preferred.map_or(false, |a| a.is_json()) { diff --git a/core/lib/src/catcher/from_error.rs b/core/lib/src/catcher/from_error.rs new file mode 100644 index 0000000000..7e84a14cd6 --- /dev/null +++ b/core/lib/src/catcher/from_error.rs @@ -0,0 +1,96 @@ +use async_trait::async_trait; + +use crate::http::Status; +use crate::outcome::Outcome; +use crate::request::FromRequest; +use crate::Request; + +use crate::catcher::TypedError; + +// TODO: update docs and do links +/// Trait used to extract types for an error catcher. You should +/// pretty much never implement this yourself. There are several +/// existing implementations, that should cover every need. +/// +/// - `Status`: Extracts the HTTP status that this error is catching. +/// - `&Request<'_>`: Extracts a reference to the entire request that +/// triggered this error to begin with. +/// - `T: FromRequest<'_>`: Extracts type that implements `FromRequest` +/// - `&dyn TypedError<'_>`: Extracts the typed error, as a dynamic +/// trait object. +/// - `Option<&dyn TypedError<'_>>`: Same as previous, but succeeds even +/// if there is no typed error to extract. +#[async_trait] +pub trait FromError<'r>: Sized { + async fn from_error( + status: Status, + request: &'r Request<'r>, + error: Option<&'r dyn TypedError<'r>> + ) -> Result; +} + +#[async_trait] +impl<'r> FromError<'r> for Status { + async fn from_error( + status: Status, + _r: &'r Request<'r>, + _e: Option<&'r dyn TypedError<'r>> + ) -> Result { + Ok(status) + } +} + +#[async_trait] +impl<'r> FromError<'r> for &'r Request<'r> { + async fn from_error( + _s: Status, + req: &'r Request<'r>, + _e: Option<&'r dyn TypedError<'r>> + ) -> Result { + Ok(req) + } +} + +#[async_trait] +impl<'r, T: FromRequest<'r>> FromError<'r> for T { + async fn from_error( + _s: Status, + req: &'r Request<'r>, + _e: Option<&'r dyn TypedError<'r>> + ) -> Result { + match T::from_request(req).await { + Outcome::Success(val) => Ok(val), + Outcome::Error((s, e)) => { + info!(status = %s, "Catcher guard error: {:?}", e); + Err(s) + }, + Outcome::Forward(s) => { + info!(status = %s, "Catcher guard forwarding"); + Err(s) + }, + } + } +} + +#[async_trait] +impl<'r> FromError<'r> for &'r dyn TypedError<'r> { + async fn from_error( + _s: Status, + _r: &'r Request<'r>, + error: Option<&'r dyn TypedError<'r>> + ) -> Result { + // TODO: what's the correct status here? Not Found? + error.ok_or(Status::InternalServerError) + } +} + +#[async_trait] +impl<'r> FromError<'r> for Option<&'r dyn TypedError<'r>> { + async fn from_error( + _s: Status, + _r: &'r Request<'r>, + error: Option<&'r dyn TypedError<'r>> + ) -> Result { + Ok(error) + } +} diff --git a/core/lib/src/catcher/handler.rs b/core/lib/src/catcher/handler.rs index f33ceba0e3..ef15f1f93e 100644 --- a/core/lib/src/catcher/handler.rs +++ b/core/lib/src/catcher/handler.rs @@ -1,4 +1,5 @@ use crate::{Request, Response}; +use crate::catcher::TypedError; use crate::http::Status; /// Type alias for the return type of a [`Catcher`](crate::Catcher)'s @@ -29,7 +30,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>; /// and used as follows: /// /// ```rust,no_run -/// use rocket::{Request, Catcher, catcher}; +/// use rocket::{Request, Catcher, catcher::{self, TypedError}}; /// use rocket::response::{Response, Responder}; /// use rocket::http::Status; /// @@ -45,14 +46,16 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>; /// /// #[rocket::async_trait] /// impl catcher::Handler for CustomHandler { -/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> catcher::Result<'r> { +/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, _e: Option<&'r dyn TypedError<'r>>) +/// -> catcher::Result<'r> +/// { /// let inner = match self.0 { -/// Kind::Simple => "simple".respond_to(req)?, -/// Kind::Intermediate => "intermediate".respond_to(req)?, -/// Kind::Complex => "complex".respond_to(req)?, +/// Kind::Simple => "simple".respond_to(req).responder_error()?, +/// Kind::Intermediate => "intermediate".respond_to(req).responder_error()?, +/// Kind::Complex => "complex".respond_to(req).responder_error()?, /// }; /// -/// Response::build_from(inner).status(status).ok() +/// Response::build_from(inner).status(status).ok::<()>().responder_error() /// } /// } /// @@ -97,30 +100,38 @@ pub trait Handler: Cloneable + Send + Sync + 'static { /// Nevertheless, failure is allowed, both for convenience and necessity. If /// an error handler fails, Rocket's default `500` catcher is invoked. If it /// succeeds, the returned `Response` is used to respond to the client. - async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> Result<'r>; + async fn handle<'r>( + &self, + status: Status, + req: &'r Request<'_>, + error: Option<&'r dyn TypedError<'r>> + ) -> Result<'r>; } // We write this manually to avoid double-boxing. impl Handler for F - where for<'x> F: Fn(Status, &'x Request<'_>) -> BoxFuture<'x>, + where for<'x> F: Fn(Status, &'x Request<'_>, Option<&'x dyn TypedError<'x>>) -> BoxFuture<'x>, { - fn handle<'r, 'life0, 'life1, 'async_trait>( + fn handle<'r, 'life0, 'life1, 'life2, 'async_trait>( &'life0 self, status: Status, req: &'r Request<'life1>, + error: Option<&'r (dyn TypedError<'r> + 'r)>, ) -> BoxFuture<'r> where 'r: 'async_trait, 'life0: 'async_trait, 'life1: 'async_trait, Self: 'async_trait, { - self(status, req) + self(status, req, error) } } // Used in tests! Do not use, please. #[doc(hidden)] -pub fn dummy_handler<'r>(_: Status, _: &'r Request<'_>) -> BoxFuture<'r> { +pub fn dummy_handler<'r>(_: Status, _: &'r Request<'_>, _: Option<&'r dyn TypedError<'r>>) + -> BoxFuture<'r> +{ Box::pin(async move { Ok(Response::new()) }) } diff --git a/core/lib/src/catcher/mod.rs b/core/lib/src/catcher/mod.rs index 4f5fefa19d..f3127049b8 100644 --- a/core/lib/src/catcher/mod.rs +++ b/core/lib/src/catcher/mod.rs @@ -2,6 +2,10 @@ mod catcher; mod handler; +mod types; +mod from_error; pub use catcher::*; pub use handler::*; +pub use types::*; +pub use from_error::*; diff --git a/core/lib/src/catcher/types.rs b/core/lib/src/catcher/types.rs new file mode 100644 index 0000000000..2841d5734d --- /dev/null +++ b/core/lib/src/catcher/types.rs @@ -0,0 +1,218 @@ +use either::Either; +use transient::{Any, CanRecoverFrom, CanTranscendTo, Downcast, Transience}; +use crate::{http::Status, response::{self, Responder}, Request, Response}; +#[doc(inline)] +pub use transient::{Static, Transient, TypeId, Inv}; + +/// Polyfill for trait upcasting to [`Any`] +pub trait AsAny: Any + Sealed { + /// The actual upcast + fn as_any(&self) -> &dyn Any; + /// convience typeid of the inner typeid + fn trait_obj_typeid(&self) -> TypeId; +} + +use sealed::Sealed; +mod sealed { + use transient::{Any, Inv, Transient, TypeId}; + + use super::AsAny; + + pub trait Sealed {} + impl<'r, T: Any>> Sealed for T { } + impl<'r, T: Any> + Transient> AsAny> for T { + fn as_any(&self) -> &dyn Any> { + self + } + fn trait_obj_typeid(&self) -> transient::TypeId { + TypeId::of::() + } + } +} + +/// This is the core of typed catchers. If an error type (returned by +/// FromParam, FromRequest, FromForm, FromData, or Responder) implements +/// this trait, it can be caught by a typed catcher. (TODO) This trait +/// can be derived. +pub trait TypedError<'r>: AsAny> + Send + Sync + 'r { + /// Generates a default response for this type (or forwards to a default catcher) + #[allow(unused_variables)] + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + Err(Status::InternalServerError) + } + + /// A descriptive name of this error type. Defaults to the type name. + fn name(&self) -> &'static str { std::any::type_name::() } + + /// The error that caused this error. Defaults to None. + /// + /// # Warning + /// A typed catcher will not attempt to follow the source of an error + /// more than once. + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { None } + + /// Status code + // TODO: This is currently only used for errors produced by Fairings + fn status(&self) -> Status { Status::InternalServerError } +} + +impl<'r> TypedError<'r> for std::convert::Infallible { } + +impl<'r> TypedError<'r> for () { } + +impl<'r> TypedError<'r> for std::io::Error { + fn status(&self) -> Status { + match self.kind() { + std::io::ErrorKind::NotFound => Status::NotFound, + _ => Status::InternalServerError, + } + } +} + +impl<'r> TypedError<'r> for std::num::ParseIntError {} +impl<'r> TypedError<'r> for std::num::ParseFloatError {} +impl<'r> TypedError<'r> for std::string::FromUtf8Error {} + +impl TypedError<'_> for Status { + fn status(&self) -> Status { *self } +} + +#[cfg(feature = "json")] +impl<'r> TypedError<'r> for serde_json::Error {} + +#[cfg(feature = "msgpack")] +impl<'r> TypedError<'r> for rmp_serde::encode::Error {} +#[cfg(feature = "msgpack")] +impl<'r> TypedError<'r> for rmp_serde::decode::Error {} + +// TODO: This is a hack to make any static type implement Transient +impl<'r, T: std::fmt::Debug + Send + Sync + 'static> TypedError<'r> for response::Debug { + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + format!("{:?}", self.0).respond_to(request).responder_error() + } +} + +impl<'r, L, R> TypedError<'r> for Either + where L: TypedError<'r> + Transient, + L::Transience: CanTranscendTo>, + R: TypedError<'r> + Transient, + R::Transience: CanTranscendTo>, +{ + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + match self { + Self::Left(v) => v.respond_to(request), + Self::Right(v) => v.respond_to(request), + } + } + + fn name(&self) -> &'static str { std::any::type_name::() } + + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { + match self { + Self::Left(v) => Some(v), + Self::Right(v) => Some(v), + } + } + + fn status(&self) -> Status { + match self { + Self::Left(v) => v.status(), + Self::Right(v) => v.status(), + } + } +} + +// TODO: This cannot be used as a bound on an untyped catcher to get any error type. +// This is mostly an implementation detail (and issue with double boxing) for +// the responder derive +#[derive(Transient)] +pub struct AnyError<'r>(pub Box + 'r>); + +impl<'r> TypedError<'r> for AnyError<'r> { + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { + Some(self.0.as_ref()) + } +} + +pub fn downcast<'r, T: Transient + 'r>(v: Option<&'r dyn TypedError<'r>>) -> Option<&'r T> + where T::Transience: CanRecoverFrom> +{ + // if v.is_none() { + // crate::trace::error!("No value to downcast from"); + // } + let v = v?; + // crate::trace::error!("Downcasting error from {}", v.name()); + v.as_any().downcast_ref() +} + +/// Upcasts a value to `Box>`, falling back to a default if it doesn't implement +/// `Error` +#[doc(hidden)] +#[macro_export] +macro_rules! resolve_typed_catcher { + ($T:expr) => ({ + #[allow(unused_imports)] + use $crate::catcher::resolution::{Resolve, DefaultTypeErase, ResolvedTypedError}; + + let inner = Resolve::new($T).cast(); + ResolvedTypedError { + name: inner.as_ref().map(|e| e.name()), + val: inner, + } + }); +} + +pub use resolve_typed_catcher; + +pub mod resolution { + use std::marker::PhantomData; + + use transient::{CanTranscendTo, Transient}; + + use super::*; + + /// The *magic*. + /// + /// `Resolve::item` for `T: Transient` is `::item`. + /// `Resolve::item` for `T: !Transient` is `DefaultTypeErase::item`. + /// + /// This _must_ be used as `Resolve:::item` for resolution to work. This + /// is a fun, static dispatch hack for "specialization" that works because + /// Rust prefers inherent methods over blanket trait impl methods. + pub struct Resolve<'r, T: 'r>(T, PhantomData<&'r ()>); + + impl<'r, T: 'r> Resolve<'r, T> { + pub fn new(val: T) -> Self { + Self(val, PhantomData) + } + } + + /// Fallback trait "implementing" `Transient` for all types. This is what + /// Rust will resolve `Resolve::item` to when `T: !Transient`. + pub trait DefaultTypeErase<'r>: Sized { + const SPECIALIZED: bool = false; + + fn cast(self) -> Option>> { None } + } + + impl<'r, T: 'r> DefaultTypeErase<'r> for Resolve<'r, T> {} + + /// "Specialized" "implementation" of `Transient` for `T: Transient`. This is + /// what Rust will resolve `Resolve::item` to when `T: Transient`. + impl<'r, T: TypedError<'r> + Transient> Resolve<'r, T> + where T::Transience: CanTranscendTo> + { + pub const SPECIALIZED: bool = true; + + pub fn cast(self) -> Option>> { Some(Box::new(self.0)) } + } + + /// Wrapper type to hold the return type of `resolve_typed_catcher`. + #[doc(hidden)] + pub struct ResolvedTypedError<'r> { + /// The return value from `TypedError::name()`, if Some + pub name: Option<&'static str>, + /// The upcast error, if it supports it + pub val: Option + 'r>>, + } +} diff --git a/core/lib/src/data/capped.rs b/core/lib/src/data/capped.rs index 804a42d486..9a7b070dec 100644 --- a/core/lib/src/data/capped.rs +++ b/core/lib/src/data/capped.rs @@ -205,7 +205,8 @@ use crate::response::{self, Responder}; use crate::request::Request; impl<'r, 'o: 'r, T: Responder<'r, 'o>> Responder<'r, 'o> for Capped { - fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> { + type Error = T::Error; + fn respond_to(self, request: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { self.value.respond_to(request) } } diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs index 964f954dda..3d7dc99b80 100644 --- a/core/lib/src/erased.rs +++ b/core/lib/src/erased.rs @@ -8,6 +8,7 @@ use futures::future::BoxFuture; use http::request::Parts; use tokio::io::{AsyncRead, ReadBuf}; +use crate::catcher::TypedError; use crate::data::{Data, IoHandler, RawStream}; use crate::{Request, Response, Rocket, Orbit}; @@ -34,10 +35,40 @@ impl Drop for ErasedRequest { fn drop(&mut self) { } } -#[derive(Debug)] +pub struct ErasedError<'r> { + error: Option + 'r>>>, +} + +impl<'r> ErasedError<'r> { + pub fn new() -> Self { + Self { error: None } + } + + pub fn write(&mut self, error: Option + 'r>>) { + // SAFETY: To meet the requirements of `Pin`, we never drop + // the inner Box. This is enforced by only allowing writing + // to the Option when it is None. + assert!(self.error.is_none()); + if let Some(error) = error { + self.error = Some(unsafe { Pin::new_unchecked(error) }); + } + } + + pub fn is_some(&self) -> bool { + self.error.is_some() + } + + pub fn get(&'r self) -> Option<&'r dyn TypedError<'r>> { + self.error.as_ref().map(|e| &**e) + } +} + +// TODO: #[derive(Debug)] pub struct ErasedResponse { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'static>, + // XXX: SAFETY: This (dependent) field must come second due to drop order! + error: ErasedError<'static>, _request: Arc, } @@ -68,8 +99,13 @@ impl ErasedRequest { let parts: Box = Box::new(parts); let request: Request<'_> = { let rocket: &Rocket = &rocket; + // SAFETY: The `Request` can borrow from `Rocket` because it has a stable + // address (due to `Arc`) and it is kept alive by the containing + // `ErasedRequest`. The `Request` is always dropped before the + // `Arc` due to drop order. let rocket: &'static Rocket = unsafe { transmute(rocket) }; let parts: &Parts = &parts; + // SAFETY: Same as above, but for `Box`. let parts: &'static Parts = unsafe { transmute(parts) }; constructor(rocket, parts) }; @@ -88,38 +124,56 @@ impl ErasedRequest { preprocess: impl for<'r, 'x> FnOnce( &'r Rocket, &'r mut Request<'x>, - &'r mut Data<'x> + &'r mut Data<'x>, + &'r mut ErasedError<'r>, ) -> BoxFuture<'r, T>, dispatch: impl for<'r> FnOnce( T, &'r Rocket, &'r Request<'r>, - Data<'r> + Data<'r>, + &'r mut ErasedError<'r>, ) -> BoxFuture<'r, Response<'r>>, ) -> ErasedResponse where T: Send + Sync + 'static, D: for<'r> Into> { let mut data: Data<'_> = Data::from(raw_stream); + // SAFETY: At this point, ErasedRequest contains a request, which is permitted + // to borrow from `Rocket` and `Parts`. They both have stable addresses (due to + // `Arc` and `Box`), and the Request will be dropped first (due to drop order). + // SAFETY: Here, we place the `ErasedRequest` (i.e. the `Request`) behind an `Arc` + // to ensure it has a stable address, and we again use drop order to ensure the `Request` + // is dropped before the values that can borrow from it. let mut parent = Arc::new(self); + // SAFETY: This error is permitted to borrow from the `Request` (as well as `Rocket` and + // `Parts`). + let mut error = ErasedError { error: None }; let token: T = { let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap(); let rocket: &Rocket = &parent._rocket; let request: &mut Request<'_> = &mut parent.request; let data: &mut Data<'_> = &mut data; - preprocess(rocket, request, data).await + // SAFETY: As below, `error` must be reborrowed with the correct lifetimes. + preprocess(rocket, request, data, unsafe { transmute(&mut error) }).await }; let parent = parent; let response: Response<'_> = { let parent: &ErasedRequest = &parent; + // SAFETY: This static reference is immediatly reborrowed for the correct lifetime. + // The Response type is permitted to borrow from the `Request`, `Rocket`, `Parts`, and + // `error`. All of these types have stable addresses, and will not be dropped until + // after Response, due to drop order. let parent: &'static ErasedRequest = unsafe { transmute(parent) }; let rocket: &Rocket = &parent._rocket; let request: &Request<'_> = &parent.request; - dispatch(token, rocket, request, data).await + // SAFETY: As above, `error` must be reborrowed with the correct lifetimes. + dispatch(token, rocket, request, data, unsafe { transmute(&mut error) }).await }; ErasedResponse { + error, _request: parent, response, } @@ -147,9 +201,21 @@ impl ErasedResponse { &'a mut Response<'r>, ) -> Option<(T, Box)> ) -> Option<(T, ErasedIoHandler)> { + // SAFETY: If an error has been thrown, the `IoHandler` could + // technically borrow from it, so we must ensure that this is + // not the case. This could be handled safely by changing `error` + // to be an `Arc` internally, and cloning the Arc to get a copy + // (like `ErasedRequest`), however it's unclear this is actually + // useful, and we can avoid paying the cost of an `Arc` + if self.error.is_some() { + warn!("Attempting to upgrade after throwing a typed error is not supported"); + return None; + } let parent: Arc = self._request.clone(); let io: Option<(T, Box)> = { let parent: &ErasedRequest = &parent; + // SAFETY: As in other cases, the request is kept alive by the `Erased...` + // type. let parent: &'static ErasedRequest = unsafe { transmute(parent) }; let request: &Request<'_> = &parent.request; constructor(request, &mut self.response) diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 808b79d213..ca1eebe87b 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -5,7 +5,9 @@ use std::error::Error as StdError; use std::sync::Arc; use figment::Profile; +use transient::{Static, Transient}; +use crate::catcher::TypedError; use crate::listener::Endpoint; use crate::{Catcher, Ignite, Orbit, Phase, Rocket, Route}; use crate::trace::Trace; @@ -85,10 +87,14 @@ pub enum ErrorKind { Shutdown(Arc>), } +impl Static for ErrorKind {} + /// An error that occurs when a value was unexpectedly empty. -#[derive(Clone, Copy, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Copy, Default, PartialEq, Eq, Hash, PartialOrd, Ord, Transient)] pub struct Empty; +impl TypedError<'_> for Empty {} + impl Error { #[inline(always)] pub(crate) fn new(kind: ErrorKind) -> Error { diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index b6dfe16b78..5e1c9226f1 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -1,6 +1,7 @@ use parking_lot::Mutex; use futures::future::{Future, BoxFuture, FutureExt}; +use crate::catcher::TypedError; use crate::{Rocket, Request, Response, Data, Build, Orbit}; use crate::fairing::{Fairing, Kind, Info, Result}; use crate::route::RouteUri; @@ -63,6 +64,10 @@ enum AdHocKind { Request(Box Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static>), + /// An ad-hoc **request_filter** fairing. Called when a request is received. + RequestFilter(Box Fn(&'a Request<'_>, &'b Data<'_>) + -> BoxFuture<'a, Result<(), Box + 'a>>> + Send + Sync + 'static>), + /// An ad-hoc **response** fairing. Called when a response is ready to be /// sent to a client. Response(Box Fn(&'r Request<'_>, &'b mut Response<'r>) @@ -154,11 +159,36 @@ impl AdHoc { /// }); /// ``` pub fn on_request(name: &'static str, f: F) -> AdHoc - where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()> + where F: for<'a> Fn(&'a mut Request<'_>, &'a mut Data<'_>) + -> BoxFuture<'a, ()> { AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } } + /// Constructs an `AdHoc` request fairing named `name`. The function `f` + /// will be called and the returned `Future` will be `await`ed by Rocket + /// when a new request is received. + /// + /// # Example + /// + /// ```rust + /// use rocket::fairing::AdHoc; + /// + /// // The no-op request fairing. + /// let fairing = AdHoc::on_request("Dummy", |req, data| { + /// Box::pin(async move { + /// // do something with the request and data... + /// # let (_, _) = (req, data); + /// }) + /// }); + /// ``` + pub fn filter_request(name: &'static str, f: F) -> AdHoc + where F: for<'a, 'b> Fn(&'a Request<'_>, &'b Data<'_>) + -> BoxFuture<'a, Result<(), Box + 'a>>> + { + AdHoc { name, kind: AdHocKind::RequestFilter(Box::new(f)) } + } + // FIXME(rustc): We'd like to allow passing `async fn` to these methods... // https://github.com/rust-lang/rust/issues/64552#issuecomment-666084589 @@ -380,7 +410,7 @@ impl AdHoc { async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) { // If the URI has no trailing slash, it routes as before. if req.uri().is_normalized_nontrailing() { - return + return; } // Otherwise, check if there's a route that matches the request @@ -407,6 +437,7 @@ impl Fairing for AdHoc { AdHocKind::Ignite(_) => Kind::Ignite, AdHocKind::Liftoff(_) => Kind::Liftoff, AdHocKind::Request(_) => Kind::Request, + AdHocKind::RequestFilter(_) => Kind::RequestFilter, AdHocKind::Response(_) => Kind::Response, AdHocKind::Shutdown(_) => Kind::Shutdown, }; @@ -433,6 +464,16 @@ impl Fairing for AdHoc { } } + async fn filter_request<'r>(&self, req: &'r Request<'_>, data: &Data<'_>) + -> Result<(), Box + 'r>> + { + if let AdHocKind::RequestFilter(ref f) = self.kind { + f(req, data).await + } else { + Ok(()) + } + } + async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { if let AdHocKind::Response(ref f) = self.kind { f(req, res).await diff --git a/core/lib/src/fairing/fairings.rs b/core/lib/src/fairing/fairings.rs index 16316c50e5..8b243ed5e3 100644 --- a/core/lib/src/fairing/fairings.rs +++ b/core/lib/src/fairing/fairings.rs @@ -1,5 +1,6 @@ use std::collections::HashSet; +use crate::erased::ErasedError; use crate::{Rocket, Request, Response, Data, Build, Orbit}; use crate::fairing::{Fairing, Info, Kind}; @@ -15,6 +16,7 @@ pub struct Fairings { ignite: Vec, liftoff: Vec, request: Vec, + filter_request: Vec, response: Vec, shutdown: Vec, } @@ -43,6 +45,7 @@ impl Fairings { self.ignite.iter() .chain(self.liftoff.iter()) .chain(self.request.iter()) + .chain(self.filter_request.iter()) .chain(self.response.iter()) .chain(self.shutdown.iter()) } @@ -104,6 +107,7 @@ impl Fairings { if this_info.kind.is(Kind::Ignite) { self.ignite.push(index); } if this_info.kind.is(Kind::Liftoff) { self.liftoff.push(index); } if this_info.kind.is(Kind::Request) { self.request.push(index); } + if this_info.kind.is(Kind::RequestFilter) { self.filter_request.push(index); } if this_info.kind.is(Kind::Response) { self.response.push(index); } if this_info.kind.is(Kind::Shutdown) { self.shutdown.push(index); } } @@ -147,9 +151,33 @@ impl Fairings { } #[inline(always)] - pub async fn handle_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) { + pub async fn handle_request<'r>( + &self, + req: &'r mut Request<'_>, + data: &mut Data<'_>, + ) { for fairing in iter!(self.request) { - fairing.on_request(req, data).await + fairing.on_request(req, data).await; + } + } + + #[inline(always)] + pub async fn handle_filter<'r>( + &self, + req: &'r Request<'_>, + data: &Data<'_>, + error: &mut ErasedError<'r>, + ) { + for fairing in iter!(self.filter_request) { + match fairing.filter_request(req, data).await { + Ok(()) => (), + Err(e) => { + // SAFETY: `e` can only contain *immutable* borrows of + // `req`. + error.write(Some(e)); + return; + }, + } } } diff --git a/core/lib/src/fairing/info_kind.rs b/core/lib/src/fairing/info_kind.rs index 74ab3a4827..a7f93b6f96 100644 --- a/core/lib/src/fairing/info_kind.rs +++ b/core/lib/src/fairing/info_kind.rs @@ -64,15 +64,18 @@ impl Kind { /// `Kind` flag representing a request for a 'request' callback. pub const Request: Kind = Kind(1 << 2); + /// `Kind` flag representing a request for a 'filter_request' callback. + pub const RequestFilter: Kind = Kind(1 << 3); + /// `Kind` flag representing a request for a 'response' callback. - pub const Response: Kind = Kind(1 << 3); + pub const Response: Kind = Kind(1 << 4); /// `Kind` flag representing a request for a 'shutdown' callback. - pub const Shutdown: Kind = Kind(1 << 4); + pub const Shutdown: Kind = Kind(1 << 5); /// `Kind` flag representing a /// [singleton](crate::fairing::Fairing#singletons) fairing. - pub const Singleton: Kind = Kind(1 << 5); + pub const Singleton: Kind = Kind(1 << 6); /// Returns `true` if `self` is a superset of `other`. In other words, /// returns `true` if all of the kinds in `other` are also in `self`. diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index ad9aaca40f..181fce7554 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -51,6 +51,7 @@ use std::any::Any; +use crate::catcher::TypedError; use crate::{Rocket, Request, Response, Data, Build, Orbit}; mod fairings; @@ -149,9 +150,18 @@ pub type Result, E = Rocket> = std::result::ResultRequest filter (`filter_request`)** +/// +/// A request callback, represented by the [`Fairing::filter_request()`] method, +/// called after `on_request` callbacks have run, but before any handlers have +/// been attempted. This type of fairing can choose to prematurly reject requests, +/// skipping handlers all together, and moving it straight to error handling. This +/// should generally only be used to apply filter that apply to the entire server, +/// e.g. CORS processing. /// /// * **Response (`on_response`)** /// @@ -501,7 +511,28 @@ pub trait Fairing: Send + Sync + Any + 'static { /// ## Default Implementation /// /// The default implementation of this method does nothing. - async fn on_request(&self, _req: &mut Request<'_>, _data: &mut Data<'_>) {} + async fn on_request(&self, _req: &mut Request<'_>, _data: &mut Data<'_>) { } + + /// The request filter callback. + /// + /// See [Fairing Callbacks](#filter_request) for complete semantics. + /// + /// This method is called when a new request is received if `Kind::RequestFilter` + /// is in the `kind` field of the `Info` structure for this fairing. The + /// `&Request` parameter is the incoming request, and the `&Data` + /// parameter is the incoming data in the request. + /// + /// If this method returns `Ok`, the request routed as normal (assuming no other + /// fairing filters it). Otherwise, the request is routed to an error handler + /// based on the error type returned. + /// + /// ## Default Implementation + /// + /// The default implementation of this method does not filter any request, + /// by always returning `Ok(())` + async fn filter_request<'r>(&self, _req: &'r Request<'_>, _data: &Data<'_>) + -> Result<(), Box + 'r>> + { Ok(()) } /// The response callback. /// @@ -555,6 +586,13 @@ impl Fairing for std::sync::Arc { (self as &T).on_request(req, data).await } + #[inline] + async fn filter_request<'r>(&self, req: &'r Request<'_>, data: &Data<'_>) + -> Result<(), Box + 'r>> + { + (self as &T).filter_request(req, data).await + } + #[inline] async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { (self as &T).on_response(req, res).await diff --git a/core/lib/src/form/error.rs b/core/lib/src/form/error.rs index b2c2c06e30..3603953f12 100644 --- a/core/lib/src/form/error.rs +++ b/core/lib/src/form/error.rs @@ -8,7 +8,9 @@ use std::net::AddrParseError; use std::borrow::Cow; use serde::{Serialize, ser::{Serializer, SerializeStruct}}; +use transient::Transient; +use crate::catcher::TypedError; use crate::http::Status; use crate::form::name::{NameBuf, Name}; use crate::data::ByteUnit; @@ -54,10 +56,12 @@ use crate::data::ByteUnit; /// Ok(i) /// } /// ``` -#[derive(Default, Debug, PartialEq, Serialize)] +#[derive(Default, Debug, PartialEq, Serialize, Transient)] #[serde(transparent)] pub struct Errors<'v>(Vec>); +impl<'r> TypedError<'r> for Errors<'r> { } + /// A form error, potentially tied to a specific form field. /// /// An `Error` is returned by [`FromForm`], [`FromFormField`], and [`validate`] @@ -131,7 +135,7 @@ pub struct Errors<'v>(Vec>); /// | `value` | `Option<&str>` | the erroring field's value, if known | /// | `entity` | `&str` | string representation of the erroring [`Entity`] | /// | `msg` | `&str` | concise message of the error | -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Transient)] pub struct Error<'v> { /// The name of the field, if it is known. pub name: Option>, @@ -143,6 +147,8 @@ pub struct Error<'v> { pub entity: Entity, } +// impl<'r> TypedError<'r> for Error<'r> { } + /// The kind of form error that occurred. /// /// ## Constructing @@ -196,7 +202,8 @@ pub enum ErrorKind<'v> { Unknown, /// A custom error occurred. Status defaults to /// [`Status::UnprocessableEntity`] if one is not directly specified. - Custom(Status, Box), + // TODO: This needs to be sync for TypedError + Custom(Status, Box), /// An error while parsing a multipart form occurred. Multipart(multer::Error), /// A string was invalid UTF-8. @@ -451,9 +458,9 @@ impl<'v> Error<'v> { /// } /// ``` pub fn custom(error: E) -> Self - where E: std::error::Error + Send + 'static + where E: std::error::Error + Send + Sync + 'static { - (Box::new(error) as Box).into() + (Box::new(error) as Box).into() } /// Creates a new `Error` with `ErrorKind::Validation` and message `msg`. @@ -966,14 +973,14 @@ impl<'a, 'v: 'a, const N: usize> From<&'static [Cow<'v, str>; N]> for ErrorKind< } } -impl<'a> From> for ErrorKind<'a> { - fn from(e: Box) -> Self { +impl<'a> From> for ErrorKind<'a> { + fn from(e: Box) -> Self { ErrorKind::Custom(Status::UnprocessableEntity, e) } } -impl<'a> From<(Status, Box)> for ErrorKind<'a> { - fn from((status, e): (Status, Box)) -> Self { +impl<'a> From<(Status, Box)> for ErrorKind<'a> { + fn from((status, e): (Status, Box)) -> Self { ErrorKind::Custom(status, e) } } diff --git a/core/lib/src/form/from_form_field.rs b/core/lib/src/form/from_form_field.rs index 2a7f5ab22b..4468765e19 100644 --- a/core/lib/src/form/from_form_field.rs +++ b/core/lib/src/form/from_form_field.rs @@ -412,7 +412,7 @@ static DATE_TIME_FMT2: &[FormatItem<'_>] = impl<'v> FromFormField<'v> for Date { fn from_value(field: ValueField<'v>) -> Result<'v, Self> { let date = Self::parse(field.value, &DATE_FMT) - .map_err(|e| Box::new(e) as Box)?; + .map_err(|e| Box::new(e) as Box)?; Ok(date) } @@ -422,7 +422,7 @@ impl<'v> FromFormField<'v> for Time { fn from_value(field: ValueField<'v>) -> Result<'v, Self> { let time = Self::parse(field.value, &TIME_FMT1) .or_else(|_| Self::parse(field.value, &TIME_FMT2)) - .map_err(|e| Box::new(e) as Box)?; + .map_err(|e| Box::new(e) as Box)?; Ok(time) } @@ -432,7 +432,7 @@ impl<'v> FromFormField<'v> for PrimitiveDateTime { fn from_value(field: ValueField<'v>) -> Result<'v, Self> { let dt = Self::parse(field.value, &DATE_TIME_FMT1) .or_else(|_| Self::parse(field.value, &DATE_TIME_FMT2)) - .map_err(|e| Box::new(e) as Box)?; + .map_err(|e| Box::new(e) as Box)?; Ok(dt) } diff --git a/core/lib/src/fs/named_file.rs b/core/lib/src/fs/named_file.rs index d4eed82a92..4b982d493a 100644 --- a/core/lib/src/fs/named_file.rs +++ b/core/lib/src/fs/named_file.rs @@ -4,8 +4,9 @@ use std::ops::{Deref, DerefMut}; use tokio::fs::{File, OpenOptions}; +use crate::outcome::try_outcome; use crate::request::Request; -use crate::response::{self, Responder}; +use crate::response::{Responder, Outcome}; use crate::http::ContentType; /// A [`Responder`] that sends file data with a Content-Type based on its @@ -152,15 +153,16 @@ impl NamedFile { /// you would like to stream a file with a different Content-Type than that /// implied by its extension, use a [`File`] directly. impl<'r> Responder<'r, 'static> for NamedFile { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { - let mut response = self.1.respond_to(req)?; + type Error = std::convert::Infallible; + fn respond_to(self, req: &'r Request<'_>) -> Outcome<'static, Self::Error> { + let mut response = try_outcome!(self.1.respond_to(req)); if let Some(ext) = self.0.extension() { if let Some(ct) = ContentType::from_extension(&ext.to_string_lossy()) { response.set_header(ct); } } - Ok(response) + Outcome::Success(response) } } diff --git a/core/lib/src/fs/server.rs b/core/lib/src/fs/server.rs index faa95f11d0..47bb52ccaa 100644 --- a/core/lib/src/fs/server.rs +++ b/core/lib/src/fs/server.rs @@ -205,7 +205,7 @@ impl Handler for FileServer { if segments.is_empty() { let file = NamedFile::open(&self.root).await; - return file.respond_to(req).or_forward((data, Status::NotFound)); + return file.respond_to(req).ok().or_forward((data, Status::NotFound, None)); } else { return Outcome::forward(data, Status::NotFound); } @@ -227,7 +227,8 @@ impl Handler for FileServer { return Redirect::permanent(normal) .respond_to(req) - .or_forward((data, Status::InternalServerError)); + .ok() + .or_forward((data, Status::InternalServerError, None)); } if !options.contains(Options::Index) { @@ -235,11 +236,11 @@ impl Handler for FileServer { } let index = NamedFile::open(p.join("index.html")).await; - index.respond_to(req).or_forward((data, Status::NotFound)) + index.respond_to(req).ok().or_forward((data, Status::NotFound, None)) }, Some(p) => { let file = NamedFile::open(p).await; - file.respond_to(req).or_forward((data, Status::NotFound)) + file.respond_to(req).ok().or_forward((data, Status::NotFound, None)) } None => Outcome::forward(data, Status::NotFound), } diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index 6f51c959e7..3ad5cb0149 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -1,12 +1,14 @@ use futures::future::{FutureExt, Future}; +use crate::catcher::TypedError; +use crate::erased::ErasedError; use crate::trace::Trace; use crate::util::Formatter; use crate::data::IoHandler; use crate::http::{Method, Status, Header}; use crate::outcome::Outcome; use crate::form::Form; -use crate::{route, catcher, Rocket, Orbit, Request, Response, Data}; +use crate::{catcher, route, Catcher, Data, Orbit, Request, Response, Rocket}; // A token returned to force the execution of one method before another. pub(crate) struct RequestToken; @@ -51,10 +53,11 @@ impl Rocket { /// /// This is the only place during lifecycle processing that `Request` is /// mutable. Keep this in-sync with the `FromForm` derive. - pub(crate) async fn preprocess( + pub(crate) async fn preprocess<'r>( &self, - req: &mut Request<'_>, - data: &mut Data<'_> + req: &'r mut Request<'_>, + data: &mut Data<'_>, + error: &mut ErasedError<'r>, ) -> RequestToken { // Check if this is a form and if the form contains the special _method // field which we use to reinterpret the request's method. @@ -72,6 +75,7 @@ impl Rocket { // Run request fairings. self.fairings.handle_request(req, data).await; + self.fairings.handle_filter(req, data, error).await; RequestToken } @@ -93,27 +97,49 @@ impl Rocket { _token: RequestToken, request: &'r Request<'s>, data: Data<'r>, + error_ptr: &'r mut ErasedError<'r>, // io_stream: impl Future> + Send, ) -> Response<'r> { // Remember if the request is `HEAD` for later body stripping. let was_head_request = request.method() == Method::Head; + // Route the request and run the user's handlers. - let mut response = match self.route(request, data).await { - Outcome::Success(response) => response, - Outcome::Forward((data, _)) if request.method() == Method::Head => { - tracing::Span::current().record("autohandled", true); - - // Dispatch the request again with Method `GET`. - request._set_method(Method::Get); - match self.route(request, data).await { - Outcome::Success(response) => response, - Outcome::Error(status) => self.dispatch_error(status, request).await, - Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, + let mut response = if error_ptr.is_some() { + // error_ptr is always some here, we just checked. + self.dispatch_error(error_ptr.get().unwrap().status(), request, error_ptr.get()).await + // We MUST wait until we are inside this block to call `get`, since we HAVE to borrow + // it for `'r`. (And it's invariant, so we can't downcast the borrow to a shorter + // lifetime) + } else { + match self.route(request, data).await { + Outcome::Success(response) => response, + Outcome::Forward((data, _, _)) if request.method() == Method::Head => { + tracing::Span::current().record("autohandled", true); + + // Dispatch the request again with Method `GET`. + request._set_method(Method::Get); + match self.route(request, data).await { + Outcome::Success(response) => response, + Outcome::Error((status, error)) => { + error_ptr.write(error); + self.dispatch_error(status, request, error_ptr.get()).await + }, + Outcome::Forward((_, status, error)) => { + error_ptr.write(error); + self.dispatch_error(status, request, error_ptr.get()).await + }, + } } + Outcome::Forward((_, status, error)) => { + error_ptr.write(error); + self.dispatch_error(status, request, error_ptr.get()).await + }, + Outcome::Error((status, error)) => { + error_ptr.write(error); + self.dispatch_error(status, request, error_ptr.get()).await + }, } - Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, - Outcome::Error(status) => self.dispatch_error(status, request).await, }; // Set the cookies. Note that error responses will only include cookies @@ -197,6 +223,7 @@ impl Rocket { // Go through all matching routes until we fail or succeed or run out of // routes to try, in which case we forward with the last status. let mut status = Status::NotFound; + let mut error = None; for route in self.router.route(request) { // Retrieve and set the requests parameters. route.trace_info(); @@ -204,18 +231,18 @@ impl Rocket { let name = route.name.as_deref(); let outcome = catch_handle(name, || route.handler.handle(request, data)).await - .unwrap_or(Outcome::Error(Status::InternalServerError)); + .unwrap_or(Outcome::error(Status::InternalServerError)); // Check if the request processing completed (Some) or if the // request needs to be forwarded. If it does, continue the loop outcome.trace_info(); match outcome { o@Outcome::Success(_) | o@Outcome::Error(_) => return o, - Outcome::Forward(forwarded) => (data, status) = forwarded, + Outcome::Forward(forwarded) => (data, status, error) = forwarded, } } - Outcome::Forward((data, status)) + Outcome::Forward((data, status, error)) } // Invokes the catcher for `status`. Returns the response on success. @@ -229,17 +256,19 @@ impl Rocket { pub(crate) async fn dispatch_error<'r, 's: 'r>( &'s self, mut status: Status, - req: &'r Request<'s> + req: &'r Request<'s>, + mut error: Option<&'r dyn TypedError<'r>>, ) -> Response<'r> { // We may wish to relax this in the future. req.cookies().reset_delta(); loop { // Dispatch to the `status` catcher. - match self.invoke_catcher(status, req).await { + match self.invoke_catcher(status, error, req).await { Ok(r) => return r, // If the catcher failed, try `500` catcher, unless this is it. Err(e) if status.code != 500 => { + error = None; warn!(status = e.map(|r| r.code), "catcher failed: trying 500 catcher"); status = Status::InternalServerError; } @@ -262,20 +291,66 @@ impl Rocket { /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))` /// if the handler ran to completion but failed. Returns `Ok(None)` if the /// handler panicked while executing. + /// + /// # TODO: updated semantics: + /// + /// Selects and invokes a specific catcher, with the following preference: + /// - Best matching error type (prefers calling `.source()` the fewest number + /// of times) + /// - The longest path base + /// - Matching status + /// - The error's built-in responder (TODO: should this be before untyped catchers?) + /// - If no catcher is found, Rocket's default handler is invoked + /// + /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))` + /// if the handler ran to completion but failed. Returns `Ok(None)` if the + /// handler panicked while executing. + /// + /// TODO: These semantics should (ideally) match the old semantics in the case where + /// `error` is `None`. async fn invoke_catcher<'s, 'r: 's>( &'s self, status: Status, + error: Option<&'r dyn TypedError<'r>>, req: &'r Request<'s> ) -> Result, Option> { - if let Some(catcher) = self.router.catch(status, req) { - catcher.trace_info(); - catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, req)).await - .map(|result| result.map_err(Some)) - .unwrap_or_else(|| Err(None)) + // Lists error types by repeatedly calling `.source()` + let catchers = std::iter::successors(error, |e| e.source()) + // Only go up to 5 levels deep (to prevent an endless cycle) + .take(5) + // Map to catchers + .filter_map(|e| { + self.router.catch(status, req, Some(e.trait_obj_typeid())).map(|c| (c, e)) + }) + // Select the minimum by the catcher's rank + .min_by_key(|(c, _)| c.rank); + if let Some((catcher, e)) = catchers { + self.invoke_specific_catcher(catcher, status, Some(e), req).await + } else if let Some(catcher) = self.router.catch(status, req, None) { + self.invoke_specific_catcher(catcher, status, error, req).await + } else if let Some(res) = error.and_then(|e| e.respond_to(req).ok()) { + Ok(res) } else { info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = status.code, "no registered catcher: using Rocket default"); Ok(catcher::default_handler(status, req)) } } + + /// Invokes a specific catcher + async fn invoke_specific_catcher<'s, 'r: 's>( + &'s self, + catcher: &Catcher, + status: Status, + error: Option<&'r dyn TypedError<'r>>, + req: &'r Request<'s> + ) -> Result, Option> { + catcher.trace_info(); + catch_handle( + catcher.name.as_deref(), + || catcher.handler.handle(status, req, error) + ).await + .map(|result| result.map_err(Some)) + .unwrap_or_else(|| Err(None)) + } } diff --git a/core/lib/src/local/asynchronous/request.rs b/core/lib/src/local/asynchronous/request.rs index 4c85c02024..17d25534c0 100644 --- a/core/lib/src/local/asynchronous/request.rs +++ b/core/lib/src/local/asynchronous/request.rs @@ -1,5 +1,6 @@ use std::fmt; +use crate::request::RequestErrors; use crate::{Request, Data}; use crate::http::{Status, Method}; use crate::http::uri::Origin; @@ -75,7 +76,7 @@ impl<'c> LocalRequest<'c> { } // Performs the actual dispatch. - async fn _dispatch(mut self) -> LocalResponse<'c> { + async fn _dispatch(self) -> LocalResponse<'c> { // First, revalidate the URI, returning an error response (generated // from an error catcher) immediately if it's invalid. If it's valid, // then `request` already contains a correct URI. @@ -85,17 +86,23 @@ impl<'c> LocalRequest<'c> { // _shouldn't_ error. Check that now and error only if not. if self.inner().uri() == invalid { error!("invalid request URI: {:?}", invalid.path()); - return LocalResponse::new(self.request, move |req| { - rocket.dispatch_error(Status::BadRequest, req) - }).await + return LocalResponse::error(self.request, move |req, error_ptr| { + // TODO: Ideally the RequestErrors should contain actual information. + error_ptr.write(Some(Box::new(RequestErrors::new(&[])))); + rocket.dispatch_error(Status::BadRequest, req, error_ptr.get()) + }).await; } } // Actually dispatch the request. - let mut data = Data::local(self.data); - let token = rocket.preprocess(&mut self.request, &mut data).await; - let response = LocalResponse::new(self.request, move |req| { - rocket.dispatch(token, req, data) + let data = Data::local(self.data); + // let token = rocket.preprocess(&mut self.request, &mut data, &mut self.error).await; + let response = LocalResponse::new(self.request, data, + move |req, data, error_ptr| { + rocket.preprocess(req, data, error_ptr) + }, + move |token, req, data, error_ptr| { + rocket.dispatch(token, req, data, error_ptr) }).await; // If the client is tracking cookies, updates the internal cookie jar diff --git a/core/lib/src/local/asynchronous/response.rs b/core/lib/src/local/asynchronous/response.rs index 06ae18e3b7..0faaa690c3 100644 --- a/core/lib/src/local/asynchronous/response.rs +++ b/core/lib/src/local/asynchronous/response.rs @@ -4,8 +4,10 @@ use std::{pin::Pin, task::{Context, Poll}}; use tokio::io::{AsyncRead, ReadBuf}; +use crate::erased::ErasedError; use crate::http::CookieJar; -use crate::{Request, Response}; +use crate::lifecycle::RequestToken; +use crate::{Data, Request, Response}; /// An `async` response from a dispatched [`LocalRequest`](super::LocalRequest). /// @@ -55,6 +57,7 @@ use crate::{Request, Response}; pub struct LocalResponse<'c> { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'c>, + _error: ErasedError<'c>, cookies: CookieJar<'c>, _request: Box>, } @@ -64,14 +67,86 @@ impl Drop for LocalResponse<'_> { } impl<'c> LocalResponse<'c> { - pub(crate) fn new(req: Request<'c>, f: F) -> impl Future> - where F: FnOnce(&'c Request<'c>) -> O + Send, - O: Future> + Send + pub(crate) fn new(req: Request<'c>, mut data: Data<'c>, preprocess: P, f: F) + -> impl Future> + where P: FnOnce(&'c mut Request<'c>, &'c mut Data<'c>, &'c mut ErasedError<'c>) + -> PO + Send, + PO: Future + Send + 'c, + F: FnOnce(RequestToken, &'c Request<'c>, Data<'c>, &'c mut ErasedError<'c>) + -> O + Send, + O: Future> + Send + 'c { // `LocalResponse` is a self-referential structure. In particular, // `response` and `cookies` can refer to `_request` and its contents. As // such, we must - // 1) Ensure `Request` has a stable address. + // 1) Ensure `Request` and `TypedError` have a stable address. + // + // This is done by `Box`ing the `Request`, using only the stable + // address thereafter. + // + // 2) Ensure no refs to `Request` or its contents leak with a lifetime + // extending beyond that of `&self`. + // + // We have no methods that return an `&Request`. However, we must + // also ensure that `Response` doesn't leak any such references. To + // do so, we don't expose the `Response` directly in any way; + // otherwise, methods like `.headers()` could, in conjunction with + // particular crafted `Responder`s, potentially be used to obtain a + // reference to contents of `Request`. All methods, instead, return + // references bounded by `self`. This is easily verified by noting + // that 1) `LocalResponse` fields are private, and 2) all `impl`s + // of `LocalResponse` aside from this method abstract the lifetime + // away as `'_`, ensuring it is not used for any output value. + let mut boxed_req = Box::new(req); + let mut error = ErasedError::new(); + + async move { + use std::mem::transmute; + + let token = { + // SAFETY: Much like request above, error can borrow from request, and + // response can borrow from request or error. TODO + let request: &'c mut Request<'c> = unsafe { &mut *(&mut *boxed_req as *mut _) }; + // SAFETY: The type of `preprocess` ensures that all of these types have the correct + // lifetime ('c). + preprocess( + request, + unsafe { transmute(&mut data) }, + unsafe { transmute(&mut error) }, + ).await + }; + // SAFETY: Much like request above, error can borrow from request, and + // response can borrow from request or error. TODO + let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) }; + // NOTE: The cookie jar `secure` state will not reflect the last + // known value in `request.cookies()`. This is okay: new cookies + // should never be added to the resulting jar which is the only time + // the value is used to set cookie defaults. + // SAFETY: The type of `preprocess` ensures that all of these types have the correct + // lifetime ('c). + let response: Response<'c> = f( + token, + request, + data, + unsafe { transmute(&mut error) } + ).await; + let mut cookies = CookieJar::new(None, request.rocket()); + for cookie in response.cookies() { + cookies.add_original(cookie.into_owned()); + } + + LocalResponse { _request: boxed_req, _error: error, cookies, response, } + } + } + + pub(crate) fn error(req: Request<'c>, f: F) -> impl Future> + where F: FnOnce(&'c Request<'c>, &'c mut ErasedError<'c>) -> O + Send, + O: Future> + Send + 'c + { + // `LocalResponse` is a self-referential structure. In particular, + // `response` and `cookies` can refer to `_request` and its contents. As + // such, we must + // 1) Ensure `Request` and `TypedError` have a stable address. // // This is done by `Box`ing the `Request`, using only the stable // address thereafter. @@ -90,20 +165,25 @@ impl<'c> LocalResponse<'c> { // of `LocalResponse` aside from this method abstract the lifetime // away as `'_`, ensuring it is not used for any output value. let boxed_req = Box::new(req); - let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) }; async move { + use std::mem::transmute; + let mut error = ErasedError::new(); + // NOTE: The cookie jar `secure` state will not reflect the last // known value in `request.cookies()`. This is okay: new cookies // should never be added to the resulting jar which is the only time // the value is used to set cookie defaults. - let response: Response<'c> = f(request).await; + // SAFETY: Much like request above, error can borrow from request, and + // response can borrow from request or error. TODO + let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) }; + let response: Response<'c> = f(request, unsafe { transmute(&mut error) }).await; let mut cookies = CookieJar::new(None, request.rocket()); for cookie in response.cookies() { cookies.add_original(cookie.into_owned()); } - LocalResponse { _request: boxed_req, cookies, response, } + LocalResponse { _request: boxed_req, _error: error, cookies, response, } } } } diff --git a/core/lib/src/mtls/error.rs b/core/lib/src/mtls/error.rs index 703835f299..4452469f7d 100644 --- a/core/lib/src/mtls/error.rs +++ b/core/lib/src/mtls/error.rs @@ -2,6 +2,7 @@ use std::fmt; use std::num::NonZeroUsize; use crate::mtls::x509::{self, nom}; +use transient::Static; /// An error returned by the [`Certificate`](crate::mtls::Certificate) guard. /// @@ -41,6 +42,8 @@ pub enum Error { Trailing(usize), } +impl Static for Error {} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/core/lib/src/outcome.rs b/core/lib/src/outcome.rs index 35521aa36a..acb20087eb 100644 --- a/core/lib/src/outcome.rs +++ b/core/lib/src/outcome.rs @@ -86,7 +86,8 @@ //! a type of `Option`. If an `Outcome` is a `Forward`, the `Option` will be //! `None`. -use crate::{route, request, response}; +use crate::catcher::TypedError; +use crate::request; use crate::data::{self, Data, FromData}; use crate::http::Status; @@ -611,6 +612,41 @@ impl Outcome { Outcome::Forward(v) => Err(v), } } + + /// Converts `Outcome` to `Option` by dropping error + /// and forward variants, and returning `None` + #[inline] + pub fn ok(self) -> Option { + match self { + Self::Success(v) => Some(v), + _ => None, + } + } +} + +impl Outcome { + /// Convenience function to convert the error type from `Infallible` + /// to any other type. This is trivially possible, since `Infallible` + /// cannot be constructed, so this cannot be an Error variant + pub(crate) fn map_err_type(self) -> Outcome { + match self { + Self::Success(v) => Outcome::Success(v), + Self::Forward(v) => Outcome::Forward(v), + Self::Error(e) => match e {}, + } + } +} + +impl<'r, S, E: TypedError<'r>> Outcome { + /// Convenience function to convert the outcome from a Responder impl to + /// the result used by TypedError + pub fn responder_error(self) -> Result { + match self { + Self::Success(v) => Ok(v), + Self::Forward(v) => Err(v), + Self::Error(e) => Err(e.status()), + } + } } impl<'a, S: Send + 'a, E: Send + 'a, F: Send + 'a> Outcome { @@ -788,23 +824,53 @@ impl IntoOutcome> for Result { } } -impl<'r, 'o: 'r> IntoOutcome> for response::Result<'o> { - type Error = (); - type Forward = (Data<'r>, Status); - - #[inline] - fn or_error(self, _: ()) -> route::Outcome<'r> { - match self { - Ok(val) => Success(val), - Err(status) => Error(status), - } - } - - #[inline] - fn or_forward(self, (data, forward): (Data<'r>, Status)) -> route::Outcome<'r> { - match self { - Ok(val) => Success(val), - Err(_) => Forward((data, forward)) - } - } -} +// impl<'r, 'o: 'r> IntoOutcome> for response::Result<'o> { +// type Error = (); +// type Forward = (Data<'r>, Status); + +// #[inline] +// fn or_error(self, _: ()) -> route::Outcome<'r> { +// match self { +// Ok(val) => Success(val), +// Err(status) => Error((status, default_error_type())), +// } +// } + +// #[inline] +// fn or_forward(self, (data, forward): (Data<'r>, Status)) -> route::Outcome<'r> { +// match self { +// Ok(val) => Success(val), +// Err(_) => Forward((data, forward, default_error_type())) +// } +// } +// } + +// type RoutedOutcome<'r, T> = Outcome< +// T, +// (Status, ErasedError<'r>), +// (Data<'r>, Status, ErasedError<'r>) +// >; + +// impl<'r, T, E: Transient> IntoOutcome> for Option> +// where E::Transience: CanTranscendTo>, +// E: Send + Sync + 'r, +// { +// type Error = Status; +// type Forward = (Data<'r>, Status); + +// fn or_error(self, error: Self::Error) -> RoutedOutcome<'r, T> { +// match self { +// Some(Ok(v)) => Outcome::Success(v), +// Some(Err(e)) => Outcome::Error((error, Box::new(e))), +// None => Outcome::Error((error, default_error_type())), +// } +// } + +// fn or_forward(self, forward: Self::Forward) -> RoutedOutcome<'r, T> { +// match self { +// Some(Ok(v)) => Outcome::Success(v), +// Some(Err(e)) => Outcome::Forward((forward.0, forward.1, Box::new(e))), +// None => Outcome::Forward((forward.0, forward.1, default_error_type())), +// } +// } +// } diff --git a/core/lib/src/request/mod.rs b/core/lib/src/request/mod.rs index 0393f96b51..a09bae68c2 100644 --- a/core/lib/src/request/mod.rs +++ b/core/lib/src/request/mod.rs @@ -8,7 +8,7 @@ mod atomic_method; #[cfg(test)] mod tests; -pub use self::request::Request; +pub use self::request::{Request, RequestErrors}; pub use self::from_request::{FromRequest, Outcome}; pub use self::from_param::{FromParam, FromSegments}; diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 93912383de..fd79046408 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -8,14 +8,16 @@ use std::net::IpAddr; use state::{TypeMap, InitCell}; use futures::future::BoxFuture; use ref_swap::OptionRefSwap; +use transient::Transient; +use crate::catcher::TypedError; use crate::{Rocket, Route, Orbit}; use crate::request::{FromParam, FromSegments, FromRequest, Outcome, AtomicMethod}; use crate::form::{self, ValueField, FromForm}; use crate::data::Limits; -use crate::http::ProxyProto; -use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie}; +use crate::http::{Method, Header, HeaderMap, ContentType, Accept, MediaType, CookieJar, Cookie, + ProxyProto, Status}; use crate::http::uri::{fmt::Path, Origin, Segments, Host, Authority}; use crate::listener::{Certificates, Endpoint}; @@ -1175,6 +1177,23 @@ impl<'r> Request<'r> { } } +#[derive(Debug, Clone, Copy, Transient)] +pub struct RequestErrors<'r> { + errors: &'r [RequestError], +} + +impl<'r> RequestErrors<'r> { + pub(crate) fn new(errors: &'r [RequestError]) -> Self { + Self { errors } + } +} + +impl<'r> TypedError<'r> for RequestErrors<'r> { + fn status(&self) -> Status { + Status::BadRequest + } +} + #[derive(Debug, Clone)] pub(crate) enum RequestError { InvalidUri(hyper::Uri), @@ -1194,8 +1213,8 @@ impl fmt::Debug for Request<'_> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Request") .field("method", &self.method()) - .field("uri", &self.uri()) - .field("headers", &self.headers()) + .field("uri", self.uri()) + .field("headers", self.headers()) .field("remote", &self.remote()) .field("cookies", &self.cookies()) .finish() diff --git a/core/lib/src/response/content.rs b/core/lib/src/response/content.rs index 68d7a33d0f..bdeb666496 100644 --- a/core/lib/src/response/content.rs +++ b/core/lib/src/response/content.rs @@ -31,6 +31,7 @@ //! let response = content::RawHtml("

Hello, world!

"); //! ``` +use crate::outcome::try_outcome; use crate::request::Request; use crate::response::{self, Response, Responder}; use crate::http::ContentType; @@ -58,7 +59,8 @@ macro_rules! ctrs { /// Sets the Content-Type of the response then delegates the /// remainder of the response to the wrapped responder. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for $name { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = R::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { (ContentType::$ct, self.0).respond_to(req) } } @@ -78,9 +80,10 @@ ctrs! { } impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for (ContentType, R) { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = R::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { Response::build() - .merge(self.1.respond_to(req)?) + .merge(try_outcome!(self.1.respond_to(req))) .header(self.0) .ok() } diff --git a/core/lib/src/response/debug.rs b/core/lib/src/response/debug.rs index a7d3e612a0..8adf2bf7a6 100644 --- a/core/lib/src/response/debug.rs +++ b/core/lib/src/response/debug.rs @@ -1,3 +1,5 @@ +use transient::Static; + use crate::request::Request; use crate::response::{self, Responder}; use crate::http::Status; @@ -29,6 +31,7 @@ use crate::http::Status; /// Because of the generic `From` implementation for `Debug`, conversions /// from `Result` to `Result>` through `?` occur /// automatically: +/// TODO: this has changed /// /// ```rust /// use std::string::FromUtf8Error; @@ -37,7 +40,7 @@ use crate::http::Status; /// use rocket::response::Debug; /// /// #[get("/")] -/// fn rand_str() -> Result> { +/// fn rand_str() -> Result { /// # /* /// let bytes: Vec = random_bytes(); /// # */ @@ -56,17 +59,19 @@ use crate::http::Status; /// use rocket::response::Debug; /// /// #[get("/")] -/// fn rand_str() -> Result> { +/// fn rand_str() -> Result { /// # /* /// let bytes: Vec = random_bytes(); /// # */ /// # let bytes: Vec = vec![]; -/// String::from_utf8(bytes).map_err(Debug) +/// String::from_utf8(bytes) /// } /// ``` #[derive(Debug)] pub struct Debug(pub E); +impl Static for Debug {} + impl From for Debug { #[inline(always)] fn from(e: E) -> Self { @@ -75,17 +80,19 @@ impl From for Debug { } impl<'r, E: std::fmt::Debug> Responder<'r, 'static> for Debug { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { let type_name = std::any::type_name::(); info!(type_name, value = ?self.0, "debug response (500)"); - Err(Status::InternalServerError) + response::Outcome::Forward(Status::InternalServerError) } } /// Prints a warning with the error and forwards to the `500` error catcher. impl<'r> Responder<'r, 'static> for std::io::Error { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { warn!("i/o error response: {self}"); - Err(Status::InternalServerError) + response::Outcome::Forward(Status::InternalServerError) } } diff --git a/core/lib/src/response/flash.rs b/core/lib/src/response/flash.rs index 279ec854b6..c96b2071c2 100644 --- a/core/lib/src/response/flash.rs +++ b/core/lib/src/response/flash.rs @@ -50,13 +50,14 @@ const FLASH_COOKIE_DELIM: char = ':'; /// # #[macro_use] extern crate rocket; /// use rocket::response::{Flash, Redirect}; /// use rocket::request::FlashMessage; +/// use rocket::either::Either; /// /// #[post("/login/")] -/// fn login(name: &str) -> Result<&'static str, Flash> { +/// fn login(name: &str) -> Either<&'static str, Flash> { /// if name == "special_user" { -/// Ok("Hello, special user!") +/// Either::Left("Hello, special user!") /// } else { -/// Err(Flash::error(Redirect::to(uri!(index)), "Invalid username.")) +/// Either::Right(Flash::error(Redirect::to(uri!(index)), "Invalid username.")) /// } /// } /// @@ -189,7 +190,8 @@ impl Flash { /// response handling to the wrapped responder. As a result, the `Outcome` of /// the response is the `Outcome` of the wrapped `Responder`. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Flash { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = R::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { req.cookies().add(self.cookie()); self.inner.respond_to(req) } diff --git a/core/lib/src/response/mod.rs b/core/lib/src/response/mod.rs index 71f0ff6980..3f7ee7df22 100644 --- a/core/lib/src/response/mod.rs +++ b/core/lib/src/response/mod.rs @@ -35,5 +35,7 @@ pub use self::redirect::Redirect; pub use self::flash::Flash; pub use self::debug::Debug; -/// Type alias for the `Result` of a [`Responder::respond_to()`] call. -pub type Result<'r> = std::result::Result, crate::http::Status>; +use crate::http::Status; + +/// Type alias for the `Outcome` of a [`Responder::respond_to()`] call. +pub type Outcome<'o, Error> = crate::outcome::Outcome, Error, Status>; diff --git a/core/lib/src/response/redirect.rs b/core/lib/src/response/redirect.rs index f685fe1b6c..b8dce12ad4 100644 --- a/core/lib/src/response/redirect.rs +++ b/core/lib/src/response/redirect.rs @@ -1,3 +1,6 @@ +use transient::Transient; + +use crate::catcher::TypedError; use crate::request::Request; use crate::response::{self, Response, Responder}; use crate::http::uri::Reference; @@ -45,7 +48,7 @@ use crate::http::Status; /// /// [`Origin`]: crate::http::uri::Origin /// [`uri!`]: ../macro.uri.html -#[derive(Debug)] +#[derive(Debug, Transient)] pub struct Redirect(Status, Option>); impl Redirect { @@ -87,63 +90,63 @@ impl Redirect { Redirect(Status::TemporaryRedirect, uri.try_into().ok()) } - /// Construct a "permanent" (308) redirect response. This redirect must only - /// be used for permanent redirects as it is cached by clients. This - /// response instructs the client to reissue requests for the current URL to - /// a different URL, now and in the future, maintaining the contents of the - /// request identically. This means that, for example, a `POST` request will - /// be resent, contents included, to the requested URL. - /// - /// # Examples - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::response::Redirect; - /// - /// let redirect = Redirect::permanent(uri!("/other_url")); - /// let redirect = Redirect::permanent(format!("some-{}-thing", "crazy")); - /// ``` - pub fn permanent>>(uri: U) -> Redirect { - Redirect(Status::PermanentRedirect, uri.try_into().ok()) - } + /// Construct a "permanent" (308) redirect response. This redirect must only + /// be used for permanent redirects as it is cached by clients. This + /// response instructs the client to reissue requests for the current URL to + /// a different URL, now and in the future, maintaining the contents of the + /// request identically. This means that, for example, a `POST` request will + /// be resent, contents included, to the requested URL. + /// + /// # Examples + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::response::Redirect; + /// + /// let redirect = Redirect::permanent(uri!("/other_url")); + /// let redirect = Redirect::permanent(format!("some-{}-thing", "crazy")); + /// ``` + pub fn permanent>>(uri: U) -> Redirect { + Redirect(Status::PermanentRedirect, uri.try_into().ok()) + } - /// Construct a temporary "found" (302) redirect response. This response - /// instructs the client to reissue the current request to a different URL, - /// ideally maintaining the contents of the request identically. - /// Unfortunately, different clients may respond differently to this type of - /// redirect, so `303` or `307` redirects, which disambiguate, are - /// preferred. - /// - /// # Examples - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::response::Redirect; - /// - /// let redirect = Redirect::found(uri!("/other_url")); - /// let redirect = Redirect::found(format!("some-{}-thing", "crazy")); - /// ``` - pub fn found>>(uri: U) -> Redirect { - Redirect(Status::Found, uri.try_into().ok()) - } + /// Construct a temporary "found" (302) redirect response. This response + /// instructs the client to reissue the current request to a different URL, + /// ideally maintaining the contents of the request identically. + /// Unfortunately, different clients may respond differently to this type of + /// redirect, so `303` or `307` redirects, which disambiguate, are + /// preferred. + /// + /// # Examples + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::response::Redirect; + /// + /// let redirect = Redirect::found(uri!("/other_url")); + /// let redirect = Redirect::found(format!("some-{}-thing", "crazy")); + /// ``` + pub fn found>>(uri: U) -> Redirect { + Redirect(Status::Found, uri.try_into().ok()) + } - /// Construct a permanent "moved" (301) redirect response. This response - /// should only be used for permanent redirects as it can be cached by - /// browsers. Because different clients may respond differently to this type - /// of redirect, a `308` redirect, which disambiguates, is preferred. - /// - /// # Examples - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// use rocket::response::Redirect; - /// - /// let redirect = Redirect::moved(uri!("here")); - /// let redirect = Redirect::moved(format!("some-{}-thing", "crazy")); - /// ``` - pub fn moved>>(uri: U) -> Redirect { - Redirect(Status::MovedPermanently, uri.try_into().ok()) - } + /// Construct a permanent "moved" (301) redirect response. This response + /// should only be used for permanent redirects as it can be cached by + /// browsers. Because different clients may respond differently to this type + /// of redirect, a `308` redirect, which disambiguates, is preferred. + /// + /// # Examples + /// + /// ```rust + /// # #[macro_use] extern crate rocket; + /// use rocket::response::Redirect; + /// + /// let redirect = Redirect::moved(uri!("here")); + /// let redirect = Redirect::moved(format!("some-{}-thing", "crazy")); + /// ``` + pub fn moved>>(uri: U) -> Redirect { + Redirect(Status::MovedPermanently, uri.try_into().ok()) + } } /// Constructs a response with the appropriate status code and the given URL in @@ -151,15 +154,36 @@ impl Redirect { /// value used to create the `Responder` is an invalid URI, an error of /// `Status::InternalServerError` is returned. impl<'r> Responder<'r, 'static> for Redirect { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { if let Some(uri) = self.1 { Response::build() .status(self.0) .raw_header("Location", uri.to_string()) .ok() + } else { + error!("Invalid URI used for redirect."); + response::Outcome::Forward(Status::InternalServerError) + } + } +} + +// TODO: This is a hack +impl<'r> TypedError<'r> for Redirect { + fn respond_to(&self, _req: &'r Request<'r>) -> Result, Status> { + if let Some(uri) = &self.1 { + Response::build() + .status(self.0) + .raw_header("Location", uri.to_string()) + .ok::<()>() + .responder_error() } else { error!("Invalid URI used for redirect."); Err(Status::InternalServerError) } } + + fn status(&self) -> Status { + self.0 + } } diff --git a/core/lib/src/response/responder.rs b/core/lib/src/response/responder.rs index f31262c7fc..c49e5a7c79 100644 --- a/core/lib/src/response/responder.rs +++ b/core/lib/src/response/responder.rs @@ -1,11 +1,19 @@ +use std::convert::Infallible; +use std::fmt; use std::fs::File; use std::io::Cursor; use std::sync::Arc; +use either::Either; +use transient::{CanTranscendTo, Inv, Transient}; + +use crate::catcher::TypedError; use crate::http::{Status, ContentType, StatusClass}; use crate::response::{self, Response}; use crate::request::Request; +use super::Outcome; + /// Trait implemented by types that generate responses for clients. /// /// Any type that implements `Responder` can be used as the return type of a @@ -173,7 +181,8 @@ use crate::request::Request; /// # struct A; /// // If the response contains no borrowed data. /// impl<'r> Responder<'r, 'static> for A { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { +/// type Error = std::convert::Infallible; +/// fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { /// todo!() /// } /// } @@ -181,7 +190,8 @@ use crate::request::Request; /// # struct B<'r>(&'r str); /// // If the response borrows from the request. /// impl<'r> Responder<'r, 'r> for B<'r> { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { +/// type Error = std::convert::Infallible; +/// fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'r, Self::Error> { /// todo!() /// } /// } @@ -189,7 +199,8 @@ use crate::request::Request; /// # struct C; /// // If the response is or wraps a borrow that may outlive the request. /// impl<'r, 'o: 'r> Responder<'r, 'o> for &'o C { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { +/// type Error = std::convert::Infallible; +/// fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { /// todo!() /// } /// } @@ -197,7 +208,8 @@ use crate::request::Request; /// # struct D(R); /// // If the response wraps an existing responder. /// impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for D { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { +/// type Error = std::convert::Infallible; +/// fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { /// todo!() /// } /// } @@ -244,12 +256,14 @@ use crate::request::Request; /// /// use rocket::request::Request; /// use rocket::response::{self, Response, Responder}; +/// use rocket::outcome::try_outcome; /// use rocket::http::ContentType; /// /// impl<'r> Responder<'r, 'static> for Person { -/// fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { +/// type Error = std::convert::Infallible; +/// fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { /// let string = format!("{}:{}", self.name, self.age); -/// Response::build_from(string.respond_to(req)?) +/// Response::build_from(try_outcome!(string.respond_to(req))) /// .raw_header("X-Person-Name", self.name) /// .raw_header("X-Person-Age", self.age.to_string()) /// .header(ContentType::new("application", "x-person")) @@ -291,6 +305,8 @@ use crate::request::Request; /// # fn person() -> Person { Person::new("Bob", 29) } /// ``` pub trait Responder<'r, 'o: 'r> { + type Error: TypedError<'r> + Transient; + /// Returns `Ok` if a `Response` could be generated successfully. Otherwise, /// returns an `Err` with a failing `Status`. /// @@ -302,13 +318,14 @@ pub trait Responder<'r, 'o: 'r> { /// returned, the error catcher for the given status is retrieved and called /// to generate a final error response, which is then written out to the /// client. - fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o>; + fn respond_to(self, request: &'r Request<'_>) -> response::Outcome<'o, Self::Error>; } /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r, 'o: 'r> Responder<'r, 'o> for &'o str { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(self)) @@ -319,7 +336,8 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for &'o str { /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for String { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(self)) @@ -339,7 +357,8 @@ impl AsRef<[u8]> for DerefRef where T::Target: AsRef<[u8] /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Arc { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(DerefRef(self))) @@ -350,7 +369,8 @@ impl<'r> Responder<'r, 'static> for Arc { /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Box { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(DerefRef(self))) @@ -361,7 +381,8 @@ impl<'r> Responder<'r, 'static> for Box { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r, 'o: 'r> Responder<'r, 'o> for &'o [u8] { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -372,7 +393,8 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for &'o [u8] { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Vec { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -383,7 +405,8 @@ impl<'r> Responder<'r, 'static> for Vec { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Arc<[u8]> { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -394,7 +417,8 @@ impl<'r> Responder<'r, 'static> for Arc<[u8]> { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Box<[u8]> { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -437,8 +461,11 @@ impl<'r> Responder<'r, 'static> for Box<[u8]> { /// Content::Text("hello".to_string()) /// } /// ``` -impl<'r, 'o: 'r, T: Responder<'r, 'o> + Sized> Responder<'r, 'o> for Box { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { +impl<'r, 'o: 'r, T: Responder<'r, 'o> + Sized> Responder<'r, 'o> for Box + where ::Transience: CanTranscendTo>, +{ + type Error = T::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { let inner = *self; inner.respond_to(req) } @@ -446,33 +473,48 @@ impl<'r, 'o: 'r, T: Responder<'r, 'o> + Sized> Responder<'r, 'o> for Box { /// Returns a response with a sized body for the file. Always returns `Ok`. impl<'r> Responder<'r, 'static> for File { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { tokio::fs::File::from(self).respond_to(req) } } /// Returns a response with a sized body for the file. Always returns `Ok`. impl<'r> Responder<'r, 'static> for tokio::fs::File { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build().sized_body(None, self).ok() } } /// Returns an empty, default `Response`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for () { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { - Ok(Response::new()) + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { + Outcome::Success(Response::new()) } } /// Responds with the inner `Responder` in `Cow`. impl<'r, 'o: 'r, R: ?Sized + ToOwned> Responder<'r, 'o> for std::borrow::Cow<'o, R> - where &'o R: Responder<'r, 'o> + 'o, ::Owned: Responder<'r, 'o> + 'r + where &'o R: Responder<'r, 'o> + 'o, + <&'o R as Responder<'r, 'o>>::Error: Transient, + <<&'o R as Responder<'r, 'o>>::Error as Transient>::Transience: CanTranscendTo>, + ::Owned: Responder<'r, 'o> + 'r, + <::Owned as Responder<'r, 'o>>::Error: Transient, + <<::Owned as Responder<'r, 'o>>::Error as Transient>::Transience: + CanTranscendTo>, + // TODO: this atrocious formatting { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = Either< + <&'o R as Responder<'r, 'o>>::Error, + >::Error, + >; + + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { match self { - std::borrow::Cow::Borrowed(b) => b.respond_to(req), - std::borrow::Cow::Owned(o) => o.respond_to(req), + std::borrow::Cow::Borrowed(b) => b.respond_to(req).map_error(|e| Either::Left(e)), + std::borrow::Cow::Owned(o) => o.respond_to(req).map_error(|e| Either::Right(e)), } } } @@ -480,13 +522,14 @@ impl<'r, 'o: 'r, R: ?Sized + ToOwned> Responder<'r, 'o> for std::borrow::Cow<'o, /// If `self` is `Some`, responds with the wrapped `Responder`. Otherwise prints /// a warning message and returns an `Err` of `Status::NotFound`. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Option { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = R::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { match self { Some(r) => r.respond_to(req), None => { let type_name = std::any::type_name::(); debug!(type_name, "`Option` responder returned `None`"); - Err(Status::NotFound) + Outcome::Forward(Status::NotFound) }, } } @@ -494,26 +537,36 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Option { /// Responds with the wrapped `Responder` in `self`, whether it is `Ok` or /// `Err`. -impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for Result - where T: Responder<'r, 't>, E: Responder<'r, 'e> +impl<'r, 'o: 'r, T, E> Responder<'r, 'o> for Result + where T: Responder<'r, 'o>, + T::Error: Transient, + ::Transience: CanTranscendTo>, + E: TypedError<'r> + Transient + 'r, + E::Transience: CanTranscendTo>, { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = Either; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { match self { - Ok(responder) => responder.respond_to(req), - Err(responder) => responder.respond_to(req), + Ok(responder) => responder.respond_to(req).map_error(|e| Either::Left(e)), + Err(error) => Outcome::Error(Either::Right(error)), } } } /// Responds with the wrapped `Responder` in `self`, whether it is `Left` or /// `Right`. -impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for either::Either - where T: Responder<'r, 't>, E: Responder<'r, 'e> +impl<'r, 'o: 'r, T, E> Responder<'r, 'o> for either::Either + where T: Responder<'r, 'o>, + T::Error: Transient, + ::Transience: CanTranscendTo>, + E: Responder<'r, 'o>, + ::Transience: CanTranscendTo>, { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = Either; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { match self { - either::Either::Left(r) => r.respond_to(req), - either::Either::Right(r) => r.respond_to(req), + either::Either::Left(r) => r.respond_to(req).map_error(|e| Either::Left(e)), + either::Either::Right(r) => r.respond_to(req).map_error(|e| Either::Right(e)), } } } @@ -533,9 +586,10 @@ impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for either::Either Responder<'r, 'static> for Status { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { match self.class() { - StatusClass::ClientError | StatusClass::ServerError => Err(self), + StatusClass::ClientError | StatusClass::ServerError => Outcome::Forward(self), StatusClass::Success if self.code < 206 => { Response::build().status(self).ok() } @@ -547,7 +601,7 @@ impl<'r> Responder<'r, 'static> for Status { "invalid status used as responder\n\ status must be one of 100, 200..=205, 400..=599"); - Err(Status::InternalServerError) + Outcome::Forward(Status::InternalServerError) } } } diff --git a/core/lib/src/response/response.rs b/core/lib/src/response/response.rs index 7abdaecafc..030101bfc6 100644 --- a/core/lib/src/response/response.rs +++ b/core/lib/src/response/response.rs @@ -9,6 +9,8 @@ use crate::http::uncased::{Uncased, AsUncased}; use crate::data::IoHandler; use crate::response::Body; +use super::Outcome; + /// Builder for the [`Response`] type. /// /// Building a [`Response`] can be a low-level ordeal; this structure presents a @@ -432,17 +434,17 @@ impl<'r> Builder<'r> { /// # Example /// /// ```rust - /// use rocket::Response; + /// use rocket::response::{Response, Outcome}; /// - /// let response: Result = Response::build() + /// let response: Outcome<'_, ()> = Response::build() /// // build the response /// .ok(); /// - /// assert!(response.is_ok()); + /// assert!(response.is_success()); /// ``` #[inline(always)] - pub fn ok(&mut self) -> Result, E> { - Ok(self.finalize()) + pub fn ok(&mut self) -> Outcome<'r, E> { + Outcome::Success(self.finalize()) } } diff --git a/core/lib/src/response/status.rs b/core/lib/src/response/status.rs index 935fe88fdf..b2f27d961c 100644 --- a/core/lib/src/response/status.rs +++ b/core/lib/src/response/status.rs @@ -29,6 +29,10 @@ use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; use std::borrow::Cow; +use transient::Static; + +use crate::catcher::TypedError; +use crate::outcome::try_outcome; use crate::request::Request; use crate::response::{self, Responder, Response}; use crate::http::Status; @@ -163,10 +167,11 @@ impl Created { /// a hashable `Responder` is provided via [`Created::tagged_body()`]. The `ETag` /// header is set to a hash value of the responder. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Created { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + type Error = R::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { let mut response = Response::build(); if let Some(responder) = self.1 { - response.merge(responder.respond_to(req)?); + response.merge(try_outcome!(responder.respond_to(req))); } if let Some(hash) = self.2 { @@ -179,6 +184,14 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Created { } } +// TODO: do we want this? +impl TypedError<'_> for Created { + fn status(&self) -> Status { + Status::Created + } +} +impl Static for Created {} + /// Sets the status of the response to 204 No Content. /// /// The response body will be empty. @@ -201,7 +214,8 @@ pub struct NoContent; /// Sets the status code of the response to 204 No Content. impl<'r> Responder<'r, 'static> for NoContent { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build().status(Status::NoContent).ok() } } @@ -234,17 +248,19 @@ pub struct Custom(pub Status, pub R); /// Sets the status code of the response and then delegates the remainder of the /// response to the wrapped responder. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Custom { + type Error = R::Error; #[inline] - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { - Response::build_from(self.1.respond_to(req)?) + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { + Response::build_from(try_outcome!(self.1.respond_to(req))) .status(self.0) .ok() } } impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for (Status, R) { + type Error = R::Error; #[inline(always)] - fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, request: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { Custom(self.0, self.1).respond_to(request) } } @@ -288,11 +304,20 @@ macro_rules! status_response { pub struct $T(pub R); impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for $T { + type Error = R::Error; #[inline(always)] - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { Custom(Status::$T, self.0).respond_to(req) } } + + // TODO: do we want this? + impl TypedError<'_> for $T { + fn status(&self) -> Status { + Status::$T + } + } + impl Static for $T {} } } diff --git a/core/lib/src/response/stream/bytes.rs b/core/lib/src/response/stream/bytes.rs index 52782aa241..e4fc2da9be 100644 --- a/core/lib/src/response/stream/bytes.rs +++ b/core/lib/src/response/stream/bytes.rs @@ -64,7 +64,8 @@ impl From for ByteStream { impl<'r, S: Stream> Responder<'r, 'r> for ByteStream where S: Send + 'r, S::Item: AsRef<[u8]> + Send + Unpin + 'r { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'r, Self::Error> { Response::build() .header(ContentType::Binary) .streamed_body(ReaderStream::from(self.0.map(std::io::Cursor::new))) diff --git a/core/lib/src/response/stream/reader.rs b/core/lib/src/response/stream/reader.rs index d3a3da71bf..a6c6938006 100644 --- a/core/lib/src/response/stream/reader.rs +++ b/core/lib/src/response/stream/reader.rs @@ -39,7 +39,8 @@ pin_project! { /// impl<'r, S: Stream> Responder<'r, 'r> for MyStream /// where S: Send + 'r /// { - /// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + /// type Error = std::convert::Infallible; + /// fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'r, Self::Error> { /// Response::build() /// .header(ContentType::Text) /// .streamed_body(ReaderStream::from(self.0.map(Cursor::new))) @@ -142,7 +143,8 @@ impl From for ReaderStream { impl<'r, S: Stream> Responder<'r, 'r> for ReaderStream where S: Send + 'r, S::Item: AsyncRead + Send, { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'r, Self::Error> { Response::build() .streamed_body(self) .ok() diff --git a/core/lib/src/response/stream/sse.rs b/core/lib/src/response/stream/sse.rs index de24ad2816..12f7edf242 100644 --- a/core/lib/src/response/stream/sse.rs +++ b/core/lib/src/response/stream/sse.rs @@ -569,7 +569,8 @@ impl> From for EventStream { } impl<'r, S: Stream + Send + 'r> Responder<'r, 'r> for EventStream { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'r, Self::Error> { Response::build() .header(ContentType::EventStream) .raw_header("Cache-Control", "no-cache") diff --git a/core/lib/src/response/stream/text.rs b/core/lib/src/response/stream/text.rs index 3064e0f0e2..329535a328 100644 --- a/core/lib/src/response/stream/text.rs +++ b/core/lib/src/response/stream/text.rs @@ -65,7 +65,8 @@ impl From for TextStream { impl<'r, S: Stream> Responder<'r, 'r> for TextStream where S: Send + 'r, S::Item: AsRef + Send + Unpin + 'r { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'r, Self::Error> { struct ByteStr(T); impl> AsRef<[u8]> for ByteStr { diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 0a9dd07afc..082a32bf4e 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -379,7 +379,7 @@ impl Rocket { /// /// ```rust,no_run /// # #[macro_use] extern crate rocket; - /// use rocket::Request; + /// use rocket::http::uri::Origin; /// /// #[catch(500)] /// fn internal_error() -> &'static str { @@ -387,8 +387,8 @@ impl Rocket { /// } /// /// #[catch(404)] - /// fn not_found(req: &Request) -> String { - /// format!("I couldn't find '{}'. Try something else?", req.uri()) + /// fn not_found(uri: &Origin) -> String { + /// format!("I couldn't find '{}'. Try something else?", uri) /// } /// /// #[launch] diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index b42d81e0fc..ac69672fec 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -1,10 +1,15 @@ +use crate::catcher::TypedError; use crate::{Request, Data}; -use crate::response::{Response, Responder}; +use crate::response::{self, Response, Responder}; use crate::http::Status; /// Type alias for the return type of a [`Route`](crate::Route)'s /// [`Handler::handle()`]. -pub type Outcome<'r> = crate::outcome::Outcome, Status, (Data<'r>, Status)>; +pub type Outcome<'r> = crate::outcome::Outcome< + Response<'r>, + (Status, Option>>), + (Data<'r>, Status, Option>>) +>; /// Type alias for the return type of a _raw_ [`Route`](crate::Route)'s /// [`Handler`]. @@ -170,6 +175,7 @@ impl Handler for F impl<'r, 'o: 'r> Outcome<'o> { /// Return the `Outcome` of response to `req` from `responder`. /// + // TODO: docs /// If the responder returns `Ok`, an outcome of `Success` is returned with /// the response. If the responder returns `Err`, an outcome of `Error` is /// returned with the status code. @@ -186,39 +192,66 @@ impl<'r, 'o: 'r> Outcome<'o> { #[inline] pub fn from>(req: &'r Request<'_>, responder: R) -> Outcome<'r> { match responder.respond_to(req) { - Ok(response) => Outcome::Success(response), - Err(status) => Outcome::Error(status) + response::Outcome::Success(response) => Outcome::Success(response), + response::Outcome::Error(error) => { + crate::trace::info!( + type_name = std::any::type_name_of_val(&error), + "Typed error to catch" + ); + Outcome::Error((error.status(), Some(Box::new(error)))) + }, + response::Outcome::Forward(status) => Outcome::Error((status, None)), } } - /// Return the `Outcome` of response to `req` from `responder`. + // TODO: does this still make sense + // /// Return the `Outcome` of response to `req` from `responder`. + // /// + // /// If the responder returns `Ok`, an outcome of `Success` is returned with + // /// the response. If the responder returns `Err`, an outcome of `Error` is + // /// returned with the status code. + // /// + // /// # Example + // /// + // /// ```rust + // /// use rocket::{Request, Data, route}; + // /// + // /// fn str_responder<'r>(req: &'r Request, _: Data<'r>) -> route::Outcome<'r> { + // /// route::Outcome::from(req, "Hello, world!") + // /// } + // /// ``` + // #[inline] + // pub fn try_from(req: &'r Request<'_>, result: Result) -> Outcome<'r> + // where R: Responder<'r, 'o>, E: std::fmt::Debug + // { + // let responder = result.map_err(crate::response::Debug); + // match responder.respond_to(req) { + // Ok(response) => Outcome::Success(response), + // Err(status) => Outcome::Error((status, Box::new(()))), + // } + // } + + /// Return an `Outcome` of `Error` with the status code `code`. This is + /// equivalent to `Outcome::error_val(code, ())`. /// - /// If the responder returns `Ok`, an outcome of `Success` is returned with - /// the response. If the responder returns `Err`, an outcome of `Error` is - /// returned with the status code. + /// This method exists to be used during manual routing. /// /// # Example /// /// ```rust /// use rocket::{Request, Data, route}; + /// use rocket::http::Status; /// - /// fn str_responder<'r>(req: &'r Request, _: Data<'r>) -> route::Outcome<'r> { - /// route::Outcome::from(req, "Hello, world!") + /// fn bad_req_route<'r>(_: &'r Request, _: Data<'r>) -> route::Outcome<'r> { + /// route::Outcome::error(Status::BadRequest) /// } /// ``` - #[inline] - pub fn try_from(req: &'r Request<'_>, result: Result) -> Outcome<'r> - where R: Responder<'r, 'o>, E: std::fmt::Debug - { - let responder = result.map_err(crate::response::Debug); - match responder.respond_to(req) { - Ok(response) => Outcome::Success(response), - Err(status) => Outcome::Error(status) - } + #[inline(always)] + pub fn error(code: Status) -> Outcome<'r> { + Outcome::Error((code, None)) } - - /// Return an `Outcome` of `Error` with the status code `code`. This is - /// equivalent to `Outcome::Error(code)`. + /// Return an `Outcome` of `Error` with the status code `code`. This adds + /// the value for typed catchers. /// /// This method exists to be used during manual routing. /// @@ -228,17 +261,21 @@ impl<'r, 'o: 'r> Outcome<'o> { /// use rocket::{Request, Data, route}; /// use rocket::http::Status; /// + /// struct CustomError(&'static str); + /// impl rocket::catcher::Static for CustomError {} + /// impl rocket::catcher::TypedError<'_> for CustomError {} + /// /// fn bad_req_route<'r>(_: &'r Request, _: Data<'r>) -> route::Outcome<'r> { - /// route::Outcome::error(Status::BadRequest) + /// route::Outcome::error_val(Status::BadRequest, CustomError("Some data to go with")) /// } /// ``` #[inline(always)] - pub fn error(code: Status) -> Outcome<'r> { - Outcome::Error(code) + pub fn error_val>(code: Status, val: T) -> Outcome<'r> { + Outcome::Error((code, Some(Box::new(val)))) } /// Return an `Outcome` of `Forward` with the data `data` and status - /// `status`. This is equivalent to `Outcome::Forward((data, status))`. + /// `status`. /// /// This method exists to be used during manual routing. /// @@ -254,7 +291,29 @@ impl<'r, 'o: 'r> Outcome<'o> { /// ``` #[inline(always)] pub fn forward(data: Data<'r>, status: Status) -> Outcome<'r> { - Outcome::Forward((data, status)) + Outcome::Forward((data, status, None)) + } + + /// Return an `Outcome` of `Forward` with the data `data`, status + /// `status` and a value of `val`. + /// + /// This method exists to be used during manual routing. + /// + /// # Example + /// + /// ```rust + /// use rocket::{Request, Data, route}; + /// use rocket::http::Status; + /// + /// fn always_forward<'r>(_: &'r Request, data: Data<'r>) -> route::Outcome<'r> { + /// route::Outcome::forward(data, Status::InternalServerError) + /// } + /// ``` + #[inline(always)] + pub fn forward_val>(data: Data<'r>, status: Status, val: T) + -> Outcome<'r> + { + Outcome::Forward((data, status, Some(Box::new(val)))) } } diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index d0e15ae45d..fae9ef390f 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -141,7 +141,9 @@ impl Catcher { /// assert!(!a.collides_with(&b)); /// ``` pub fn collides_with(&self, other: &Self) -> bool { - self.code == other.code && self.base().segments().eq(other.base().segments()) + self.code == other.code && + types_collide(self, other) && + self.base().segments().eq(other.base().segments()) } } @@ -207,6 +209,10 @@ fn formats_collide(route: &Route, other: &Route) -> bool { } } +fn types_collide(catcher: &Catcher, other: &Catcher) -> bool { + catcher.error_type.as_ref().map(|(i, _)| i) == other.error_type.as_ref().map(|(i, _)| i) +} + #[cfg(test)] mod tests { use std::str::FromStr; diff --git a/core/lib/src/router/matcher.rs b/core/lib/src/router/matcher.rs index e0dd66d302..2449139c4c 100644 --- a/core/lib/src/router/matcher.rs +++ b/core/lib/src/router/matcher.rs @@ -1,3 +1,6 @@ +use transient::TypeId; + +use crate::catcher::TypedError; use crate::{Route, Request, Catcher}; use crate::router::Collide; use crate::http::Status; @@ -119,14 +122,14 @@ impl Catcher { /// // Let's say `request` is `GET /` that 404s. The error matches only `a`: /// let request = client.get("/"); /// # let request = request.inner(); - /// assert!(a.matches(Status::NotFound, &request)); - /// assert!(!b.matches(Status::NotFound, &request)); + /// assert!(a.matches(Status::NotFound, &request, None)); + /// assert!(!b.matches(Status::NotFound, &request, None)); /// /// // Now `request` is a 404 `GET /bar`. The error matches `a` and `b`: /// let request = client.get("/bar"); /// # let request = request.inner(); - /// assert!(a.matches(Status::NotFound, &request)); - /// assert!(b.matches(Status::NotFound, &request)); + /// assert!(a.matches(Status::NotFound, &request, None)); + /// assert!(b.matches(Status::NotFound, &request, None)); /// /// // Note that because `b`'s base' has more complete segments that `a's, /// // Rocket would route the error to `b`, not `a`, even though both match. @@ -134,12 +137,15 @@ impl Catcher { /// let b_count = b.base().segments().filter(|s| !s.is_empty()).count(); /// assert!(b_count > a_count); /// ``` - pub fn matches(&self, status: Status, request: &Request<'_>) -> bool { + // TODO: document error matching + pub fn matches(&self, status: Status, request: &Request<'_>, error: Option) -> bool { self.code.map_or(true, |code| code == status.code) + && error == self.error_type.map(|(ty, _)| ty) && self.base().segments().prefix_of(request.uri().path().segments()) } } + fn paths_match(route: &Route, req: &Request<'_>) -> bool { trace!(route.uri = %route.uri, request.uri = %req.uri()); let route_segments = &route.uri.metadata.uri_segments; diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 84da1b9843..8fc0fd7696 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -1,5 +1,8 @@ use std::collections::HashMap; +use transient::TypeId; + +use crate::catcher::TypedError; use crate::request::Request; use crate::http::{Method, Status}; @@ -51,21 +54,34 @@ impl Router { .flat_map(move |routes| routes.iter().filter(move |r| r.matches(req))) } + // TODO: Catch order: + // There are four matches (ignoring uri base): + // - Error type & Status + // - Error type & any status + // - Any type & Status + // - Any type & any status + // + // What order should these be selected in? + // Master prefers longer paths over any other match. However, types could + // be considered more important + // - Error type, longest path, status + // - Any type, longest path, status + // !! There are actually more than 4 - b/c we need to check source() + // What we would want to do, is gather the catchers that match the source() x 5, + // and select the one with the longest path. If none exist, try without error. + // For many catchers, using aho-corasick or similar should be much faster. - pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> { + pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>, error: Option) + -> Option<&Catcher> + { // Note that catchers are presorted by descending base length. - let explicit = self.catchers.get(&Some(status.code)) - .and_then(|c| c.iter().find(|c| c.matches(status, req))); - - let default = self.catchers.get(&None) - .and_then(|c| c.iter().find(|c| c.matches(status, req))); - - match (explicit, default) { - (None, None) => None, - (None, c@Some(_)) | (c@Some(_), None) => c, - (Some(a), Some(b)) if a.rank <= b.rank => Some(a), - (Some(_), Some(b)) => Some(b), - } + self.catchers.get(&Some(status.code)) + .and_then(|c| c.iter().find(|c| c.matches(status, req, error))) + .into_iter() + .chain(self.catchers.get(&None) + .and_then(|c| c.iter().find(|c| c.matches(status, req, error))) + ) + .min_by_key(|c| c.rank) } fn collisions<'a, I, T>(&self, items: I) -> impl Iterator + 'a @@ -555,10 +571,12 @@ mod test { router } - fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { + fn catcher<'a>(router: &'a Router, status: Status, uri: &str, error_ty: Option) + -> Option<&'a Catcher> + { let client = Client::debug_with(vec![]).expect("client"); let request = client.get(Origin::parse(uri).unwrap()); - router.catch(status, &request) + router.catch(status, &request, error_ty) } macro_rules! assert_catcher_routing { @@ -574,7 +592,8 @@ mod test { let router = router_with_catchers(&catchers); for (req, expected) in requests.iter().zip(expected.iter()) { let req_status = Status::from_code(req.0).expect("valid status"); - let catcher = catcher(&router, req_status, req.1).expect("some catcher"); + // TODO: write test cases for typed variant + let catcher = catcher(&router, req_status, req.1, None).expect("some catcher"); assert_eq!(catcher.code, expected.0, "\nmatched {:?}, expected {:?} for req {:?}", catcher, expected, req); diff --git a/core/lib/src/sentinel.rs b/core/lib/src/sentinel.rs index c45517b0da..c4feac7452 100644 --- a/core/lib/src/sentinel.rs +++ b/core/lib/src/sentinel.rs @@ -150,11 +150,12 @@ use crate::{Rocket, Ignite}; /// use rocket::response::Responder; /// # type AnotherSentinel = (); /// -/// #[get("/")] -/// fn f<'r>() -> Either, AnotherSentinel> { -/// /* ... */ -/// # Either::Left(()) -/// } +/// // TODO: this no longer compiles, since the `impl Responder` doesn't meet the full reqs +/// // #[get("/")] +/// // fn f<'r>() -> Either, AnotherSentinel> { +/// // /* ... */ +/// // # Either::Left(()) +/// // } /// ``` /// /// **Note:** _Rocket actively discourages using `impl Trait` in route diff --git a/core/lib/src/serde/json.rs b/core/lib/src/serde/json.rs index d68b24f04b..af2d069c9f 100644 --- a/core/lib/src/serde/json.rs +++ b/core/lib/src/serde/json.rs @@ -35,6 +35,7 @@ use crate::http::uri::fmt::{UriDisplay, FromUriParam, Query, Formatter as UriFor use crate::http::Status; use serde::{Serialize, Deserialize}; +use transient::Transient; #[doc(hidden)] pub use serde_json; @@ -139,6 +140,11 @@ pub enum Error<'a> { Parse(&'a str, serde_json::error::Error), } +unsafe impl<'a> Transient for Error<'a> { + type Static = Error<'static>; + type Transience = transient::Co<'a>; +} + impl<'a> fmt::Display for Error<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -216,14 +222,17 @@ impl<'r, T: Deserialize<'r>> FromData<'r> for Json { /// JSON and a fixed-size body with the serialized value. If serialization /// fails, an `Err` of `Status::InternalServerError` is returned. impl<'r, T: Serialize> Responder<'r, 'static> for Json { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { - let string = serde_json::to_string(&self.0) - .map_err(|e| { + type Error = serde_json::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { + let string = match serde_json::to_string(&self.0) { + Ok(v) => v, + Err(e) => { error!("JSON serialize failure: {}", e); - Status::InternalServerError - })?; + return response::Outcome::Error(e); + } + }; - content::RawJson(string).respond_to(req) + content::RawJson(string).respond_to(req).map_error(|e| match e {}) } } @@ -298,7 +307,8 @@ impl<'v, T: Deserialize<'v> + Send> form::FromFormField<'v> for Json { /// Serializes the value into JSON. Returns a response with Content-Type JSON /// and a fixed-size body with the serialized value. impl<'r> Responder<'r, 'static> for Value { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + type Error = std::convert::Infallible; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { content::RawJson(self.to_string()).respond_to(req) } } diff --git a/core/lib/src/serde/msgpack.rs b/core/lib/src/serde/msgpack.rs index 73217d5a62..7a555d49f6 100644 --- a/core/lib/src/serde/msgpack.rs +++ b/core/lib/src/serde/msgpack.rs @@ -187,14 +187,20 @@ impl<'r, T: Deserialize<'r>> FromData<'r> for MsgPack { /// Content-Type `MsgPack` and a fixed-size body with the serialization. If /// serialization fails, an `Err` of `Status::InternalServerError` is returned. impl<'r, T: Serialize> Responder<'r, 'static> for MsgPack { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { - let buf = rmp_serde::to_vec(&self.0) - .map_err(|e| { + type Error = rmp_serde::encode::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { + let buf = match rmp_serde::to_vec(&self.0) { + Ok(v) => v, + Err(e) => { error!("MsgPack serialize failure: {}", e); - Status::InternalServerError - })?; + return response::Outcome::Error(e); + } + }; + // .map_err(|e| { + // Status::InternalServerError + // })?; - content::RawMsgPack(buf).respond_to(req) + content::RawMsgPack(buf).respond_to(req).map_err_type() } } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index badfb44c95..5137e0f908 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -10,7 +10,7 @@ use futures::{Future, TryFutureExt}; use tokio::io::{AsyncRead, AsyncWrite}; use crate::{Ignite, Orbit, Request, Rocket}; -use crate::request::ConnectionMeta; +use crate::request::{ConnectionMeta, RequestErrors}; use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler}; use crate::listener::{Listener, Connection, BouncedExt, CancellableExt}; use crate::error::log_server_error; @@ -42,13 +42,20 @@ impl Rocket { span_debug!("request headers" => request.inner().headers().iter().trace_all_debug()); let mut response = request.into_response( stream, - |rocket, request, data| Box::pin(rocket.preprocess(request, data)), - |token, rocket, request, data| Box::pin(async move { + |rocket, request, data, error_ptr| { + Box::pin(rocket.preprocess(request, data, error_ptr)) + }, + |token, rocket, request, data, error_ptr| Box::pin(async move { if !request.errors.is_empty() { - return rocket.dispatch_error(Status::BadRequest, request).await; + error_ptr.write(Some(Box::new(RequestErrors::new(&request.errors)))); + return rocket.dispatch_error( + Status::BadRequest, + request, + error_ptr.get(), + ).await; } - rocket.dispatch(token, request, data).await + rocket.dispatch(token, request, data, error_ptr).await }) ).await; diff --git a/core/lib/src/trace/traceable.rs b/core/lib/src/trace/traceable.rs index 19b8c8b580..bd8800b5e7 100644 --- a/core/lib/src/trace/traceable.rs +++ b/core/lib/src/trace/traceable.rs @@ -244,8 +244,8 @@ impl Trace for route::Outcome<'_> { }, status = match self { Self::Success(r) => r.status().code, - Self::Error(s) => s.code, - Self::Forward((_, s)) => s.code, + Self::Error((s, _)) => s.code, + Self::Forward((_, s, _)) => s.code, }, ) } diff --git a/core/lib/tests/catcher-cookies-1213.rs b/core/lib/tests/catcher-cookies-1213.rs index 332ee23200..2ea9d3e700 100644 --- a/core/lib/tests/catcher-cookies-1213.rs +++ b/core/lib/tests/catcher-cookies-1213.rs @@ -1,11 +1,10 @@ #[macro_use] extern crate rocket; -use rocket::request::Request; use rocket::http::CookieJar; #[catch(404)] -fn not_found(request: &Request<'_>) -> &'static str { - request.cookies().add(("not_found", "404")); +fn not_found(jar: &CookieJar<'_>) -> &'static str { + jar.add(("not_found", "404")); "404 - Not Found" } diff --git a/core/lib/tests/panic-handling.rs b/core/lib/tests/panic-handling.rs index f5e8c1aea5..d9241c1023 100644 --- a/core/lib/tests/panic-handling.rs +++ b/core/lib/tests/panic-handling.rs @@ -1,5 +1,6 @@ #[macro_use] extern crate rocket; +use rocket::catcher::TypedError; use rocket::{Request, Rocket, Route, Catcher, Build, route, catcher}; use rocket::data::Data; use rocket::http::{Method, Status}; @@ -73,7 +74,9 @@ fn catches_early_route_panic() { #[test] fn catches_early_catcher_panic() { - fn pre_future_catcher<'r>(_: Status, _: &'r Request<'_>) -> catcher::BoxFuture<'r> { + fn pre_future_catcher<'r>(_: Status, _: &'r Request<'_>, _: Option<&'r dyn TypedError<'r>>) + -> catcher::BoxFuture<'r> + { panic!("a panicking pre-future catcher") } diff --git a/core/lib/tests/responder_lifetime-issue-345.rs b/core/lib/tests/responder_lifetime-issue-345.rs index 4cd12f000b..017fb12bca 100644 --- a/core/lib/tests/responder_lifetime-issue-345.rs +++ b/core/lib/tests/responder_lifetime-issue-345.rs @@ -3,7 +3,7 @@ #[macro_use] extern crate rocket; use rocket::{Request, State}; -use rocket::response::{Responder, Result}; +use rocket::response::{Responder, Outcome}; struct SomeState; @@ -13,7 +13,8 @@ pub struct CustomResponder<'r, R> { } impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for CustomResponder<'r, R> { - fn respond_to(self, req: &'r Request<'_>) -> Result<'o> { + type Error = >::Error; + fn respond_to(self, req: &'r Request<'_>) -> Outcome<'o, Self::Error> { self.responder.respond_to(req) } } diff --git a/core/lib/tests/sentinel.rs b/core/lib/tests/sentinel.rs index d88e99b98d..174182c653 100644 --- a/core/lib/tests/sentinel.rs +++ b/core/lib/tests/sentinel.rs @@ -1,4 +1,4 @@ -use rocket::{*, either::Either, error::ErrorKind::SentinelAborts}; +use rocket::{catcher::TypedError, either::Either, error::ErrorKind::SentinelAborts, *}; #[get("/two")] fn two_states(_one: &State, _two: &State) {} @@ -147,14 +147,18 @@ async fn data_sentinel_works() { #[test] fn inner_sentinels_detected() { use rocket::local::blocking::Client; + use transient::Transient; #[derive(Responder)] struct MyThing(T); + #[derive(Debug, Transient)] struct ResponderSentinel; + impl TypedError<'_> for ResponderSentinel {} impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> { unimplemented!() } } @@ -222,33 +226,34 @@ fn inner_sentinels_detected() { use rocket::response::Responder; - #[get("/")] - fn half_c<'r>() -> Either< - Inner>, - Result> - > { - Either::Left(Inner(())) - } + // #[get("/")] + // fn half_c<'r>() -> Either< + // Inner>, + // Result> + // > { + // Either::Left(Inner(())) + // } - let err = Client::debug_with(routes![half_c]).unwrap_err(); - assert!(matches!(err.kind(), SentinelAborts(vec) if vec.len() == 2)); + // let err = Client::debug_with(routes![half_c]).unwrap_err(); + // assert!(matches!(err.kind(), SentinelAborts(vec) if vec.len() == 2)); - #[get("/")] - fn half_d<'r>() -> Either< - Inner>, - Result, Inner> - > { - Either::Left(Inner(())) - } + // #[get("/")] + // fn half_d<'r>() -> Either< + // Inner>, + // Result, Inner> + // > { + // Either::Left(Inner(())) + // } - let err = Client::debug_with(routes![half_d]).unwrap_err(); - assert!(matches!(err.kind(), SentinelAborts(vec) if vec.len() == 1)); + // let err = Client::debug_with(routes![half_d]).unwrap_err(); + // assert!(matches!(err.kind(), SentinelAborts(vec) if vec.len() == 1)); // The special `Result` implementation. type MyResult = Result; #[get("/")] - fn half_e<'r>() -> Either>, MyResult> { + // fn half_e<'r>() -> Either>, MyResult> { + fn half_e<'r>() -> Either, MyResult> { Either::Left(Inner(())) } diff --git a/docs/guide/05-requests.md b/docs/guide/05-requests.md index 838cbb12df..22d401775c 100644 --- a/docs/guide/05-requests.md +++ b/docs/guide/05-requests.md @@ -1981,45 +1981,40 @@ Application processing is fallible. Errors arise from the following sources: * A routing failure. If any of these occur, Rocket returns an error to the client. To generate the -error, Rocket invokes the _catcher_ corresponding to the error's status code and -scope. Catchers are similar to routes except in that: +error, Rocket invokes the _catcher_ corresponding to the error's status code, +scope, and type. Catchers are similar to routes except in that: 1. Catchers are only invoked on error conditions. 2. Catchers are declared with the `catch` attribute. 3. Catchers are _registered_ with [`register()`] instead of [`mount()`]. 4. Any modifications to cookies are cleared before a catcher is invoked. - 5. Error catchers cannot invoke guards. 6. Error catchers should not fail to produce a response. 7. Catchers are scoped to a path prefix. To declare a catcher for a given status code, use the [`catch`] attribute, which -takes a single integer corresponding to the HTTP status code to catch. For +takes a single integer corresponding to the HTTP status code to catch as the first +arguement. For instance, to declare a catcher for `404 Not Found` errors, you'd write: ```rust # #[macro_use] extern crate rocket; # fn main() {} -use rocket::Request; - #[catch(404)] -fn not_found(req: &Request) { /* .. */ } +fn not_found() { /* .. */ } ``` -Catchers may take zero, one, or two arguments. If the catcher takes one -argument, it must be of type [`&Request`]. It it takes two, they must be of type -[`Status`] and [`&Request`], in that order. As with routes, the return type must -implement `Responder`. A concrete implementation may look like: +Cathers can include Request Guards, although forwards and errors work differently. +Forwards try the next catcher in the chain, while Errors trigger a `500` response. ```rust # #[macro_use] extern crate rocket; # fn main() {} - -# use rocket::Request; +# use rocket::http::uri::Origin; #[catch(404)] -fn not_found(req: &Request) -> String { - format!("Sorry, '{}' is not a valid path.", req.uri()) +fn not_found(uri: &Origin) -> String { + format!("Sorry, '{}' is not a valid path.", uri) } ``` @@ -2032,14 +2027,49 @@ looks like: ```rust # #[macro_use] extern crate rocket; -# use rocket::Request; -# #[catch(404)] fn not_found(req: &Request) { /* .. */ } +# #[catch(404)] fn not_found() { /* .. */ } fn main() { rocket::build().register("/", catchers![not_found]); } ``` +### Additional parameters. + +Catchers provide two special parameters: `status` and `error`. `status` provides +access to the status code, which is primarily useful for [default catchers](#default-catchers). +`error` provides access the error values returned by a [`FromRequest`], [`FromData`] or [`FromParam`]. +It only provides access to the most recent error, so if a route forwards, and another route is +attempted, only the error produced by the most recent attempt can be extracted. The `error` type +must implement [`Transient`], a non-static re-implementation of [`std::and::Any`]. (Almost) All errror +types returned by built-in guards implement [`Transient`], and can therefore be extracted. See +[the `Transient` derive](@api/master/rocket/catcher/derive.Transient.html) for more information +on implementing [`Transient`] for custom error types. + +[`Transient`]: @api/master/rocket/catcher/trait.Transient.html +[`std::any::Any`]: https://doc.rust-lang.org/1.78.0/core/any/trait.Any.html + +* The form::Errors type does not (yet) implement Transient + +The function arguement must be a reference to the error type expected. See the +[error handling example](@git/master/examples/error-handling) +for a full application, including the route that generates the error. + +```rust +# #[macro_use] extern crate rocket; + +use rocket::Request; +use std::num::ParseIntError; + +#[catch(default, error = "")] +fn default_catcher(error: &ParseIntError) { /* .. */ } + +#[launch] +fn rocket() -> _ { + rocket::build().register("/", catchers![default_catcher]) +} +``` + ### Scoping The first argument to `register()` is a path to scope the catcher under called @@ -2106,8 +2136,8 @@ similarly be registered with [`register()`]: use rocket::Request; use rocket::http::Status; -#[catch(default)] -fn default_catcher(status: Status, request: &Request) { /* .. */ } +#[catch(default, status = "")] +fn default_catcher(status: Status) { /* .. */ } #[launch] fn rocket() -> _ { diff --git a/docs/guide/06-responses.md b/docs/guide/06-responses.md index 9343b5964e..722e9f803a 100644 --- a/docs/guide/06-responses.md +++ b/docs/guide/06-responses.md @@ -263,9 +263,11 @@ use rocket::response::{self, Response, Responder}; use rocket::http::ContentType; # struct String(std::string::String); +// TODO: this needs a full update #[rocket::async_trait] impl<'r> Responder<'r, 'static> for String { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + type Error = std::convert::Infallible; + fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { Response::build() .header(ContentType::Plain) # /* diff --git a/docs/guide/14-faq.md b/docs/guide/14-faq.md index e4dabca4b8..662903b52f 100644 --- a/docs/guide/14-faq.md +++ b/docs/guide/14-faq.md @@ -418,10 +418,13 @@ the example below: use rocket::request::Request; use rocket::response::{self, Response, Responder}; use rocket::serde::json::Json; +use rocket::outcome::try_outcome; +// TODO: this needs a full update impl<'r> Responder<'r, 'static> for Person { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { - Response::build_from(Json(&self).respond_to(req)?) + type Error = as Responder<'r, 'static>>::Error; + fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> { + Response::build_from(try_outcome!(Json(&self).respond_to(req))) .raw_header("X-Person-Name", self.name) .raw_header("X-Person-Age", self.age.to_string()) .ok() diff --git a/examples/cookies/src/session.rs b/examples/cookies/src/session.rs index 31d0fc613c..acf2d4b2ce 100644 --- a/examples/cookies/src/session.rs +++ b/examples/cookies/src/session.rs @@ -1,4 +1,5 @@ use rocket::outcome::IntoOutcome; +use rocket::either::Either; use rocket::request::{self, FlashMessage, FromRequest, Request}; use rocket::response::{Redirect, Flash}; use rocket::http::{CookieJar, Status}; @@ -58,12 +59,12 @@ fn login_page(flash: Option>) -> Template { } #[post("/login", data = "")] -fn post_login(jar: &CookieJar<'_>, login: Form>) -> Result> { +fn post_login(jar: &CookieJar<'_>, login: Form>) -> Either> { if login.username == "Sergio" && login.password == "password" { jar.add_private(("user_id", "1")); - Ok(Redirect::to(uri!(index))) + Either::Left(Redirect::to(uri!(index))) } else { - Err(Flash::error(Redirect::to(uri!(login_page)), "Invalid username/password.")) + Either::Right(Flash::error(Redirect::to(uri!(login_page)), "Invalid username/password.")) } } diff --git a/examples/error-handling/Cargo.toml b/examples/error-handling/Cargo.toml index c19138a7b2..1582b06295 100644 --- a/examples/error-handling/Cargo.toml +++ b/examples/error-handling/Cargo.toml @@ -7,3 +7,4 @@ publish = false [dependencies] rocket = { path = "../../core/lib" } +transient = { path = "/code/matthew/transient" } diff --git a/examples/error-handling/src/main.rs b/examples/error-handling/src/main.rs index ffa0a6b13f..85693d49b1 100644 --- a/examples/error-handling/src/main.rs +++ b/examples/error-handling/src/main.rs @@ -2,9 +2,12 @@ #[cfg(test)] mod tests; -use rocket::{Rocket, Request, Build}; +use transient::Transient; + +use rocket::{Rocket, Build}; use rocket::response::{content, status}; -use rocket::http::Status; +use rocket::http::{Status, uri::Origin}; +use std::num::ParseIntError; #[get("/hello//")] fn hello(name: &str, age: i8) -> String { @@ -16,6 +19,22 @@ fn forced_error(code: u16) -> Status { Status::new(code) } +// TODO: Derive TypedError +#[derive(Transient, Debug)] +struct CustomError; + +impl<'r> rocket::catcher::TypedError<'r> for CustomError { } + +#[get("/")] +fn forced_custom_error() -> Result<(), CustomError> { + Err(CustomError) +} + +#[catch(500, error = "")] +fn catch_custom(e: &CustomError) -> &'static str { + "You found the custom error!" +} + #[catch(404)] fn general_not_found() -> content::RawHtml<&'static str> { content::RawHtml(r#" @@ -25,11 +44,23 @@ fn general_not_found() -> content::RawHtml<&'static str> { } #[catch(404)] -fn hello_not_found(req: &Request<'_>) -> content::RawHtml { +fn hello_not_found(uri: &Origin<'_>) -> content::RawHtml { content::RawHtml(format!("\

Sorry, but '{}' is not a valid path!

\

Try visiting /hello/<name>/<age> instead.

", - req.uri())) + uri)) +} + +// `error` is typed error. All other parameters must implement `FromError`. +// Any type that implements `FromRequest` automatically implements `FromError`, +// as well as `Status`, `&Request` and `&dyn TypedError<'_>` +#[catch(422, error = "")] +fn param_error(e: &ParseIntError, uri: &Origin<'_>) -> content::RawHtml { + content::RawHtml(format!("\ +

Sorry, but '{}' is not a valid path!

\ +

Try visiting /hello/<name>/<age> instead.

\ +

Error: {e:?}

", + uri)) } #[catch(default)] @@ -37,9 +68,9 @@ fn sergio_error() -> &'static str { "I...don't know what to say." } -#[catch(default)] -fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom { - let msg = format!("{} ({})", status, req.uri()); +#[catch(default, status = "")] +fn default_catcher(status: Status, uri: &Origin<'_>) -> status::Custom { + let msg = format!("{} ({})", status, uri); status::Custom(status, msg) } @@ -51,9 +82,9 @@ fn rocket() -> Rocket { rocket::build() // .mount("/", routes![hello, hello]) // uncomment this to get an error // .mount("/", routes![unmanaged]) // uncomment this to get a sentinel error - .mount("/", routes![hello, forced_error]) - .register("/", catchers![general_not_found, default_catcher]) - .register("/hello", catchers![hello_not_found]) + .mount("/", routes![hello, forced_error, forced_custom_error]) + .register("/", catchers![general_not_found, default_catcher, catch_custom]) + .register("/hello", catchers![hello_not_found, param_error]) .register("/hello/Sergio", catchers![sergio_error]) } diff --git a/examples/error-handling/src/tests.rs b/examples/error-handling/src/tests.rs index fcd78424c9..a0f90dc1ed 100644 --- a/examples/error-handling/src/tests.rs +++ b/examples/error-handling/src/tests.rs @@ -24,19 +24,19 @@ fn forced_error() { assert_eq!(response.into_string().unwrap(), expected.0); let request = client.get("/405"); - let expected = super::default_catcher(Status::MethodNotAllowed, request.inner()); + let expected = super::default_catcher(Status::MethodNotAllowed, request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::MethodNotAllowed); assert_eq!(response.into_string().unwrap(), expected.1); let request = client.get("/533"); - let expected = super::default_catcher(Status::new(533), request.inner()); + let expected = super::default_catcher(Status::new(533), request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::new(533)); assert_eq!(response.into_string().unwrap(), expected.1); let request = client.get("/700"); - let expected = super::default_catcher(Status::InternalServerError, request.inner()); + let expected = super::default_catcher(Status::InternalServerError, request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.into_string().unwrap(), expected.1); @@ -48,16 +48,19 @@ fn test_hello_invalid_age() { for path in &["Ford/-129", "Trillian/128"] { let request = client.get(format!("/hello/{}", path)); - let expected = super::default_catcher(Status::UnprocessableEntity, request.inner()); + let expected = super::param_error( + &path.split_once("/").unwrap().1.parse::().unwrap_err(), + request.uri() + ); let response = request.dispatch(); assert_eq!(response.status(), Status::UnprocessableEntity); - assert_eq!(response.into_string().unwrap(), expected.1); + assert_eq!(response.into_string().unwrap(), expected.0); } { let path = &"foo/bar/baz"; let request = client.get(format!("/hello/{}", path)); - let expected = super::hello_not_found(request.inner()); + let expected = super::hello_not_found(request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::NotFound); assert_eq!(response.into_string().unwrap(), expected.0); diff --git a/examples/fairings/src/main.rs b/examples/fairings/src/main.rs index 0e826a1c69..7ce4d2e91f 100644 --- a/examples/fairings/src/main.rs +++ b/examples/fairings/src/main.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use rocket::{Rocket, Request, State, Data, Build}; use rocket::fairing::{self, AdHoc, Fairing, Info, Kind}; +use rocket::catcher::TypedError; use rocket::trace::Trace; use rocket::http::Method; @@ -39,12 +40,15 @@ impl Fairing for Counter { Ok(rocket.manage(self.clone()).mount("/", routes![counts])) } - async fn on_request(&self, request: &mut Request<'_>, _: &mut Data<'_>) { + async fn on_request<'r>(&self, request: &'r mut Request<'_>, _: &mut Data<'_>) + -> Result<(), Box + 'r>> + { if request.method() == Method::Get { self.get.fetch_add(1, Ordering::Relaxed); } else if request.method() == Method::Post { self.post.fetch_add(1, Ordering::Relaxed); } + Ok(()) } } @@ -83,6 +87,7 @@ fn rocket() -> _ { req.trace_info(); }) } + Ok(()) }) })) .attach(AdHoc::on_response("Response Rewriter", |req, res| { diff --git a/examples/manual-routing/src/main.rs b/examples/manual-routing/src/main.rs index e4a21620f0..5322ebe775 100644 --- a/examples/manual-routing/src/main.rs +++ b/examples/manual-routing/src/main.rs @@ -1,12 +1,12 @@ #[cfg(test)] mod tests; -use rocket::{Request, Route, Catcher, route, catcher}; +use rocket::{Request, Route, Catcher, route, catcher, outcome::Outcome}; use rocket::data::{Data, ToByteUnit}; use rocket::http::{Status, Method::{Get, Post}}; use rocket::response::{Responder, status::Custom}; -use rocket::outcome::{try_outcome, IntoOutcome}; use rocket::tokio::fs::File; +use rocket::catcher::TypedError; fn forward<'r>(_req: &'r Request, data: Data<'r>) -> route::BoxFuture<'r> { Box::pin(async move { route::Outcome::forward(data, Status::NotFound) }) @@ -25,12 +25,17 @@ fn name<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { } fn echo_url<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { - let param_outcome = req.param::<&str>(1) - .and_then(Result::ok) - .or_error(Status::BadRequest); - Box::pin(async move { - route::Outcome::from(req, try_outcome!(param_outcome)) + let param_outcome = match req.param::<&str>(1) { + Some(Ok(v)) => v, + Some(Err(e)) => return Outcome::Error(( + Status::BadRequest, + Some(Box::new(e) as Box>) + )), + None => return Outcome::Error((Status::BadRequest, None)), + }; + + route::Outcome::from(req, param_outcome) }) } @@ -62,9 +67,11 @@ fn get_upload<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { route::Outcome::from(req, std::fs::File::open(path).ok()).pin() } -fn not_found_handler<'r>(_: Status, req: &'r Request) -> catcher::BoxFuture<'r> { +fn not_found_handler<'r>(_: Status, req: &'r Request, _e: Option<&'r dyn TypedError<'r>>) + -> catcher::BoxFuture<'r> +{ let responder = Custom(Status::NotFound, format!("Couldn't find: {}", req.uri())); - Box::pin(async move { responder.respond_to(req) }) + Box::pin(async move { responder.respond_to(req).responder_error() }) } #[derive(Clone)] @@ -82,11 +89,17 @@ impl CustomHandler { impl route::Handler for CustomHandler { async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> route::Outcome<'r> { let self_data = self.data; - let id = req.param::<&str>(0) - .and_then(Result::ok) - .or_forward((data, Status::NotFound)); - - route::Outcome::from(req, format!("{} - {}", self_data, try_outcome!(id))) + let id = match req.param::<&str>(1) { + Some(Ok(v)) => v, + Some(Err(e)) => return Outcome::Forward((data, Status::BadRequest, Some(Box::new(e)))), + None => return Outcome::Forward(( + data, + Status::BadRequest, + None + )), + }; + + route::Outcome::from(req, format!("{} - {}", self_data, id)) } } diff --git a/examples/responders/src/main.rs b/examples/responders/src/main.rs index 90b65b3be2..79d0baba62 100644 --- a/examples/responders/src/main.rs +++ b/examples/responders/src/main.rs @@ -122,7 +122,8 @@ fn maybe_redir(name: &str) -> Result<&'static str, Redirect> { /***************************** `content` Responders ***************************/ -use rocket::Request; +use rocket::request::Request; +use rocket::http::uri::Origin; use rocket::response::content; // NOTE: This example explicitly uses the `RawJson` type from @@ -143,15 +144,17 @@ fn json() -> content::RawJson<&'static str> { content::RawJson(r#"{ "payload": "I'm here" }"#) } +// TODO: Should we allow this? +// Unlike in routes, you actually can use `&Request` in catchers. #[catch(404)] -fn not_found(request: &Request<'_>) -> content::RawHtml { - let html = match request.format() { +fn not_found(req: &Request<'_>, uri: &Origin) -> content::RawHtml { + let html = match req.format() { Some(ref mt) if !(mt.is_xml() || mt.is_html()) => { format!("

'{}' requests are not supported.

", mt) } _ => format!("

Sorry, '{}' is an invalid path! Try \ /hello/<name>/<age> instead.

", - request.uri()) + uri) }; content::RawHtml(html) diff --git a/examples/serialization/src/uuid.rs b/examples/serialization/src/uuid.rs index 15c804b733..f25b2b2e45 100644 --- a/examples/serialization/src/uuid.rs +++ b/examples/serialization/src/uuid.rs @@ -7,11 +7,15 @@ use rocket::serde::uuid::Uuid; // real application this would be a database. struct People(HashMap); +// TODO: this is actually the same as previous, since Result didn't +// set or override the status. #[get("/people/")] -fn people(id: Uuid, people: &State) -> Result { - people.0.get(&id) - .map(|person| format!("We found: {}", person)) - .ok_or_else(|| format!("Missing person for UUID: {}", id)) +fn people(id: Uuid, people: &State) -> String { + if let Some(person) = people.0.get(&id) { + format!("We found: {}", person) + } else { + format!("Missing person for UUID: {}", id) + } } pub fn stage() -> rocket::fairing::AdHoc { diff --git a/examples/templating/src/hbs.rs b/examples/templating/src/hbs.rs index c3edcdb221..46af55e77c 100644 --- a/examples/templating/src/hbs.rs +++ b/examples/templating/src/hbs.rs @@ -1,4 +1,4 @@ -use rocket::Request; +use rocket::http::uri::Origin; use rocket::response::Redirect; use rocket_dyn_templates::{Template, handlebars, context}; @@ -28,9 +28,9 @@ pub fn about() -> Template { } #[catch(404)] -pub fn not_found(req: &Request<'_>) -> Template { +pub fn not_found(uri: &Origin<'_>) -> Template { Template::render("hbs/error/404", context! { - uri: req.uri() + uri, }) } diff --git a/examples/templating/src/tera.rs b/examples/templating/src/tera.rs index 8e5e0b8372..a7c34fc76c 100644 --- a/examples/templating/src/tera.rs +++ b/examples/templating/src/tera.rs @@ -1,4 +1,4 @@ -use rocket::Request; +use rocket::http::uri::Origin; use rocket::response::Redirect; use rocket_dyn_templates::{Template, tera::Tera, context}; @@ -25,9 +25,9 @@ pub fn about() -> Template { } #[catch(404)] -pub fn not_found(req: &Request<'_>) -> Template { +pub fn not_found(uri: &Origin<'_>) -> Template { Template::render("tera/error/404", context! { - uri: req.uri() + uri, }) } diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index a12f7ab439..d3086493d3 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -7,6 +7,7 @@ mod tests; mod task; use rocket::{Rocket, Build}; +use rocket::either::Either; use rocket::fairing::AdHoc; use rocket::request::FlashMessage; use rocket::response::{Flash, Redirect}; @@ -64,23 +65,29 @@ async fn new(todo_form: Form, conn: DbConn) -> Flash { } #[put("/")] -async fn toggle(id: i32, conn: DbConn) -> Result { +async fn toggle(id: i32, conn: DbConn) -> Either { match Task::toggle_with_id(id, &conn).await { - Ok(_) => Ok(Redirect::to("/")), + Ok(_) => Either::Left(Redirect::to("/")), Err(e) => { error!("DB toggle({id}) error: {e}"); - Err(Template::render("index", Context::err(&conn, "Failed to toggle task.").await)) + Either::Right(Template::render( + "index", + Context::err(&conn, "Failed to toggle task.").await + )) } } } #[delete("/")] -async fn delete(id: i32, conn: DbConn) -> Result, Template> { +async fn delete(id: i32, conn: DbConn) -> Either, Template> { match Task::delete_with_id(id, &conn).await { - Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")), + Ok(_) => Either::Left(Flash::success(Redirect::to("/"), "Todo was deleted.")), Err(e) => { error!("DB deletion({id}) error: {e}"); - Err(Template::render("index", Context::err(&conn, "Failed to delete task.").await)) + Either::Right(Template::render( + "index", + Context::err(&conn, "Failed to delete task.").await + )) } } }