diff --git a/Cargo.lock b/Cargo.lock index c639c01..f39e2b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,6 +62,61 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1236b4b292f6c4d6dc34604bb5120d85c3fe1d1aa596bd5cc52ca054d13e7b9e" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "http-body-util", + "hyper 1.1.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -202,9 +257,10 @@ name = "example" version = "0.1.0" dependencies = [ "async-trait", + "axum", "fs-err", "glob", - "hyper", + "hyper 0.14.28", "prost", "prost-build", "prost-wkt", @@ -390,7 +446,26 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.11", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 1.0.0", "indexmap", "slab", "tokio", @@ -436,6 +511,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -443,7 +529,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.11", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.0.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", "pin-project-lite", ] @@ -469,9 +578,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.23", + "http 0.2.11", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -483,6 +592,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.2", + "http 1.0.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "tokio", + "want", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -490,12 +619,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.28", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdea9aac0dbe5a9240d68cfd9501e2db94222c6dc06843e06640b9e07f0fdc67" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "hyper 1.1.0", + "pin-project-lite", + "socket2", + "tokio", + "tracing", +] + [[package]] name = "idna" version = "0.5.0" @@ -576,6 +723,12 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "memchr" version = "2.7.1" @@ -726,6 +879,26 @@ dependencies = [ "indexmap", ] +[[package]] +name = "pin-project" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -922,10 +1095,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.23", + "http 0.2.11", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-tls", "ipnet", "js-sys", @@ -969,6 +1142,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + [[package]] name = "ryu" version = "1.0.16" @@ -1038,6 +1217,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebd154a240de39fdebcf5775d2675c204d7c13cf39a4c697be6493c8e734337c" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1080,6 +1269,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "system-configuration" version = "0.5.1" @@ -1201,6 +1396,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + [[package]] name = "tower-service" version = "0.3.2" @@ -1213,6 +1430,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-core", ] @@ -1234,23 +1452,28 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "twirp" -version = "0.1.0" +version = "0.2.0" dependencies = [ "async-trait", + "axum", + "bytes", "futures", - "hyper", + "http 1.0.0", + "http-body-util", + "hyper 1.1.0", "prost", "reqwest", "serde", "serde_json", "thiserror", "tokio", + "tower", "url", ] [[package]] name = "twirp-build" -version = "0.1.0" +version = "0.2.0" dependencies = [ "prost-build", ] diff --git a/README.md b/README.md index 43cdd2d..16b00fa 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Add the `twirp-build` crate as a build dependency in your `Cargo.toml` (you'll n ```toml # Cargo.toml [build-dependencies] -twirp-build = "0.1" +twirp-build = "0.2" prost-build = "0.12" ``` @@ -58,22 +58,22 @@ Include the generated code, create a router, register your service, and then ser mod haberdash { include!(concat!(env!("OUT_DIR"), "/service.haberdash.v1.rs")); } -use haberdash + +use haberdash::{MakeHatRequest, MakeHatResponse}; #[tokio::main] pub async fn main() { - let mut router = Router::default(); - let server = Arc::new(HaberdasherAPIServer {}); - haberdash::add_service(&mut router, server.clone()); - let router = Arc::new(router); - let service = make_service_fn(move |_| { - let router = router.clone(); - async { Ok::<_, GenericError>(service_fn(move |req| twirp::serve(router.clone(), req))) } - }); - - let addr = ([127, 0, 0, 1], 3000).into(); - let server = Server::bind(&addr).serve(service); - server.await.expect("server error") + let api_impl = Arc::new(HaberdasherAPIServer {}); + let twirp_routes = Router::new() + .nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)); + let app = Router::new() + .nest("/twirp", twirp_routes) + .fallback(twirp::server::not_found_handler); + + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:3000").await.unwrap(); + if let Err(e) = axum::serve(tcp_listener, app).await { + eprintln!("server error: {}", e); + } } // Define the server and implement the trait. diff --git a/crates/twirp-build/Cargo.toml b/crates/twirp-build/Cargo.toml index fb0be95..1206b77 100644 --- a/crates/twirp-build/Cargo.toml +++ b/crates/twirp-build/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "twirp-build" -version = "0.1.0" +version = "0.2.0" authors = ["The blackbird team "] edition = "2021" diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 74a4fd3..c0aa439 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -19,6 +19,8 @@ impl prost_build::ServiceGenerator for ServiceGenerator { let service_fqn = format!("{}.{}", service.package, service_name); writeln!(buf).unwrap(); + writeln!(buf, "pub const SERVICE_FQN: &str = \"{service_fqn}\";").unwrap(); + // // generate the twirp server // @@ -37,32 +39,44 @@ impl prost_build::ServiceGenerator for ServiceGenerator { // add_service writeln!( buf, - r#"pub fn add_service(router: &mut twirp::Router, api: std::sync::Arc) + r#"pub fn router(api: std::sync::Arc) -> twirp::Router where - T: {} + Send + Sync + 'static, -{{"#, - service_name + T: {service_name} + Send + Sync + 'static, +{{ + twirp::Router::new()"#, ) .unwrap(); for m in &service.methods { + let uri = &m.proto_name; + let rust_method_name = &m.name; writeln!( buf, - r#" {{ - #[allow(clippy::redundant_clone)] - let api = api.clone(); - router.add_method( - "{}/{}", - move |req| {{ - let api = api.clone(); - async move {{ api.{}(req).await }} - }}, - ); - }}"#, - service_fqn, m.proto_name, m.name + r#" .route( + "/{uri}", + twirp::details::post( + |twirp::details::State(api): twirp::details::State>, + req: twirp::details::Request| async move {{ + twirp::server::handle_request( + req, + move |req| async move {{ + api.{rust_method_name}(req).await + }}, + ) + .await + }}, + ), + )"#, ) .unwrap(); } - writeln!(buf, "}}").unwrap(); + writeln!( + buf, + r#" + .with_state(api) + .fallback(twirp::server::not_found_handler) +}}"# + ) + .unwrap(); // // generate the twirp client diff --git a/crates/twirp/Cargo.toml b/crates/twirp/Cargo.toml index b6a4676..871735a 100644 --- a/crates/twirp/Cargo.toml +++ b/crates/twirp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "twirp" -version = "0.1.0" +version = "0.2.0" authors = ["The blackbird team "] edition = "2021" @@ -22,7 +22,12 @@ reqwest = { version = "0.11", features = ["default", "gzip", "json"], optional = url = { version = "2.5", optional = true } # For the server feature -hyper = { version = "0.14", features = ["full"], optional = true } +axum = "0.7" +bytes = "1.5" +http = "1.0" +http-body-util = "0.1" +hyper = { version = "1.1", features = ["full"], optional = true } +tower = "0.4" # For the test-support feature async-trait = { version = "0.1", optional = true } diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index 5ff23b7..e4de595 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -1,13 +1,13 @@ use std::sync::Arc; use async_trait::async_trait; -use hyper::header::{InvalidHeaderValue, CONTENT_TYPE}; -use hyper::StatusCode; +use reqwest::header::{InvalidHeaderValue, CONTENT_TYPE}; +use reqwest::StatusCode; use thiserror::Error; use url::Url; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{to_proto_body, TwirpErrorResponse}; +use crate::{serialize_proto_message, TwirpErrorResponse}; #[derive(Debug, Error)] pub enum ClientError { @@ -146,7 +146,7 @@ impl Client { .http_client .post(url) .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) - .body(to_proto_body(body)) + .body(serialize_proto_message(body)) .build()?; // Create and execute the middleware handlers diff --git a/crates/twirp/src/details.rs b/crates/twirp/src/details.rs new file mode 100644 index 0000000..1ab5989 --- /dev/null +++ b/crates/twirp/src/details.rs @@ -0,0 +1,10 @@ +//! Undocumented features that are public for use in generated code (see `twirp-build`). + +#[doc(hidden)] +pub use axum::extract::{Request, State}; + +#[doc(hidden)] +pub use axum::routing::post; + +#[doc(hidden)] +pub use axum::response::Response; diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index e8ab926..4c1a98d 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -2,7 +2,10 @@ use std::collections::HashMap; -use hyper::{header, Body, Response, StatusCode}; +use axum::body::Body; +use axum::response::IntoResponse; +use http::header::{self, HeaderMap, HeaderValue}; +use hyper::{Response, StatusCode}; use serde::{Deserialize, Serialize, Serializer}; // Alias for a generic error @@ -149,14 +152,20 @@ impl TwirpErrorResponse { pub fn insert_meta(&mut self, key: String, value: String) -> Option { self.meta.insert(key, value) } +} + +impl IntoResponse for TwirpErrorResponse { + fn into_response(self) -> Response { + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + let json = + serde_json::to_string(&self).expect("JSON serialization of an error should not fail"); - pub fn to_response(&self) -> Result, GenericError> { - let json = serde_json::to_string(self)?; - let response = Response::builder() - .status(self.code.http_status_code()) - .header(header::CONTENT_TYPE, "application/json") - .body(Body::from(json))?; - Ok(response) + (self.code.http_status_code(), headers, json).into_response() } } diff --git a/crates/twirp/src/lib.rs b/crates/twirp/src/lib.rs index 49c08d8..ee05cab 100644 --- a/crates/twirp/src/lib.rs +++ b/crates/twirp/src/lib.rs @@ -8,9 +8,11 @@ pub mod server; #[cfg(any(test, feature = "test-support"))] pub mod test; +pub mod details; + pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result}; pub use error::*; // many constructors like `invalid_argument()` -pub use server::{serve, Router, Timings}; +pub use server::{Router, Timings}; // Re-export `reqwest` so that it's easy to implement middleware. pub use reqwest; @@ -18,7 +20,7 @@ pub use reqwest; // Re-export `url so that the generated code works without additional dependencies beyond just the `twirp` crate. pub use url; -pub(crate) fn to_proto_body(m: T) -> hyper::Body +pub(crate) fn serialize_proto_message(m: T) -> Vec where T: prost::Message, { @@ -27,5 +29,5 @@ where m.encode(&mut data) .expect("can only fail if buffer does not have capacity"); assert_eq!(data.len(), len); - hyper::Body::from(data) + data } diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index d8fb863..de886e0 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -1,155 +1,21 @@ -use std::collections::HashMap; use std::fmt::Debug; -use std::sync::Arc; +use axum::body::Body; +use axum::response::IntoResponse; +pub use axum::Router; use futures::Future; -use hyper::{header, Body, Method, Request, Response}; +use http_body_util::BodyExt; +use hyper::{header, Request, Response}; use serde::de::DeserializeOwned; use serde::Serialize; use tokio::time::{Duration, Instant}; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{error, to_proto_body, GenericError, TwirpErrorResponse}; - -/// A function that handles a request and returns a response. -type HandlerFn = Box) -> HandlerResponse + Send + Sync>; - -/// Type alias for a handler response. -type HandlerResponse = - Box, GenericError>> + Unpin + Send>; - -/// A Router maps a request (method, path) tuple to a handler. -pub struct Router { - routes: HashMap<(Method, String), HandlerFn>, - prefix: &'static str, -} +use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse}; /// The canonical twirp path prefix. You don't have to use this, but it's the default. pub const DEFAULT_TWIRP_PATH_PREFIX: &str = "/twirp"; -impl Default for Router { - fn default() -> Self { - Self::new(DEFAULT_TWIRP_PATH_PREFIX) - } -} - -impl Debug for Router { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Router") - .field("routes", &self.routes.keys()) - .finish() - } -} - -impl Router { - /// Create a new router at the given prefix. Since this prefix is - /// canonically `/twirp`, it is recommended to use `Router::default()` - /// instead. - pub fn new(prefix: &'static str) -> Self { - Self { - routes: Default::default(), - prefix, - } - } - - /// Adds a sync handler to the router for the given method and path. - pub fn add_sync_handler(&mut self, method: Method, path: &str, f: F) - where - F: Fn(Request) -> Result, GenericError> - + Clone - + Sync - + Send - + 'static, - { - let g = move |req| -> Box< - dyn Future, GenericError>> + Unpin + Send, - > { - let f = f.clone(); - Box::new(Box::pin(async move { f(req) })) - }; - let key = (method, path.to_string()); - self.routes.insert(key, Box::new(g)); - } - - /// Adds an async handler to the router for the given method and path. - pub fn add_handler(&mut self, method: Method, path: &str, f: F) - where - F: Fn(Request) -> Fut + Clone + Sync + Send + 'static, - Fut: Future, GenericError>> + Send, - { - let g = move |req| -> Box< - dyn Future, GenericError>> + Unpin + Send, - > { - let f = f.clone(); - Box::new(Box::pin(async move { f(req).await })) - }; - let key = (method, path.to_string()); - self.routes.insert(key, Box::new(g)); - } - - /// Adds a twirp method handler to the router for the given path. - pub fn add_method(&mut self, path: &str, f: F) - where - F: Fn(Req) -> Fut + Clone + Sync + Send + 'static, - Fut: Future> + Send, - Req: prost::Message + Default + serde::de::DeserializeOwned, - Resp: prost::Message + serde::Serialize, - { - let g = move |req: Request| -> Box< - dyn Future, GenericError>> + Unpin + Send, - > { - let f = f.clone(); - Box::new(Box::pin(async move { - let mut timings = *req - .extensions() - .get::() - .expect("invariant violated: timing info not present in request"); - match parse_request(req, &mut timings).await { - Ok((req, resp_fmt)) => { - let res = f(req).await; - timings.set_response_handled(); - write_response(res, resp_fmt) - } - Err(err) => { - // This is the only place we use tracing (would be nice to remove) - // tracing::error!(?err, "failed to parse request"); - // TODO: We don't want to loose the underlying error - // here, but it might not be safe to include in the - // response like this always. - let mut twirp_err = error::malformed("bad request"); - twirp_err.insert_meta("error".to_string(), err.to_string()); - twirp_err.to_response() - } - } - .map(|mut resp| { - timings.set_response_written(); - resp.extensions_mut().insert(timings); - resp - }) - })) - }; - let key = (Method::POST, [self.prefix, path].join("/")); - self.routes.insert(key, Box::new(g)); - } -} - -/// Serve a request using the given router. -pub async fn serve( - router: Arc, - mut req: Request, -) -> Result, GenericError> { - if req.extensions().get::().is_none() { - let start = tokio::time::Instant::now(); - req.extensions_mut().insert(Timings::new(start)); - } - let key = (req.method().clone(), req.uri().path().to_string()); - if let Some(handler) = router.routes.get(&key) { - handler(req).await - } else { - error::bad_route("not found").to_response() - } -} - // TODO: Properly implement JsonPb (de)serialization as it is slightly different // than standard JSON. #[derive(Debug, Clone, Copy, Default)] @@ -172,6 +38,49 @@ impl BodyFormat { } } +/// Entry point used in code generated by `twirp-build`. +pub async fn handle_request(req: Request, f: F) -> Response +where + F: FnOnce(Req) -> Fut + Clone + Sync + Send + 'static, + Fut: Future> + Send, + Req: prost::Message + Default + serde::de::DeserializeOwned, + Resp: prost::Message + serde::Serialize, +{ + let mut timings = *req + .extensions() + .get::() + .expect("invariant violated: timing info not present in request"); + + let (req, resp_fmt) = match parse_request(req, &mut timings).await { + Ok(pair) => pair, + Err(err) => { + // This is the only place we use tracing (would be nice to remove) + // tracing::error!(?err, "failed to parse request"); + // TODO: We don't want to lose the underlying error here, but it might not be safe to + // include in the response like this always. + let mut twirp_err = error::malformed("bad request"); + twirp_err.insert_meta("error".to_string(), err.to_string()); + return twirp_err.into_response(); + } + }; + + let res = f(req).await; + timings.set_response_handled(); + + let mut resp = match write_response(res, resp_fmt) { + Ok(resp) => resp, + Err(err) => { + let mut twirp_err = error::unknown("error serializing response"); + twirp_err.insert_meta("error".to_string(), err.to_string()); + return twirp_err.into_response(); + } + }; + timings.set_response_written(); + + resp.extensions_mut().insert(timings); + resp +} + async fn parse_request( req: Request, timings: &mut Timings, @@ -180,10 +89,10 @@ where T: prost::Message + Default + DeserializeOwned, { let format = BodyFormat::from_content_type(&req); - let bytes = hyper::body::to_bytes(req.into_body()).await?; + let bytes = req.into_body().collect().await?.to_bytes(); timings.set_received(); let request = match format { - BodyFormat::Pb => T::decode(bytes)?, + BodyFormat::Pb => T::decode(&bytes[..])?, BodyFormat::JsonPb => serde_json::from_slice(&bytes)?, }; timings.set_parsed(); @@ -199,25 +108,28 @@ where { let res = match response { Ok(response) => match response_format { - BodyFormat::Pb => { - let response = Response::builder() - .header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) - .body(to_proto_body(response))?; - Ok(response) - } + BodyFormat::Pb => Response::builder() + .header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) + .body(Body::from(serialize_proto_message(response)))?, _ => { let data = serde_json::to_string(&response)?; - let response = Response::builder() + Response::builder() .header(header::CONTENT_TYPE, CONTENT_TYPE_JSON) - .body(Body::from(data))?; - Ok(response) + .body(Body::from(data))? } }, - Err(err) => err.to_response(), - }?; + Err(err) => err.into_response(), + }; Ok(res) } +/// Axum handler function that returns 404 Not Found with a Twirp JSON payload. +/// +/// `axum::Router`'s default fallback handler returns a 404 Not Found with no body content. +pub async fn not_found_handler() -> Response { + error::bad_route("not found").into_response() +} + /// Contains timing information associated with a request. /// To access the timings in a given request, use the [extensions](Request::extensions) /// method and specialize to `Timings` appropriately. @@ -302,30 +214,29 @@ mod tests { use super::*; use crate::test::*; + use tower::Service; + + fn timings() -> Timings { + Timings::new(Instant::now()) + } + #[tokio::test] async fn test_bad_route() { - let router = Arc::new(Router::default()); - let req = Request::get("/nothing").body(Body::empty()).unwrap(); - let resp = serve(router, req).await.unwrap(); + let mut router = test_api_router(); + let req = Request::get("/nothing") + .extension(timings()) + .body(Body::empty()) + .unwrap(); + + let resp = router.call(req).await.unwrap(); let data = read_err_body(resp.into_body()).await; assert_eq!(data, error::bad_route("not found")); } - #[tokio::test] - async fn test_routes() { - let router = test_api_router().await; - assert!(router - .routes - .contains_key(&(Method::POST, "/twirp/test.TestAPI/Ping".to_string()))); - assert!(router - .routes - .contains_key(&(Method::POST, "/twirp/test.TestAPI/Boom".to_string()))); - } - #[tokio::test] async fn test_ping_success() { - let router = test_api_router().await; - let resp = serve(router, gen_ping_request("hi")).await.unwrap(); + let mut router = test_api_router(); + let resp = router.call(gen_ping_request("hi")).await.unwrap(); assert!(resp.status().is_success(), "{:?}", resp); let data: PingResponse = read_json_body(resp.into_body()).await; assert_eq!(&data.name, "hi"); @@ -333,11 +244,12 @@ mod tests { #[tokio::test] async fn test_ping_invalid_request() { - let router = test_api_router().await; + let mut router = test_api_router(); let req = Request::post("/twirp/test.TestAPI/Ping") + .extension(timings()) .body(Body::empty()) // not a valid request .unwrap(); - let resp = serve(router, req).await.unwrap(); + let resp = router.call(req).await.unwrap(); assert!(resp.status().is_client_error(), "{:?}", resp); let data = read_err_body(resp.into_body()).await; @@ -354,15 +266,16 @@ mod tests { #[tokio::test] async fn test_boom() { - let router = test_api_router().await; + let mut router = test_api_router(); let req = serde_json::to_string(&PingRequest { name: "hi".to_string(), }) .unwrap(); let req = Request::post("/twirp/test.TestAPI/Boom") + .extension(timings()) .body(Body::from(req)) .unwrap(); - let resp = serve(router, req).await.unwrap(); + let resp = router.call(req).await.unwrap(); assert!(resp.status().is_server_error(), "{:?}", resp); let data = read_err_body(resp.into_body()).await; assert_eq!(data, error::internal("boom!")); diff --git a/crates/twirp/src/test.rs b/crates/twirp/src/test.rs index df3ac9e..d3386ec 100644 --- a/crates/twirp/src/test.rs +++ b/crates/twirp/src/test.rs @@ -3,63 +3,79 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Server}; +use axum::body::Body; +use axum::Router; +use http_body_util::BodyExt; +use hyper::Request; use serde::de::DeserializeOwned; use tokio::task::JoinHandle; +use tokio::time::Instant; -use crate::{error, Client, GenericError, Result, Router, TwirpErrorResponse}; +use crate::server::Timings; +use crate::{error, Client, Result, TwirpErrorResponse}; -pub async fn run_test_server(port: u16) -> JoinHandle> { - let router = test_api_router().await; - let service = make_service_fn(move |_| { - let router = router.clone(); - async { Ok::<_, GenericError>(service_fn(move |req| crate::serve(router.clone(), req))) } - }); - - let addr = ([127, 0, 0, 1], port).into(); - let server = Server::bind(&addr).serve(service); +pub async fn run_test_server(port: u16) -> JoinHandle> { + let router = test_api_router(); + let addr: std::net::SocketAddr = ([127, 0, 0, 1], port).into(); + let tcp_listener = tokio::net::TcpListener::bind(addr).await.unwrap(); println!("Listening on {addr}"); - let h = tokio::spawn(server); + let h = tokio::spawn(async move { axum::serve(tcp_listener, router).await }); tokio::time::sleep(Duration::from_millis(100)).await; h } -pub async fn test_api_router() -> Arc { +pub fn test_api_router() -> Router { let api = Arc::new(TestAPIServer {}); - let mut router = Router::default(); - // NB: This would be generated - { - let api = api.clone(); - router.add_method("test.TestAPI/Ping", move |req| { - let api = api.clone(); - async move { api.ping(req).await } - }); - } - { - router.add_method("test.TestAPI/Boom", move |req| { - let api = api.clone(); - async move { api.boom(req).await } - }); - } - Arc::new(router) + + // NB: This part would be generated + let test_router = crate::Router::new() + .route( + "/Ping", + crate::details::post( + |crate::details::State(api): crate::details::State>, + req: crate::details::Request| async move { + crate::server::handle_request( + req, + move |req| async move { api.ping(req).await }, + ) + .await + }, + ), + ) + .route( + "/Boom", + crate::details::post( + |crate::details::State(api): crate::details::State>, + req: crate::details::Request| async move { + crate::server::handle_request( + req, + move |req| async move { api.boom(req).await }, + ) + .await + }, + ), + ) + .fallback(crate::server::not_found_handler) + .with_state(api); + + axum::Router::new() + .nest("/twirp/test.TestAPI", test_router) + .fallback(crate::server::not_found_handler) } -pub fn gen_ping_request(name: &str) -> Request { +pub fn gen_ping_request(name: &str) -> Request { let req = serde_json::to_string(&PingRequest { name: name.to_string(), }) .expect("will always be valid json"); Request::post("/twirp/test.TestAPI/Ping") + .extension(Timings::new(Instant::now())) .body(Body::from(req)) .expect("always a valid twirp request") } pub async fn read_string_body(body: Body) -> String { - let data = hyper::body::to_bytes(body) - .await - .expect("invalid body") - .to_vec(); + let data = Vec::::from(body.collect().await.expect("invalid body").to_bytes()); String::from_utf8(data).expect("non-utf8 body") } @@ -67,10 +83,7 @@ pub async fn read_json_body(body: Body) -> T where T: DeserializeOwned, { - let data = hyper::body::to_bytes(body) - .await - .expect("invalid body") - .to_vec(); + let data = Vec::::from(body.collect().await.expect("invalid body").to_bytes()); serde_json::from_slice(&data).expect("twirp response isn't valid JSON") } diff --git a/example/Cargo.toml b/example/Cargo.toml index b161acf..19b8150 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] twirp = { path = "../crates/twirp" } async-trait = "0.1" +axum = "0.7" hyper = { version = "0.14", features = ["full"] } prost = "0.12" prost-wkt = "0.5" diff --git a/example/src/main.rs b/example/src/main.rs index 2aaf783..a69fdf5 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -1,10 +1,10 @@ +use std::net::SocketAddr; use std::sync::Arc; use std::time::UNIX_EPOCH; use async_trait::async_trait; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Response, Server}; -use twirp::{invalid_argument, GenericError, Router, TwirpErrorResponse}; +use axum::routing::get; +use twirp::{invalid_argument, Router, TwirpErrorResponse}; pub mod service { pub mod haberdash { @@ -15,25 +15,25 @@ pub mod service { } use service::haberdash::v1::{self as haberdash, MakeHatRequest, MakeHatResponse}; +async fn ping() -> &'static str { + "Pong\n" +} + #[tokio::main] pub async fn main() { - let mut router = Router::default(); - let example = Arc::new(HaberdasherAPIServer {}); - haberdash::add_service(&mut router, example.clone()); - router.add_sync_handler(Method::GET, "/_ping", |_req| { - Ok(Response::new(Body::from("Pong\n"))) - }); - println!("{router:?}"); - let router = Arc::new(router); - let service = make_service_fn(move |_| { - let router = router.clone(); - async { Ok::<_, GenericError>(service_fn(move |req| twirp::serve(router.clone(), req))) } - }); + let api_impl = Arc::new(HaberdasherAPIServer {}); + let twirp_routes = Router::new().nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)); + let app = Router::new() + .nest("/twirp", twirp_routes) + .route("/_ping", get(ping)) + .fallback(twirp::server::not_found_handler); - let addr = ([127, 0, 0, 1], 3000).into(); - let server = Server::bind(&addr).serve(service); + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let tcp_listener = tokio::net::TcpListener::bind(addr) + .await + .expect("failed to bind"); println!("Listening on {addr}"); - if let Err(e) = server.await { + if let Err(e) = axum::serve(tcp_listener, app).await { eprintln!("server error: {}", e); } }