Skip to content

Commit

Permalink
fix: router
Browse files Browse the repository at this point in the history
  • Loading branch information
TroyKomodo committed Jan 1, 2024
1 parent 2de7cdc commit b68d97d
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 50 deletions.
27 changes: 20 additions & 7 deletions common/src/http/router/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::{Debug, Formatter};

use super::extend::ExtendRouter;
use super::middleware::{Middleware, PostMiddlewareHandler, PreMiddlewareHandler};
use super::route::{Route, RouteHandler, RouterItem};
use super::types::{ErrorHandler, RouteInfo};
Expand Down Expand Up @@ -39,6 +40,10 @@ impl<I: 'static, O: 'static, E: 'static> RouterBuilder<I, O, E> {
}
}

pub fn extend(self, mut extend: impl ExtendRouter<I, O, E>) -> Self {
extend.extend(self)
}

pub fn middleware(mut self, middleware: Middleware<O, E>) -> Self {
match middleware {
Middleware::Pre(handler) => self.pre_middleware.push(handler),
Expand All @@ -48,11 +53,16 @@ impl<I: 'static, O: 'static, E: 'static> RouterBuilder<I, O, E> {
self
}

pub fn data<T: Clone + Send + Sync + 'static>(self, data: T) -> Self {
self.middleware(Middleware::pre(move |mut req| {
req.extensions_mut().insert(data.clone());
async move { Ok(req) }
}))
pub fn data<T: Clone + Send + Sync + 'static>(mut self, data: T) -> Self {
self.pre_middleware.insert(
0,
PreMiddlewareHandler(Box::new(move |mut req: hyper::Request<()>| {
req.extensions_mut().insert(data.clone());
Box::pin(async move { Ok(req) })
})),
);

self
}

pub fn error_handler<F: std::future::Future<Output = hyper::Response<O>> + Send + 'static>(
Expand Down Expand Up @@ -228,7 +238,7 @@ impl<I: 'static, O: 'static, E: 'static> RouterBuilder<I, O, E> {
let full_path = format!(
"/{method}/{}{}{}",
parent_path,
if parent_path.is_empty() { "" } else { "/" },
if parent_path.is_empty() || path.is_empty() { "" } else { "/" },
path
);

