Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a reason for not acceptable server request rejections #2607

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
rustTemplate(
"""
if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) {
return Err(#{RequestRejection}::NotAcceptable);
}
#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()})?;
""",
*codegenScope,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

use crate::rejection::MissingContentTypeReason;
use crate::rejection::{MissingContentTypeReason, NotAcceptableReason};
use thiserror::Error;

#[derive(Debug, Error)]
Expand All @@ -18,8 +18,8 @@ pub enum ResponseRejection {
pub enum RequestRejection {
#[error("error converting non-streaming body to bytes: {0}")]
BufferHttpBodyBytes(crate::Error),
#[error("request contains invalid value for `Accept` header")]
NotAcceptable,
#[error("request is not acceptable: {0}")]
NotAcceptable(#[from] NotAcceptableReason),
#[error("expected `Content-Type` header not found: {0}")]
MissingContentType(#[from] MissingContentTypeReason),
#[error("error deserializing request HTTP body as JSON: {0}")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
//!
//! Consult `crate::proto::$protocolName::rejection` for rejection types for other protocols.

use crate::rejection::MissingContentTypeReason;
use crate::rejection::{MissingContentTypeReason, NotAcceptableReason};
use std::num::TryFromIntError;
use thiserror::Error;

Expand Down Expand Up @@ -109,10 +109,10 @@ pub enum RequestRejection {
#[error("error converting non-streaming body to bytes: {0}")]
BufferHttpBodyBytes(crate::Error),

/// Used when the request contained an `Accept` header with a MIME type, and the server cannot
/// return a response body adhering to that MIME type.
#[error("request contains invalid value for `Accept` header")]
NotAcceptable,
/// Used when the request contained `Accept` headers with certain MIME types, and the server cannot
/// return a response body adhering to _any_ of those MIME types.
#[error("request is not acceptable: {0}")]
NotAcceptable(#[from] NotAcceptableReason),

/// Used when checking the `Content-Type` header.
#[error("expected `Content-Type` header not found: {0}")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl From<RequestRejection> for RuntimeError {
match err {
RequestRejection::MissingContentType(_reason) => Self::UnsupportedMediaType,
RequestRejection::ConstraintViolation(reason) => Self::Validation(reason),
RequestRejection::NotAcceptable => Self::NotAcceptable,
RequestRejection::NotAcceptable(_reason) => Self::NotAcceptable,
_ => Self::Serialization(crate::Error::new(err)),
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! [`crate::proto::rest_json_1::rejection::RequestRejection::JsonDeserialize`] is swapped for
//! [`RequestRejection::XmlDeserialize`].

use crate::rejection::MissingContentTypeReason;
use crate::rejection::{MissingContentTypeReason, NotAcceptableReason};
use std::num::TryFromIntError;
use thiserror::Error;

Expand All @@ -28,8 +28,8 @@ pub enum RequestRejection {
#[error("error converting non-streaming body to bytes: {0}")]
BufferHttpBodyBytes(crate::Error),

#[error("request contains invalid value for `Accept` header")]
NotAcceptable,
#[error("request is not acceptable: {0}")]
NotAcceptable(#[from] NotAcceptableReason),

#[error("expected `Content-Type` header not found: {0}")]
MissingContentType(#[from] MissingContentTypeReason),
Expand Down
149 changes: 98 additions & 51 deletions rust-runtime/aws-smithy-http-server/src/protocols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
*/

//! Protocol helpers.
use crate::rejection::MissingContentTypeReason;
use http::HeaderMap;
use crate::rejection::{MissingContentTypeReason, NotAcceptableReason};
use http::{HeaderMap, HeaderValue};

/// When there are no modeled inputs,
/// a request body is empty and the content-type request header must not be set
Expand Down Expand Up @@ -66,17 +66,18 @@ pub fn content_type_header_classifier(
Ok(())
}

pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) -> bool {
pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) -> Result<(), NotAcceptableReason> {
if !headers.contains_key(http::header::ACCEPT) {
return true;
return Ok(());
}
// Must be of the form: type/subtype
// Must be of the form: `type/subtype`.
let content_type = content_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: I feel like this conversion can be done in a static once_cell::Lazy outside of this function, then we pass the &Mime into this function instead. That way we're not doing it over and over.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at Mime::parse and it appears to be a non-negligible process, so putting Mime behind a OnceCell and incurring an atomic read should indeed speed things up.

I also learnt that mime is not actively maintained, and stumbled upon mediatype, which is const-constructible, so for our static mime type coming from the sSDK I think it would make more sense to switch to that. It's also zero-copy, so parsing the request's Accept and Content-Type headers' mime types and performing the type + subtype equality check should be faster.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a plan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first part: #2629

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2629 is throwaway work if we switch away from mime. I've opened #2666 to track that.

.parse::<mime::Mime>()
.expect("BUG: MIME parsing failed, content_type is not valid");
headers
.get_all(http::header::ACCEPT)
.into_iter()
// `expect` safety: content_type` is sent in from the generated server SDK and we know it's valid.
.expect("MIME parsing failed, `content_type` is not valid; please file a bug report under https://github.com/awslabs/smithy-rs/issues");
let accept_headers = headers.get_all(http::header::ACCEPT);
let can_satisfy_some_accept_header = accept_headers
.iter()
.flat_map(|header| {
header
.to_str()
Expand All @@ -88,20 +89,30 @@ pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str)
* and remove the optional "; q=x" parameters
* NOTE: the `unwrap`() is safe, because it takes the first element (if there's nothing to split, returns the string)
*/
.flat_map(|s| s.split(',').map(|typ| typ.split(';').next().unwrap().trim()))
.flat_map(|s| s.split(',').map(|type_| type_.split(';').next().unwrap().trim()))
})
.filter_map(|h| h.parse::<mime::Mime>().ok())
.any(|mim| {
let typ = content_type.type_();
.any(|mime| {
let type_ = content_type.type_();
let subtype = content_type.subtype();
// Accept: */*, type/*, type/subtype
match (mim.type_(), mim.subtype()) {
(t, s) if t == typ && s == subtype => true,
(t, mime::STAR) if t == typ => true,
match (mime.type_(), mime.subtype()) {
(t, s) if t == type_ && s == subtype => true,
(t, mime::STAR) if t == type_ => true,
(mime::STAR, mime::STAR) => true,
_ => false,
}
})
});
if can_satisfy_some_accept_header {
Ok(())
} else {
// We can't make `NotAcceptableReason`/`RequestRejection` borrow the header values because
// non-static lifetimes are not allowed in the source of an error, because
// `std::error::Error` requires the source is `dyn Error + 'static`. So we clone them into
// a vector in the case of a request rejection.
let cloned_accept_headers: Vec<HeaderValue> = accept_headers.into_iter().cloned().collect();
Err(NotAcceptableReason::CannotSatisfyAcceptHeaders(cloned_accept_headers))
}
}

#[cfg(test)]
Expand All @@ -115,9 +126,9 @@ mod tests {
headers
}

fn req_accept(accept: &'static str) -> HeaderMap {
fn req_accept(accept: &str) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static(accept));
headers.insert(ACCEPT, HeaderValue::from_str(accept).unwrap());
headers
}

Expand Down Expand Up @@ -192,44 +203,80 @@ mod tests {
assert!(matches!(result.unwrap_err(), MissingContentTypeReason::ToStrError(_)));
}

#[test]
fn valid_accept_header_classifier_multiple_values() {
let valid_request = req_accept("text/strings, application/json, invalid");
assert!(accept_header_classifier(&valid_request, "application/json"));
}
mod accept_header_classifier {
use super::*;

#[test]
fn invalid_accept_header_classifier() {
let invalid_request = req_accept("text/invalid, invalid, invalid/invalid");
assert!(!accept_header_classifier(&invalid_request, "application/json"));
}
#[test]
fn valid_single_value() {
let valid_request = req_accept("application/json");
assert!(accept_header_classifier(&valid_request, "application/json").is_ok());
}

#[test]
fn valid_accept_header_classifier_star() {
let valid_request = req_accept("application/*");
assert!(accept_header_classifier(&valid_request, "application/json"));
}
#[test]
fn valid_multiple_values() {
let valid_request = req_accept("text/strings, application/json, invalid");
assert!(accept_header_classifier(&valid_request, "application/json").is_ok());
}

#[test]
fn valid_accept_header_classifier_star_star() {
let valid_request = req_accept("*/*");
assert!(accept_header_classifier(&valid_request, "application/json"));
}
#[test]
fn subtype_star() {
let valid_request = req_accept("application/*");
assert!(accept_header_classifier(&valid_request, "application/json").is_ok());
}

#[test]
fn valid_empty_accept_header_classifier() {
assert!(accept_header_classifier(&HeaderMap::new(), "application/json"));
}
#[test]
fn type_star_subtype_star() {
let valid_request = req_accept("*/*");
assert!(accept_header_classifier(&valid_request, "application/json").is_ok());
}

#[test]
fn valid_accept_header_classifier_with_params() {
let valid_request = req_accept("application/json; q=30, */*");
assert!(accept_header_classifier(&valid_request, "application/json"));
}
#[test]
fn empty() {
assert!(accept_header_classifier(&HeaderMap::new(), "application/json").is_ok());
}

#[test]
fn valid_accept_header_classifier() {
let valid_request = req_accept("application/json");
assert!(accept_header_classifier(&valid_request, "application/json"));
#[test]
fn valid_with_params() {
let valid_request = req_accept("application/json; q=30, */*");
assert!(accept_header_classifier(&valid_request, "application/json").is_ok());
}

#[test]
fn unstatisfiable_multiple_values() {
let accept_header_values = ["text/invalid, invalid, invalid/invalid"];
let joined = accept_header_values.join(", ");
let invalid_request = req_accept(&joined);
match accept_header_classifier(&invalid_request, "application/json").unwrap_err() {
NotAcceptableReason::CannotSatisfyAcceptHeaders(returned_accept_header_values) => {
for header_value in accept_header_values {
let header_value = HeaderValue::from_str(header_value).unwrap();
assert!(returned_accept_header_values.contains(&header_value));
}
}
}
}

#[test]
fn unstatisfiable_unparseable() {
let header_value = "foo_"; // Not a valid MIME type.
assert!(header_value.parse::<mime::Mime>().is_err());
let invalid_request = req_accept(header_value);
match accept_header_classifier(&invalid_request, "application/json").unwrap_err() {
NotAcceptableReason::CannotSatisfyAcceptHeaders(returned_accept_header_values) => {
let header_value = HeaderValue::from_str(header_value).unwrap();
assert!(returned_accept_header_values.contains(&header_value));
}
}
}

#[test]
#[should_panic]
fn panic_if_content_type_not_parseable() {
let header_value = "foo_"; // Not a valid MIME type.
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_str(header_value).unwrap());
assert!(header_value.parse::<mime::Mime>().is_err());
let _res = accept_header_classifier(&headers, header_value);
}
}
}
8 changes: 8 additions & 0 deletions rust-runtime/aws-smithy-http-server/src/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

use crate::response::IntoResponse;
use http::HeaderValue;
use thiserror::Error;

// This is used across different protocol-specific `rejection` modules.
Expand All @@ -24,6 +25,13 @@ pub enum MissingContentTypeReason {
},
}

// This is used across different protocol-specific `rejection` modules.
#[derive(Debug, Error)]
pub enum NotAcceptableReason {
#[error("cannot satisfy any of `Accept` header values: {0:?}")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't log values

CannotSatisfyAcceptHeaders(Vec<HeaderValue>),
}

pub mod any_rejections {
//! This module hosts enums, up to size 8, which implement [`IntoResponse`] when their variants implement
//! [`IntoResponse`].
Expand Down