From bf4f5173004b807578df8d3f58f485e03b19dc6a Mon Sep 17 00:00:00 2001 From: Olivier 'reivilibre Date: Fri, 28 Jun 2024 16:00:00 +0100 Subject: [PATCH] Pull out common code from both `get` and `post` of the recovery handler --- crates/handlers/src/views/recovery/finish.rs | 175 +++++++++++-------- 1 file changed, 103 insertions(+), 72 deletions(-) diff --git a/crates/handlers/src/views/recovery/finish.rs b/crates/handlers/src/views/recovery/finish.rs index e5dcbe381..1cf17d81f 100644 --- a/crates/handlers/src/views/recovery/finish.rs +++ b/crates/handlers/src/views/recovery/finish.rs @@ -20,10 +20,11 @@ use axum::{ }; use mas_axum_utils::{ cookies::CookieJar, - csrf::{CsrfExt, ProtectedForm}, + csrf::{CsrfExt, CsrfToken, ProtectedForm}, FancyError, }; use mas_data_model::SiteConfig; +use mas_i18n::DataLocale; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, BoxRng}; @@ -47,23 +48,49 @@ pub(crate) struct RouteForm { new_password_confirm: String, } -pub(crate) async fn get( - mut rng: BoxRng, - clock: BoxClock, - mut repo: BoxRepository, - State(site_config): State, - State(templates): State, - PreferredLanguage(locale): PreferredLanguage, +struct Continuation { + pub csrf_token: CsrfToken, + pub cookie_jar: CookieJar, + user: mas_data_model::User, + ticket: mas_data_model::UserRecoveryTicket, + session: mas_data_model::UserRecoverySession, +} + +/// Helper enum for the output of [`common`], which is called by both +/// [`get`] and [`post`]. +#[allow(clippy::large_enum_variant)] +enum CommonOut { + /// Continue handling the request + Continue(Continuation), + /// Stop handling the request and respond immediately + Respond(Response), +} + +impl From for CommonOut { + fn from(value: T) -> Self { + CommonOut::Respond(value.into_response()) + } +} + +/// This is common form validation, login checks and config checks +/// needed by both [`get`] and [`post`]. +async fn common( + rng: &mut BoxRng, + clock: &BoxClock, + repo: &mut BoxRepository, + site_config: &SiteConfig, + templates: &Templates, + locale: DataLocale, cookie_jar: CookieJar, - Query(query): Query, -) -> Result { + query: &RouteQuery, +) -> Result { if !site_config.account_recovery_allowed { let context = EmptyContext.with_language(locale); let rendered = templates.render_recovery_disabled(&context)?; - return Ok((cookie_jar, Html(rendered)).into_response()); + return Ok((cookie_jar, Html(rendered)).into()); } - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng); let ticket = repo .user_recovery() @@ -80,7 +107,7 @@ pub(crate) async fn get( if session.consumed_at.is_some() { let context = EmptyContext.with_language(locale); let rendered = templates.render_recovery_consumed(&context)?; - return Ok((cookie_jar, Html(rendered)).into_response()); + return Ok((cookie_jar, Html(rendered)).into()); } if !ticket.active(clock.now()) { @@ -88,7 +115,7 @@ pub(crate) async fn get( .with_csrf(csrf_token.form_value()) .with_language(locale); let rendered = templates.render_recovery_expired(&context)?; - return Ok((cookie_jar, Html(rendered)).into_response()); + return Ok((cookie_jar, Html(rendered)).into()); } let user_email = repo @@ -112,9 +139,49 @@ pub(crate) async fn get( .with_code("Account locked") .with_language(&locale), )?; - return Ok((cookie_jar, Html(rendered)).into_response()); + return Ok((cookie_jar, Html(rendered)).into()); } + Ok(CommonOut::Continue(Continuation { + csrf_token, + cookie_jar, + user, + ticket, + session, + })) +} + +pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, + State(site_config): State, + State(templates): State, + PreferredLanguage(locale): PreferredLanguage, + cookie_jar: CookieJar, + Query(query): Query, +) -> Result { + let Continuation { + csrf_token, + cookie_jar, + user, + .. + } = match common( + &mut rng, + &clock, + &mut repo, + &site_config, + &templates, + locale.clone(), + cookie_jar, + &query, + ) + .await? + { + CommonOut::Continue(st) => st, + CommonOut::Respond(resp) => return Ok(resp), + }; + let context = RecoveryFinishContext::new(user) .with_csrf(csrf_token.form_value()) .with_language(locale); @@ -139,63 +206,27 @@ pub(crate) async fn post( Query(query): Query, Form(form): Form>, ) -> Result { - if !site_config.account_recovery_allowed { - let context = EmptyContext.with_language(locale); - let rendered = templates.render_recovery_disabled(&context)?; - return Ok((cookie_jar, Html(rendered)).into_response()); - } - - let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); - - let ticket = repo - .user_recovery() - .find_ticket(&query.ticket) - .await? - .context("Unknown ticket")?; - - let session = repo - .user_recovery() - .lookup_session(ticket.user_recovery_session_id) - .await? - .context("Unknown session")?; - - if session.consumed_at.is_some() { - let context = EmptyContext.with_language(locale); - let rendered = templates.render_recovery_consumed(&context)?; - return Ok((cookie_jar, Html(rendered)).into_response()); - } - - if !ticket.active(clock.now()) { - let context = RecoveryExpiredContext::new(session) - .with_csrf(csrf_token.form_value()) - .with_language(locale); - let rendered = templates.render_recovery_expired(&context)?; - return Ok((cookie_jar, Html(rendered)).into_response()); - } - - let user_email = repo - .user_email() - .lookup(ticket.user_email_id) - .await? - // Only allow confirmed email addresses - .filter(|email| email.confirmed_at.is_some()) - .context("Unknown email address")?; - - let user = repo - .user() - .lookup(user_email.user_id) - .await? - .context("Invalid user")?; - - if !user.is_valid() { - // TODO: render a 'account locked' page - let rendered = templates.render_error( - &ErrorContext::new() - .with_code("Account locked") - .with_language(&locale), - )?; - return Ok((cookie_jar, Html(rendered)).into_response()); - } + let Continuation { + csrf_token, + cookie_jar, + user, + ticket, + session, + } = match common( + &mut rng, + &clock, + &mut repo, + &site_config, + &templates, + locale.clone(), + cookie_jar, + &query, + ) + .await? + { + CommonOut::Continue(st) => st, + CommonOut::Respond(resp) => return Ok(resp), + }; let form = cookie_jar.verify_form(&clock, form)?;