Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Pull out common code from both get and post of the recovery handler #2886

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 103 additions & 72 deletions crates/handlers/src/views/recovery/finish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use axum::{
};
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
csrf::{CsrfExt, CsrfToken, ProtectedForm},
FancyError,
};
use mas_data_model::SiteConfig;
use mas_i18n::DataLocale;
use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng};
Expand All @@ -47,23 +48,49 @@ pub(crate) struct RouteForm {
new_password_confirm: String,
}

pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
PreferredLanguage(locale): PreferredLanguage,
struct Continuation {
pub csrf_token: CsrfToken,
pub cookie_jar: CookieJar,
user: mas_data_model::User,
ticket: mas_data_model::UserRecoveryTicket,
session: mas_data_model::UserRecoverySession,
}

/// Helper enum for the output of [`common`], which is called by both
/// [`get`] and [`post`].
#[allow(clippy::large_enum_variant)]
enum CommonOut {
/// Continue handling the request
Continue(Continuation),
/// Stop handling the request and respond immediately
Respond(Response),
}

impl<T: IntoResponse> From<T> for CommonOut {
fn from(value: T) -> Self {
CommonOut::Respond(value.into_response())
}
}

/// This is common form validation, login checks and config checks
/// needed by both [`get`] and [`post`].
async fn common(
rng: &mut BoxRng,
clock: &BoxClock,
repo: &mut BoxRepository,
site_config: &SiteConfig,
templates: &Templates,
locale: DataLocale,
cookie_jar: CookieJar,
Query(query): Query<RouteQuery>,
) -> Result<Response, FancyError> {
query: &RouteQuery,
) -> Result<CommonOut, FancyError> {
if !site_config.account_recovery_allowed {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_disabled(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);

let ticket = repo
.user_recovery()
Expand All @@ -80,15 +107,15 @@ pub(crate) async fn get(
if session.consumed_at.is_some() {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_consumed(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

if !ticket.active(clock.now()) {
let context = RecoveryExpiredContext::new(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let rendered = templates.render_recovery_expired(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

let user_email = repo
Expand All @@ -112,9 +139,49 @@ pub(crate) async fn get(
.with_code("Account locked")
.with_language(&locale),
)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

Ok(CommonOut::Continue(Continuation {
csrf_token,
cookie_jar,
user,
ticket,
session,
}))
}

pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Query(query): Query<RouteQuery>,
) -> Result<Response, FancyError> {
let Continuation {
csrf_token,
cookie_jar,
user,
..
} = match common(
&mut rng,
&clock,
&mut repo,
&site_config,
&templates,
locale.clone(),
cookie_jar,
&query,
)
.await?
{
CommonOut::Continue(st) => st,
CommonOut::Respond(resp) => return Ok(resp),
};

let context = RecoveryFinishContext::new(user)
.with_csrf(csrf_token.form_value())
.with_language(locale);
Expand All @@ -139,63 +206,27 @@ pub(crate) async fn post(
Query(query): Query<RouteQuery>,
Form(form): Form<ProtectedForm<RouteForm>>,
) -> Result<Response, FancyError> {
if !site_config.account_recovery_allowed {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_disabled(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);

let ticket = repo
.user_recovery()
.find_ticket(&query.ticket)
.await?
.context("Unknown ticket")?;

let session = repo
.user_recovery()
.lookup_session(ticket.user_recovery_session_id)
.await?
.context("Unknown session")?;

if session.consumed_at.is_some() {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_consumed(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}

if !ticket.active(clock.now()) {
let context = RecoveryExpiredContext::new(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let rendered = templates.render_recovery_expired(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}

let user_email = repo
.user_email()
.lookup(ticket.user_email_id)
.await?
// Only allow confirmed email addresses
.filter(|email| email.confirmed_at.is_some())
.context("Unknown email address")?;

let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Invalid user")?;

if !user.is_valid() {
// TODO: render a 'account locked' page
let rendered = templates.render_error(
&ErrorContext::new()
.with_code("Account locked")
.with_language(&locale),
)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}
let Continuation {
csrf_token,
cookie_jar,
user,
ticket,
session,
} = match common(
&mut rng,
&clock,
&mut repo,
&site_config,
&templates,
locale.clone(),
cookie_jar,
&query,
)
.await?
{
CommonOut::Continue(st) => st,
CommonOut::Respond(resp) => return Ok(resp),
};

let form = cookie_jar.verify_form(&clock, form)?;

Expand Down
Loading