Skip to content

Commit

Permalink
Don't panic if Timings are not provided in the Request.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendorff committed Jan 25, 2024
1 parent eee91a5 commit 6fc5d2e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
5 changes: 3 additions & 2 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ where
Req: prost::Message + Default + serde::de::DeserializeOwned,
Resp: prost::Message + serde::Serialize,
{
let mut timings = *req
let mut timings = req
.extensions()
.get::<Timings>()
.expect("invariant violated: timing info not present in request");
.copied()
.unwrap_or_else(|| Timings::new(Instant::now()));

let (req, resp_fmt) = match parse_request(req, &mut timings).await {
Ok(pair) => pair,
Expand Down
66 changes: 66 additions & 0 deletions example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ impl haberdash::HaberdasherAPI for HaberdasherAPIServer {

#[cfg(test)]
mod test {
use service::haberdash::v1::HaberdasherAPIClient;
use twirp::client::Client;
use twirp::url::Url;
use twirp::TwirpErrorCode;

use crate::service::haberdash::v1::HaberdasherAPI;
Expand All @@ -88,4 +91,67 @@ mod test {
let err = res.unwrap_err();
assert_eq!(err.code, TwirpErrorCode::InvalidArgument);
}

/// A running network server task, bound to an arbitrary port on localhost, chosen by the OS
struct NetServer {
port: u16,
server_task: tokio::task::JoinHandle<()>,
shutdown_sender: tokio::sync::oneshot::Sender<()>,
}

impl NetServer {
async fn start(api_impl: Arc<HaberdasherAPIServer>) -> Self {
let twirp_routes =
Router::new().nest(haberdash::SERVICE_FQN, haberdash::router(api_impl));
let app = Router::new()
.nest("/twirp", twirp_routes)
.route("/_ping", get(ping))
.fallback(twirp::server::not_found_handler);

let tcp_listener = tokio::net::TcpListener::bind("localhost:0")
.await
.expect("failed to bind");
let addr = tcp_listener.local_addr().unwrap();
println!("Listening on {addr}");
let port = addr.port();

let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel::<()>();
let server_task = tokio::spawn(async move {
let shutdown_receiver = async move {
shutdown_receiver.await.unwrap();
};
if let Err(e) = axum::serve(tcp_listener, app)
.with_graceful_shutdown(shutdown_receiver)
.await
{
eprintln!("server error: {}", e);
}
});

NetServer {
port,
server_task,
shutdown_sender,
}
}

async fn shutdown(self) {
self.shutdown_sender.send(()).unwrap();
self.server_task.await.unwrap();
}
}

#[tokio::test]
async fn test_net() {
let api_impl = Arc::new(HaberdasherAPIServer {});
let server = NetServer::start(api_impl).await;

let url = Url::parse(&format!("http://localhost:{}/twirp/", server.port)).unwrap();
let client = Client::from_base_url(url).unwrap();
let resp = client.make_hat(MakeHatRequest { inches: 1 }).await;
println!("{:?}", resp);
assert_eq!(resp.unwrap().size, 1);

server.shutdown().await;
}
}

0 comments on commit 6fc5d2e

Please sign in to comment.