Skip to content

Commit

Permalink
Rework progress code for optimal retention
Browse files Browse the repository at this point in the history
- The Brent method makes it harder to predict how many iterations will
be required, so don't set a total; the frontend will just display the
number of elapsed iterations.
- Make sure the progress gets updated during the initial bracket, for
better feedback and so the user can cancel faster.
  • Loading branch information
dae committed Nov 26, 2023
1 parent 4dd9c2e commit 51fb9fe
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use rand::{
};
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use std::sync::{Arc, Mutex};
use strum::EnumCount;

#[derive(Debug, EnumCount)]
Expand Down Expand Up @@ -352,48 +351,59 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
memorized_cnt_per_day[memorized_cnt_per_day.len() - 1]
}

fn sample(
fn sample<F>(
config: &SimulatorConfig,
weights: &[f64],
request_retention: f64,
n: usize,
) -> Result<f64, FSRSError> {
let out = (0..n)
progress: &mut F,
) -> Result<f64>
where
F: FnMut() -> bool,
{
if !progress() {
return Err(FSRSError::Interrupted);
}
Ok((0..n)
.into_par_iter()
.map(|i| {
let result = simulate(
simulate(
config,
weights,
request_retention,
Some((i + 42).try_into().unwrap()),
);
Ok(result)
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(out.iter().sum::<f64>() / n as f64)
.sum::<f64>()
/ n as f64)
}

/// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2894
fn bracket(
fn bracket<F>(
mut xa: f64,
mut xb: f64,
config: &SimulatorConfig,
weights: &[f64],
) -> (f64, f64, f64, f64, f64, f64) {
progress: &mut F,
) -> Result<(f64, f64, f64, f64, f64, f64)>
where
F: FnMut() -> bool,
{
const U_LIM: f64 = 0.95;
const L_LIM: f64 = 0.75;
const GROW_LIMIT: f64 = 100f64;
const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html
const MAXITER: i32 = 20;

let mut fa = -sample(config, weights, xa, 5).unwrap();
let mut fb = -sample(config, weights, xb, 5).unwrap();
let mut fa = -sample(config, weights, xa, 5, progress)?;
let mut fb = -sample(config, weights, xb, 5, progress)?;

if fa < fb {
(fa, fb) = (fb, fa);
(xa, xb) = (xb, xa);
}
let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM);
let mut fc = -sample(config, weights, xc, 5).unwrap();
let mut fc = -sample(config, weights, xc, 5, progress)?;

let mut iter = 0;
while fc < fb {
Expand All @@ -413,34 +423,34 @@ fn bracket(
let mut fw: f64;

if (w - xc) * (xb - w) > 0.0 {
fw = -sample(config, weights, w, 5).unwrap();
fw = -sample(config, weights, w, 5, progress)?;
if fw < fc {
(xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM));
(fa, fb) = (fb, fw);
break;
} else if fw > fb {
xc = w.clamp(L_LIM, U_LIM);
fc = -sample(config, weights, xc, 5).unwrap();
fc = -sample(config, weights, xc, 5, progress)?;
break;
}
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = -sample(config, weights, w, 5).unwrap();
fw = -sample(config, weights, w, 5, progress)?;
} else if (w - wlim) * (wlim - xc) >= 0.0 {
w = wlim;
fw = -sample(config, weights, w, 5).unwrap();
fw = -sample(config, weights, w, 5, progress)?;
} else if (w - wlim) * (xc - w) > 0.0 {
fw = -sample(config, weights, w, 5).unwrap();
fw = -sample(config, weights, w, 5, progress)?;
if fw < fc {
(xb, xc, w) = (
xc.clamp(L_LIM, U_LIM),
w.clamp(L_LIM, U_LIM),
GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM),
);
(fb, fc, fw) = (fc, fw, -sample(config, weights, w, 5).unwrap());
(fb, fc, fw) = (fc, fw, -sample(config, weights, w, 5, progress)?);
}
} else {
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = -sample(config, weights, w, 5).unwrap();
fw = -sample(config, weights, w, 5, progress)?;
}
(xa, xb, xc) = (
xb.clamp(L_LIM, U_LIM),
Expand All @@ -449,7 +459,7 @@ fn bracket(
);
(fa, fb, fc) = (fb, fc, fw);
}
(xa, xb, xc, fa, fb, fc)
Ok((xa, xb, xc, fa, fb, fc))
}

impl<B: Backend> FSRS<B> {
Expand All @@ -459,7 +469,7 @@ impl<B: Backend> FSRS<B> {
&self,
config: &SimulatorConfig,
weights: &Weights,
progress: F,
mut progress: F,
) -> Result<f64>
where
F: FnMut(ItemProgress) -> bool + Send,
Expand All @@ -474,7 +484,17 @@ impl<B: Backend> FSRS<B> {
.iter()
.map(|v| *v as f64)
.collect::<Vec<_>>();
Self::brent(config, &weights, progress)
let mut progress_info = ItemProgress {
current: 0,
// not provided for this method
total: 0,
};
let inc_progress = move || {
progress_info.current += 1;
progress(progress_info)
};

Self::brent(config, &weights, inc_progress)
}
/// https://argmin-rs.github.io/argmin/argmin/solver/brent/index.html
/// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2446
Expand All @@ -484,22 +504,14 @@ impl<B: Backend> FSRS<B> {
mut progress: F,
) -> Result<f64, FSRSError>
where
F: FnMut(ItemProgress) -> bool + Send,
F: FnMut() -> bool,
{
let mut progress_info = ItemProgress {
current: 0,
total: 100,
};
let inc_progress = Arc::new(Mutex::new(move || {
progress_info.current += 1;
progress(progress_info)
}));
let mintol = 1e-10;
let cg = 0.3819660;
let maxiter = 20;
let tol = 0.01f64;

let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, weights);
let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, weights, &mut progress)?;

let (mut v, mut w, mut x) = (xb, xb, xb);
let (mut fx, mut fv, mut fw) = (fb, fb, fb);
Expand All @@ -510,9 +522,6 @@ impl<B: Backend> FSRS<B> {
let mut u;

while iter < maxiter {
if !(inc_progress.lock().unwrap()()) {
return Err(FSRSError::Interrupted);
}
let tol1 = tol.mul_add(x.abs(), mintol);
let tol2 = 2.0 * tol1;
let xmid = 0.5 * (a + b);
Expand Down Expand Up @@ -557,7 +566,7 @@ impl<B: Backend> FSRS<B> {
rat
};
// calculate new output value
let fu = -sample(config, weights, u, 5).unwrap();
let fu = -sample(config, weights, u, 5, &mut progress)?;

// if it's bigger than current
if fu > fx {
Expand Down

0 comments on commit 51fb9fe

Please sign in to comment.