From c43121b1546baec3bfa731ee3d3c205e9d395bc4 Mon Sep 17 00:00:00 2001 From: Shao <41271167+realradical@users.noreply.github.com> Date: Sat, 30 Nov 2024 17:49:14 -0800 Subject: [PATCH] add basic auth --- utoipa-swagger-ui/Cargo.toml | 1 + utoipa-swagger-ui/src/axum.rs | 49 ++++++++++++++++++++++++++++++++--- utoipa-swagger-ui/src/lib.rs | 30 +++++++++++++++++++++ 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/utoipa-swagger-ui/Cargo.toml b/utoipa-swagger-ui/Cargo.toml index 13ed4908..3997186b 100644 --- a/utoipa-swagger-ui/Cargo.toml +++ b/utoipa-swagger-ui/Cargo.toml @@ -32,6 +32,7 @@ utoipa = { version = "5.0.0", path = "../utoipa", default-features = false, feat ] } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } +base64 = { version = "0.22.1" } [dev-dependencies] axum-test = "16.2.0" diff --git a/utoipa-swagger-ui/src/axum.rs b/utoipa-swagger-ui/src/axum.rs index 0c2c2aa8..822d4b29 100644 --- a/utoipa-swagger-ui/src/axum.rs +++ b/utoipa-swagger-ui/src/axum.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use axum::{ - extract::Path, http::StatusCode, response::IntoResponse, routing, Extension, Json, Router, + body::Body, extract::Path, http::{HeaderMap, Request, Response, StatusCode}, middleware::{self, Next}, response::IntoResponse, routing, Extension, Json, Router }; +use base64::{prelude::BASE64_STANDARD, Engine}; -use crate::{ApiDoc, Config, SwaggerUi, Url}; +use crate::{ApiDoc, BasicAuth, Config, SwaggerUi, Url}; impl From for Router where @@ -42,10 +43,10 @@ where Config::new(urls) }; - let handler = routing::get(serve_swagger_ui).layer(Extension(Arc::new(config))); + let handler = routing::get(serve_swagger_ui).layer(Extension(Arc::new(config.clone()))); let path: &str = swagger_ui.path.as_ref(); - if path == "/" { + let mut router = if path == "/" { router .route(path, handler.clone()) .route(&format!("{}*rest", path), handler) @@ -65,7 +66,30 @@ where ) .route(&format!("{}/", path), handler.clone()) .route(&format!("{}/*rest", path), handler) + }; + + if let Some(BasicAuth {username, password}) = config.basic_auth { + let username = Arc::new(username); + let password = Arc::new(password); + let basic_auth_middleware = move |headers: HeaderMap, req: Request, next: Next| { + let username = username.clone(); + let password = password.clone(); + async move { + if let Some(header) = headers.get("Authorization") { + if let Ok(header_str) = header.to_str() { + let base64_encoded_credentials = BASE64_STANDARD.encode(format!("{}:{}", &username, &password)); + if header_str == format!("Basic {}", base64_encoded_credentials) { + return Ok::, StatusCode>(next.run(req).await); + } + } + } + Ok::, StatusCode>((StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Basic realm=\":\"")]).into_response()) + } + }; + router = router.layer(middleware::from_fn(basic_auth_middleware)); } + + router } } @@ -152,4 +176,21 @@ mod tests { let response = server.get("/swagger-ui/swagger-ui.css").await; response.assert_status_ok(); } + + #[tokio::test] + async fn basic_auth() { + let swagger_ui = SwaggerUi::new("/swagger-ui") + .config(Config::default().basic_auth(BasicAuth { username: "admin".to_string(), password: "password".to_string() })); + let app = Router::<()>::from(swagger_ui); + let server = TestServer::new(app).unwrap(); + let response = server.get("/swagger-ui").await; + response.assert_status_unauthorized(); + let encoded_credentials = BASE64_STANDARD.encode("admin:password"); + let response = server.get("/swagger-ui").authorization(format!("Basic {}", encoded_credentials)).await; + response.assert_status_see_other(); + let response = server.get("/swagger-ui/").authorization(format!("Basic {}", encoded_credentials)).await; + response.assert_status_ok(); + let response = server.get("/swagger-ui/swagger-ui.css").authorization(format!("Basic {}", encoded_credentials)).await; + response.assert_status_ok(); + } } diff --git a/utoipa-swagger-ui/src/lib.rs b/utoipa-swagger-ui/src/lib.rs index ab8823aa..5f6dcb31 100644 --- a/utoipa-swagger-ui/src/lib.rs +++ b/utoipa-swagger-ui/src/lib.rs @@ -669,6 +669,10 @@ pub struct Config<'a> { /// The layout of Swagger UI uses, default is `"StandaloneLayout"`. layout: &'a str, + + /// Basic authentication configuration. If configured, the Swagger UI will prompt for basic auth credentials. + #[serde(skip_serializing_if = "Option::is_none")] + basic_auth: Option, } impl<'a> Config<'a> { @@ -1268,6 +1272,25 @@ impl<'a> Config<'a> { self } + + /// Set basic authentication configuration. + /// If configured, the Swagger UI will prompt for basic auth credentials. + /// username and password are required. "{username}:{password}" will be base64 encoded and added to the "Authorization" header. + /// If not provided or wrong credentials are provided, the user will be prompted again. + /// # Examples + /// + /// Configure basic authentication. + /// ```rust + /// # use utoipa_swagger_ui::Config; + /// # use utoipa_swagger_ui::BasicAuth; + /// let config = Config::new(["/api-docs/openapi.json"]) + /// .basic_auth(BasicAuth { username: "admin".to_string(), password: "password".to_string() }); + /// ``` + pub fn basic_auth(mut self, basic_auth: BasicAuth) -> Self { + self.basic_auth = Some(basic_auth); + + self + } } impl Default for Config<'_> { @@ -1301,6 +1324,7 @@ impl Default for Config<'_> { oauth: Default::default(), syntax_highlight: Default::default(), layout: SWAGGER_STANDALONE_LAYOUT, + basic_auth: Default::default(), } } } @@ -1317,6 +1341,12 @@ impl From for Config<'_> { } } +#[derive(Serialize, Clone)] +pub struct BasicAuth { + pub username: String, + pub password: String, +} + /// Represents settings related to syntax highlighting of payloads and /// cURL commands. #[derive(Serialize, Clone)]