diff --git a/rust/Cargo.toml b/rust/Cargo.toml index de0a8e059ad9f..eae4b047f3dbf 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -73,7 +73,7 @@ time = { version = "0.3.36", features = [ ] } thiserror = { version = "1.0" } tokio = { version = "1.34.0", features = ["full"] } -tower = "0.4.13" +tower = { version = "0.4.13", features = ["default", "limit"] } tower-http = { version = "0.5.2", features = ["cors", "limit", "trace"] } tracing = "0.1.40" tracing-opentelemetry = "0.23.0" diff --git a/rust/hook-api/src/config.rs b/rust/hook-api/src/config.rs index fa1bbb3c7e484..1e02a575ffafa 100644 --- a/rust/hook-api/src/config.rs +++ b/rust/hook-api/src/config.rs @@ -20,6 +20,9 @@ pub struct Config { #[envconfig(default = "5000000")] pub max_body_size: usize, + #[envconfig(default = "100")] + pub concurrency_limit: usize, + #[envconfig(default = "false")] pub hog_mode: bool, } diff --git a/rust/hook-api/src/handlers/app.rs b/rust/hook-api/src/handlers/app.rs index 1dea37c1bdc5b..ffe20720b444f 100644 --- a/rust/hook-api/src/handlers/app.rs +++ b/rust/hook-api/src/handlers/app.rs @@ -1,4 +1,7 @@ +use std::convert::Infallible; + use axum::{routing, Router}; +use tower::limit::ConcurrencyLimitLayer; use tower_http::limit::RequestBodyLimitLayer; use hook_common::pgqueue::PgQueue; @@ -10,6 +13,7 @@ pub fn add_routes( pg_pool: PgQueue, hog_mode: bool, max_body_size: usize, + concurrency_limit: usize, ) -> Router { let router = router .route("/", routing::get(index)) @@ -21,6 +25,7 @@ pub fn add_routes( "/hoghook", routing::post(webhook::post_hoghook) .with_state(pg_pool) + .layer::<_, Infallible>(ConcurrencyLimitLayer::new(concurrency_limit)) .layer(RequestBodyLimitLayer::new(max_body_size)), ) } else { @@ -28,6 +33,7 @@ pub fn add_routes( "/webhook", routing::post(webhook::post_webhook) .with_state(pg_pool) + .layer::<_, Infallible>(ConcurrencyLimitLayer::new(concurrency_limit)) .layer(RequestBodyLimitLayer::new(max_body_size)), ) } @@ -54,7 +60,7 @@ mod tests { let pg_queue = PgQueue::new_from_pool("test_index", db).await; let hog_mode = false; - let app = add_routes(Router::new(), pg_queue, hog_mode, 1_000_000); + let app = add_routes(Router::new(), pg_queue, hog_mode, 1_000_000, 10); let response = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) diff --git a/rust/hook-api/src/handlers/webhook.rs b/rust/hook-api/src/handlers/webhook.rs index 3f4cd8c8a5b8a..d797f41eb22ba 100644 --- a/rust/hook-api/src/handlers/webhook.rs +++ b/rust/hook-api/src/handlers/webhook.rs @@ -243,13 +243,20 @@ mod tests { use crate::handlers::app::add_routes; const MAX_BODY_SIZE: usize = 1_000_000; + const CONCURRENCY_LIMIT: usize = 10; #[sqlx::test(migrations = "../migrations")] async fn webhook_success(db: PgPool) { let pg_queue = PgQueue::new_from_pool("test_index", db).await; let hog_mode = false; - let app = add_routes(Router::new(), pg_queue, hog_mode, MAX_BODY_SIZE); + let app = add_routes( + Router::new(), + pg_queue, + hog_mode, + MAX_BODY_SIZE, + CONCURRENCY_LIMIT, + ); let mut headers = collections::HashMap::new(); headers.insert("Content-Type".to_owned(), "application/json".to_owned()); @@ -292,7 +299,13 @@ mod tests { let pg_queue = PgQueue::new_from_pool("test_index", db).await; let hog_mode = false; - let app = add_routes(Router::new(), pg_queue, hog_mode, MAX_BODY_SIZE); + let app = add_routes( + Router::new(), + pg_queue, + hog_mode, + MAX_BODY_SIZE, + CONCURRENCY_LIMIT, + ); let response = app .oneshot( @@ -330,7 +343,13 @@ mod tests { let pg_queue = PgQueue::new_from_pool("test_index", db).await; let hog_mode = false; - let app = add_routes(Router::new(), pg_queue, hog_mode, MAX_BODY_SIZE); + let app = add_routes( + Router::new(), + pg_queue, + hog_mode, + MAX_BODY_SIZE, + CONCURRENCY_LIMIT, + ); let response = app .oneshot( @@ -352,7 +371,13 @@ mod tests { let pg_queue = PgQueue::new_from_pool("test_index", db).await; let hog_mode = false; - let app = add_routes(Router::new(), pg_queue, hog_mode, MAX_BODY_SIZE); + let app = add_routes( + Router::new(), + pg_queue, + hog_mode, + MAX_BODY_SIZE, + CONCURRENCY_LIMIT, + ); let response = app .oneshot( @@ -374,7 +399,13 @@ mod tests { let pg_queue = PgQueue::new_from_pool("test_index", db).await; let hog_mode = false; - let app = add_routes(Router::new(), pg_queue, hog_mode, MAX_BODY_SIZE); + let app = add_routes( + Router::new(), + pg_queue, + hog_mode, + MAX_BODY_SIZE, + CONCURRENCY_LIMIT, + ); let bytes: Vec = vec![b'a'; MAX_BODY_SIZE + 1]; let long_string = String::from_utf8_lossy(&bytes); @@ -422,7 +453,13 @@ mod tests { let pg_queue = PgQueue::new_from_pool("test_index", db.clone()).await; let hog_mode = true; - let app = add_routes(Router::new(), pg_queue, hog_mode, MAX_BODY_SIZE); + let app = add_routes( + Router::new(), + pg_queue, + hog_mode, + MAX_BODY_SIZE, + CONCURRENCY_LIMIT, + ); let valid_payloads = vec![ ( @@ -507,7 +544,13 @@ mod tests { let pg_queue = PgQueue::new_from_pool("test_index", db.clone()).await; let hog_mode = true; - let app = add_routes(Router::new(), pg_queue, hog_mode, MAX_BODY_SIZE); + let app = add_routes( + Router::new(), + pg_queue, + hog_mode, + MAX_BODY_SIZE, + CONCURRENCY_LIMIT, + ); let invalid_payloads = vec![ r#"{}"#, diff --git a/rust/hook-api/src/main.rs b/rust/hook-api/src/main.rs index d20c5a11a37e1..7ca8de09513ff 100644 --- a/rust/hook-api/src/main.rs +++ b/rust/hook-api/src/main.rs @@ -39,6 +39,7 @@ async fn main() { pg_queue, config.hog_mode, config.max_body_size, + config.concurrency_limit, ); let app = setup_metrics_routes(app);