Skip to content

Commit

Permalink
add basic auth
Browse files Browse the repository at this point in the history
  • Loading branch information
realradical committed Dec 1, 2024
1 parent a0c3415 commit c43121b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
1 change: 1 addition & 0 deletions utoipa-swagger-ui/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ utoipa = { version = "5.0.0", path = "../utoipa", default-features = false, feat
] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0" }
base64 = { version = "0.22.1" }

[dev-dependencies]
axum-test = "16.2.0"
Expand Down
49 changes: 45 additions & 4 deletions utoipa-swagger-ui/src/axum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
use std::sync::Arc;

use axum::{
extract::Path, http::StatusCode, response::IntoResponse, routing, Extension, Json, Router,
body::Body, extract::Path, http::{HeaderMap, Request, Response, StatusCode}, middleware::{self, Next}, response::IntoResponse, routing, Extension, Json, Router
};
use base64::{prelude::BASE64_STANDARD, Engine};

use crate::{ApiDoc, Config, SwaggerUi, Url};
use crate::{ApiDoc, BasicAuth, Config, SwaggerUi, Url};

impl<S> From<SwaggerUi> for Router<S>
where
Expand Down Expand Up @@ -42,10 +43,10 @@ where
Config::new(urls)
};

let handler = routing::get(serve_swagger_ui).layer(Extension(Arc::new(config)));
let handler = routing::get(serve_swagger_ui).layer(Extension(Arc::new(config.clone())));
let path: &str = swagger_ui.path.as_ref();

if path == "/" {
let mut router = if path == "/" {
router
.route(path, handler.clone())
.route(&format!("{}*rest", path), handler)
Expand All @@ -65,7 +66,30 @@ where
)
.route(&format!("{}/", path), handler.clone())
.route(&format!("{}/*rest", path), handler)
};

if let Some(BasicAuth {username, password}) = config.basic_auth {
let username = Arc::new(username);
let password = Arc::new(password);
let basic_auth_middleware = move |headers: HeaderMap, req: Request<Body>, next: Next| {
let username = username.clone();
let password = password.clone();
async move {
if let Some(header) = headers.get("Authorization") {
if let Ok(header_str) = header.to_str() {
let base64_encoded_credentials = BASE64_STANDARD.encode(format!("{}:{}", &username, &password));
if header_str == format!("Basic {}", base64_encoded_credentials) {
return Ok::<Response<Body>, StatusCode>(next.run(req).await);
}
}
}
Ok::<Response<Body>, StatusCode>((StatusCode::UNAUTHORIZED, [("WWW-Authenticate", "Basic realm=\":\"")]).into_response())
}
};
router = router.layer(middleware::from_fn(basic_auth_middleware));
}

router
}
}

Expand Down Expand Up @@ -152,4 +176,21 @@ mod tests {
let response = server.get("/swagger-ui/swagger-ui.css").await;
response.assert_status_ok();
}

#[tokio::test]
async fn basic_auth() {
let swagger_ui = SwaggerUi::new("/swagger-ui")
.config(Config::default().basic_auth(BasicAuth { username: "admin".to_string(), password: "password".to_string() }));
let app = Router::<()>::from(swagger_ui);
let server = TestServer::new(app).unwrap();
let response = server.get("/swagger-ui").await;
response.assert_status_unauthorized();
let encoded_credentials = BASE64_STANDARD.encode("admin:password");
let response = server.get("/swagger-ui").authorization(format!("Basic {}", encoded_credentials)).await;
response.assert_status_see_other();
let response = server.get("/swagger-ui/").authorization(format!("Basic {}", encoded_credentials)).await;
response.assert_status_ok();
let response = server.get("/swagger-ui/swagger-ui.css").authorization(format!("Basic {}", encoded_credentials)).await;
response.assert_status_ok();
}
}
30 changes: 30 additions & 0 deletions utoipa-swagger-ui/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@ pub struct Config<'a> {

/// The layout of Swagger UI uses, default is `"StandaloneLayout"`.
layout: &'a str,

/// Basic authentication configuration. If configured, the Swagger UI will prompt for basic auth credentials.
#[serde(skip_serializing_if = "Option::is_none")]
basic_auth: Option<BasicAuth>,
}

impl<'a> Config<'a> {
Expand Down Expand Up @@ -1268,6 +1272,25 @@ impl<'a> Config<'a> {

self
}

/// Set basic authentication configuration.
/// If configured, the Swagger UI will prompt for basic auth credentials.
/// username and password are required. "{username}:{password}" will be base64 encoded and added to the "Authorization" header.
/// If not provided or wrong credentials are provided, the user will be prompted again.
/// # Examples
///
/// Configure basic authentication.
/// ```rust
/// # use utoipa_swagger_ui::Config;
/// # use utoipa_swagger_ui::BasicAuth;
/// let config = Config::new(["/api-docs/openapi.json"])
/// .basic_auth(BasicAuth { username: "admin".to_string(), password: "password".to_string() });
/// ```
pub fn basic_auth(mut self, basic_auth: BasicAuth) -> Self {
self.basic_auth = Some(basic_auth);

self
}
}

impl Default for Config<'_> {
Expand Down Expand Up @@ -1301,6 +1324,7 @@ impl Default for Config<'_> {
oauth: Default::default(),
syntax_highlight: Default::default(),
layout: SWAGGER_STANDALONE_LAYOUT,
basic_auth: Default::default(),
}
}
}
Expand All @@ -1317,6 +1341,12 @@ impl From<String> for Config<'_> {
}
}

#[derive(Serialize, Clone)]
pub struct BasicAuth {
pub username: String,
pub password: String,
}

/// Represents settings related to syntax highlighting of payloads and
/// cURL commands.
#[derive(Serialize, Clone)]
Expand Down

0 comments on commit c43121b

Please sign in to comment.