diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 85b624b567..c61947a397 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -178,11 +178,13 @@ class ServerHttpBoundProtocolTraitImplGenerator( outputSymbol: Symbol, operationShape: OperationShape, ) { + val operationName = symbolProvider.toSymbol(operationShape).name + val staticContentType = "CONTENT_TYPE_${operationName.uppercase()}" val verifyAcceptHeader = writable { httpBindingResolver.responseContentType(operationShape)?.also { contentType -> rustTemplate( """ - if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), ${contentType.dq()}) { + if !#{SmithyHttpServer}::protocols::accept_header_classifier(request.headers(), &$staticContentType) { return Err(#{RequestRejection}::NotAcceptable); } """, @@ -190,6 +192,22 @@ class ServerHttpBoundProtocolTraitImplGenerator( ) } } + val verifyAcceptHeaderStaticContentTypeInit = writable { + httpBindingResolver.responseContentType(operationShape)?.also { contentType -> + val init = when (contentType) { + "application/json" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_JSON;" + "application/octet-stream" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_OCTET_STREAM;" + "application/x-www-form-urlencoded" -> "const $staticContentType: #{Mime}::Mime = #{Mime}::APPLICATION_WWW_FORM_URLENCODED;" + else -> + """ + static $staticContentType: #{OnceCell}::sync::Lazy<#{Mime}::Mime> = #{OnceCell}::sync::Lazy::new(|| { + ${contentType.dq()}.parse::<#{Mime}::Mime>().expect("BUG: MIME parsing failed, content_type is not valid") + }); + """ + } + rustTemplate(init, *codegenScope) + } + } val verifyRequestContentTypeHeader = writable { operationShape .inputShape(model) @@ -215,6 +233,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( // TODO(https://github.com/awslabs/smithy-rs/issues/2238): Remove the `Pin>` and replace with thin wrapper around `Collect`. rustTemplate( """ + #{verifyAcceptHeaderStaticContentTypeInit:W} #{PinProjectLite}::pin_project! { /// A [`Future`](std::future::Future) aggregating the body bytes of a [`Request`] and constructing the /// [`${inputSymbol.name}`](#{I}) using modelled bindings. @@ -267,6 +286,7 @@ class ServerHttpBoundProtocolTraitImplGenerator( "Marker" to protocol.markerStruct(), "parse_request" to serverParseRequest(operationShape), "verifyAcceptHeader" to verifyAcceptHeader, + "verifyAcceptHeaderStaticContentTypeInit" to verifyAcceptHeaderStaticContentTypeInit, "verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader, ) diff --git a/rust-runtime/aws-smithy-http-server/src/protocols.rs b/rust-runtime/aws-smithy-http-server/src/protocols.rs index 2db5e7163d..75d723f3fc 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocols.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocols.rs @@ -66,14 +66,10 @@ pub fn content_type_header_classifier( Ok(()) } -pub fn accept_header_classifier(headers: &HeaderMap, content_type: &'static str) -> bool { +pub fn accept_header_classifier(headers: &HeaderMap, content_type: &mime::Mime) -> bool { if !headers.contains_key(http::header::ACCEPT) { return true; } - // Must be of the form: type/subtype - let content_type = content_type - .parse::() - .expect("BUG: MIME parsing failed, content_type is not valid"); headers .get_all(http::header::ACCEPT) .into_iter() @@ -195,41 +191,62 @@ mod tests { #[test] fn valid_accept_header_classifier_multiple_values() { let valid_request = req_accept("text/strings, application/json, invalid"); - assert!(accept_header_classifier(&valid_request, "application/json")); + assert!(accept_header_classifier( + &valid_request, + &"application/json".parse().unwrap() + )); } #[test] fn invalid_accept_header_classifier() { let invalid_request = req_accept("text/invalid, invalid, invalid/invalid"); - assert!(!accept_header_classifier(&invalid_request, "application/json")); + assert!(!accept_header_classifier( + &invalid_request, + &"application/json".parse().unwrap() + )); } #[test] fn valid_accept_header_classifier_star() { let valid_request = req_accept("application/*"); - assert!(accept_header_classifier(&valid_request, "application/json")); + assert!(accept_header_classifier( + &valid_request, + &"application/json".parse().unwrap() + )); } #[test] fn valid_accept_header_classifier_star_star() { let valid_request = req_accept("*/*"); - assert!(accept_header_classifier(&valid_request, "application/json")); + assert!(accept_header_classifier( + &valid_request, + &"application/json".parse().unwrap() + )); } #[test] fn valid_empty_accept_header_classifier() { - assert!(accept_header_classifier(&HeaderMap::new(), "application/json")); + assert!(accept_header_classifier( + &HeaderMap::new(), + &"application/json".parse().unwrap() + )); } #[test] fn valid_accept_header_classifier_with_params() { let valid_request = req_accept("application/json; q=30, */*"); - assert!(accept_header_classifier(&valid_request, "application/json")); + assert!(accept_header_classifier( + &valid_request, + &"application/json".parse().unwrap() + )); } #[test] fn valid_accept_header_classifier() { let valid_request = req_accept("application/json"); - assert!(accept_header_classifier(&valid_request, "application/json")); + assert!(accept_header_classifier( + &valid_request, + &"application/json".parse().unwrap() + )); } }