Expand All @@ -240,7 +250,10 @@ impl<I: 'static, O: 'static, E: 'static> RouterBuilder<I, O, E> {
let parent_path = parent_path.trim_matches('/');
let path = path.trim_matches('/');
router.build_scoped(
&format!("{parent_path}{}{path}", if parent_path.is_empty() { "" } else { "/" }),
&format!(
"{parent_path}{}{path}",
if parent_path.is_empty() || path.is_empty() { "" } else { "/" }
),
target,
&pre_middleware_idxs,
&post_middleware_idxs,
Expand Down
15 changes: 15 additions & 0 deletions common/src/http/router/extend.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::builder::RouterBuilder;

pub trait ExtendRouter<I, O, E> {
fn extend(&mut self, router: RouterBuilder<I, O, E>) -> RouterBuilder<I, O, E>;
}

impl<I, O, E, F: Fn(RouterBuilder<I, O, E>) -> RouterBuilder<I, O, E>> ExtendRouter<I, O, E> for F {
fn extend(&mut self, router: RouterBuilder<I, O, E>) -> RouterBuilder<I, O, E> {
self(router)
}
}

pub fn extend_fn<I, O, E, F: Fn(RouterBuilder<I, O, E>) -> RouterBuilder<I, O, E>>(f: F) -> impl ExtendRouter<I, O, E> {
f
}
1 change: 1 addition & 0 deletions common/src/http/router/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod builder;
pub mod compat;
pub mod error;
pub mod ext;
pub mod extend;
pub mod middleware;
pub mod route;
pub mod types;
Expand Down
37 changes: 26 additions & 11 deletions platform/api/src/api/middleware/cors.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
use std::sync::Arc;

use common::http::router::extend::{extend_fn, ExtendRouter};
use common::http::router::middleware::Middleware;
use common::http::RouteError;
use common::make_response;
use hyper::body::Incoming;
use hyper::http::header;
use serde_json::json;

use crate::api::error::ApiError;
use crate::api::Body;
use crate::global::ApiGlobal;

pub fn cors_middleware<G: ApiGlobal>(_: &Arc<G>) -> Middleware<Body, RouteError<ApiError>> {
Middleware::post(|mut resp| async move {
resp.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
resp.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST, OPTIONS".parse().unwrap());
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
"Content-Type, Authorization".parse().unwrap(),
);
pub fn cors_middleware<G: ApiGlobal>(_: &Arc<G>) -> impl ExtendRouter<Incoming, Body, RouteError<ApiError>> {
extend_fn(|router| {
router
.middleware(Middleware::post(|mut resp| async move {
resp.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
resp.headers_mut()
.insert(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST, OPTIONS".parse().unwrap());
resp.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
"Content-Type, Authorization".parse().unwrap(),
);

Ok(resp)
Ok(resp)
}))
.options("/*", |_| async move {
Ok(make_response!(
hyper::StatusCode::OK,
json!({
"success": true,
})
))
})
})
}
54 changes: 25 additions & 29 deletions platform/api/src/api/middleware/response_headers.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,40 @@
use std::sync::{Arc, Mutex};

use common::http::router::ext::RequestExt as _;
use common::http::router::extend::{extend_fn, ExtendRouter};
use common::http::router::middleware::Middleware;
use common::http::RouteError;
use hyper::body::Incoming;
use hyper::header::IntoHeaderName;
use hyper::Request;

use crate::api::error::ApiError;
use crate::api::Body;
use crate::global::ApiGlobal;

#[derive(Clone)]
pub struct ResponseHeadersMiddleware(pub Arc<Mutex<hyper::HeaderMap>>);

impl Default for ResponseHeadersMiddleware {
fn default() -> Self {
Self(Arc::new(Mutex::new(hyper::HeaderMap::new())))
}
}

pub fn pre_flight_middleware<G: ApiGlobal>(_: &Arc<G>) -> Middleware<Body, RouteError<ApiError>> {
Middleware::pre(|mut req| async move {
req.extensions_mut().insert(ResponseHeadersMiddleware::default());

Ok(req)
})
}

pub fn post_flight_middleware<G: ApiGlobal>(_: &Arc<G>) -> Middleware<Body, RouteError<ApiError>> {
Middleware::post_with_req(|mut resp, req| async move {
let headers = req.data::<ResponseHeadersMiddleware>();

if let Some(headers) = headers {
let headers = headers.0.lock().expect("failed to lock headers");
headers.iter().for_each(|(k, v)| {
resp.headers_mut().insert(k, v.clone());
});
}

Ok(resp)
#[derive(Clone, Default)]
struct ResponseHeadersMiddleware(Arc<Mutex<hyper::HeaderMap>>);

pub fn response_headers<G: ApiGlobal>(_: &Arc<G>) -> impl ExtendRouter<Incoming, Body, RouteError<ApiError>> {
extend_fn(|router| {
router
.middleware(Middleware::pre(|mut req| async move {
req.extensions_mut().insert(ResponseHeadersMiddleware::default());

Ok(req)
}))
.middleware(Middleware::post_with_req(|mut resp, req| async move {
let headers = req.data::<ResponseHeadersMiddleware>();

if let Some(headers) = headers {
let headers = headers.0.lock().expect("failed to lock headers");
headers.iter().for_each(|(k, v)| {
resp.headers_mut().insert(k, v.clone());
});
}

Ok(resp)
}))
})
}

Expand Down
5 changes: 2 additions & 3 deletions platform/api/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ pub fn routes<G: ApiGlobal>(global: &Arc<G>) -> Router<Incoming, Body, RouteErro
.data(weak)
// These response header middlewares lets us add headers to the response from the request
// handlers
.middleware(middleware::response_headers::pre_flight_middleware(global))
.middleware(middleware::response_headers::post_flight_middleware(global))
.extend(middleware::response_headers::response_headers(global))
// Our error handler
// The CORS middleware adds the CORS headers to the response
.middleware(middleware::cors::cors_middleware(global))
.extend(middleware::cors::cors_middleware(global))
// The auth middleware checks the Authorization header, and if it's valid, it adds the user
// to the request extensions This way, we can access the user in the handlers, this does not
// fail the request if the token is invalid or not present.
Expand Down

0 comments on commit b68d97d

Please sign in to comment.