Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Pull out common code from both get and post of the recovery handler
Browse files Browse the repository at this point in the history
  • Loading branch information
reivilibre committed Jun 28, 2024
1 parent 93d3cfa commit bf4f517
Showing 1 changed file with 103 additions and 72 deletions.
175 changes: 103 additions & 72 deletions crates/handlers/src/views/recovery/finish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use axum::{
};
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
csrf::{CsrfExt, CsrfToken, ProtectedForm},
FancyError,
};
use mas_data_model::SiteConfig;
use mas_i18n::DataLocale;
use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng};
Expand All @@ -47,23 +48,49 @@ pub(crate) struct RouteForm {
new_password_confirm: String,
}

pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
PreferredLanguage(locale): PreferredLanguage,
struct Continuation {
pub csrf_token: CsrfToken,
pub cookie_jar: CookieJar,
user: mas_data_model::User,
ticket: mas_data_model::UserRecoveryTicket,
session: mas_data_model::UserRecoverySession,
}

/// Helper enum for the output of [`common`], which is called by both
/// [`get`] and [`post`].
#[allow(clippy::large_enum_variant)]
enum CommonOut {
/// Continue handling the request
Continue(Continuation),
/// Stop handling the request and respond immediately
Respond(Response),
}

impl<T: IntoResponse> From<T> for CommonOut {
fn from(value: T) -> Self {
CommonOut::Respond(value.into_response())
}
}

/// This is common form validation, login checks and config checks
/// needed by both [`get`] and [`post`].
async fn common(
rng: &mut BoxRng,
clock: &BoxClock,
repo: &mut BoxRepository,
site_config: &SiteConfig,
templates: &Templates,
locale: DataLocale,
cookie_jar: CookieJar,
Query(query): Query<RouteQuery>,
) -> Result<Response, FancyError> {
query: &RouteQuery,
) -> Result<CommonOut, FancyError> {
if !site_config.account_recovery_allowed {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_disabled(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);

let ticket = repo
.user_recovery()
Expand All @@ -80,15 +107,15 @@ pub(crate) async fn get(
if session.consumed_at.is_some() {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_consumed(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

if !ticket.active(clock.now()) {
let context = RecoveryExpiredContext::new(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let rendered = templates.render_recovery_expired(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

let user_email = repo
Expand All @@ -112,9 +139,49 @@ pub(crate) async fn get(
.with_code("Account locked")
.with_language(&locale),
)?;
return Ok((cookie_jar, Html(rendered)).into_response());
return Ok((cookie_jar, Html(rendered)).into());
}

Ok(CommonOut::Continue(Continuation {
csrf_token,
cookie_jar,
user,
ticket,
session,
}))
}

pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Query(query): Query<RouteQuery>,
) -> Result<Response, FancyError> {
let Continuation {
csrf_token,
cookie_jar,
user,
..
} = match common(
&mut rng,
&clock,
&mut repo,
&site_config,
&templates,
locale.clone(),
cookie_jar,
&query,
)
.await?
{
CommonOut::Continue(st) => st,
CommonOut::Respond(resp) => return Ok(resp),
};

let context = RecoveryFinishContext::new(user)
.with_csrf(csrf_token.form_value())
.with_language(locale);
Expand All @@ -139,63 +206,27 @@ pub(crate) async fn post(
Query(query): Query<RouteQuery>,
Form(form): Form<ProtectedForm<RouteForm>>,
) -> Result<Response, FancyError> {
if !site_config.account_recovery_allowed {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_disabled(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);

let ticket = repo
.user_recovery()
.find_ticket(&query.ticket)
.await?
.context("Unknown ticket")?;

let session = repo
.user_recovery()
.lookup_session(ticket.user_recovery_session_id)
.await?
.context("Unknown session")?;

if session.consumed_at.is_some() {
let context = EmptyContext.with_language(locale);
let rendered = templates.render_recovery_consumed(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}

if !ticket.active(clock.now()) {
let context = RecoveryExpiredContext::new(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let rendered = templates.render_recovery_expired(&context)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}

let user_email = repo
.user_email()
.lookup(ticket.user_email_id)
.await?
// Only allow confirmed email addresses
.filter(|email| email.confirmed_at.is_some())
.context("Unknown email address")?;

let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Invalid user")?;

if !user.is_valid() {
// TODO: render a 'account locked' page
let rendered = templates.render_error(
&ErrorContext::new()
.with_code("Account locked")
.with_language(&locale),
)?;
return Ok((cookie_jar, Html(rendered)).into_response());
}
let Continuation {
csrf_token,
cookie_jar,
user,
ticket,
session,
} = match common(
&mut rng,
&clock,
&mut repo,
&site_config,
&templates,
locale.clone(),
cookie_jar,
&query,
)
.await?
{
CommonOut::Continue(st) => st,
CommonOut::Respond(resp) => return Ok(resp),
};

let form = cookie_jar.verify_form(&clock, form)?;

Expand Down

0 comments on commit bf4f517

Please sign in to comment.