From 51fb9fe00d7e87d8a8fca603bb080d1a8b2c705f Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Sun, 26 Nov 2023 13:52:24 +1000 Subject: [PATCH] Rework progress code for optimal retention - The Brent method makes it harder to predict how many iterations will be required, so don't set a total; the frontend will just display the number of elapsed iterations. - Make sure the progress gets updated during the initial bracket, for better feedback and so the user can cancel faster. --- src/optimal_retention.rs | 85 ++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 2b8d91c2..1500d0d9 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -13,7 +13,6 @@ use rand::{ }; use rayon::iter::IntoParallelIterator; use rayon::iter::ParallelIterator; -use std::sync::{Arc, Mutex}; use strum::EnumCount; #[derive(Debug, EnumCount)] @@ -352,48 +351,59 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O memorized_cnt_per_day[memorized_cnt_per_day.len() - 1] } -fn sample( +fn sample( config: &SimulatorConfig, weights: &[f64], request_retention: f64, n: usize, -) -> Result { - let out = (0..n) + progress: &mut F, +) -> Result +where + F: FnMut() -> bool, +{ + if !progress() { + return Err(FSRSError::Interrupted); + } + Ok((0..n) .into_par_iter() .map(|i| { - let result = simulate( + simulate( config, weights, request_retention, Some((i + 42).try_into().unwrap()), - ); - Ok(result) + ) }) - .collect::, _>>()?; - Ok(out.iter().sum::() / n as f64) + .sum::() + / n as f64) } + /// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2894 -fn bracket( +fn bracket( mut xa: f64, mut xb: f64, config: &SimulatorConfig, weights: &[f64], -) -> (f64, f64, f64, f64, f64, f64) { + progress: &mut F, +) -> Result<(f64, f64, f64, f64, f64, f64)> +where + F: FnMut() -> bool, +{ const U_LIM: f64 = 0.95; const L_LIM: f64 = 0.75; const GROW_LIMIT: f64 = 100f64; const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html const MAXITER: i32 = 20; - let mut fa = -sample(config, weights, xa, 5).unwrap(); - let mut fb = -sample(config, weights, xb, 5).unwrap(); + let mut fa = -sample(config, weights, xa, 5, progress)?; + let mut fb = -sample(config, weights, xb, 5, progress)?; if fa < fb { (fa, fb) = (fb, fa); (xa, xb) = (xb, xa); } let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM); - let mut fc = -sample(config, weights, xc, 5).unwrap(); + let mut fc = -sample(config, weights, xc, 5, progress)?; let mut iter = 0; while fc < fb { @@ -413,34 +423,34 @@ fn bracket( let mut fw: f64; if (w - xc) * (xb - w) > 0.0 { - fw = -sample(config, weights, w, 5).unwrap(); + fw = -sample(config, weights, w, 5, progress)?; if fw < fc { (xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM)); (fa, fb) = (fb, fw); break; } else if fw > fb { xc = w.clamp(L_LIM, U_LIM); - fc = -sample(config, weights, xc, 5).unwrap(); + fc = -sample(config, weights, xc, 5, progress)?; break; } w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = -sample(config, weights, w, 5).unwrap(); + fw = -sample(config, weights, w, 5, progress)?; } else if (w - wlim) * (wlim - xc) >= 0.0 { w = wlim; - fw = -sample(config, weights, w, 5).unwrap(); + fw = -sample(config, weights, w, 5, progress)?; } else if (w - wlim) * (xc - w) > 0.0 { - fw = -sample(config, weights, w, 5).unwrap(); + fw = -sample(config, weights, w, 5, progress)?; if fw < fc { (xb, xc, w) = ( xc.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM), GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM), ); - (fb, fc, fw) = (fc, fw, -sample(config, weights, w, 5).unwrap()); + (fb, fc, fw) = (fc, fw, -sample(config, weights, w, 5, progress)?); } } else { w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = -sample(config, weights, w, 5).unwrap(); + fw = -sample(config, weights, w, 5, progress)?; } (xa, xb, xc) = ( xb.clamp(L_LIM, U_LIM), @@ -449,7 +459,7 @@ fn bracket( ); (fa, fb, fc) = (fb, fc, fw); } - (xa, xb, xc, fa, fb, fc) + Ok((xa, xb, xc, fa, fb, fc)) } impl FSRS { @@ -459,7 +469,7 @@ impl FSRS { &self, config: &SimulatorConfig, weights: &Weights, - progress: F, + mut progress: F, ) -> Result where F: FnMut(ItemProgress) -> bool + Send, @@ -474,7 +484,17 @@ impl FSRS { .iter() .map(|v| *v as f64) .collect::>(); - Self::brent(config, &weights, progress) + let mut progress_info = ItemProgress { + current: 0, + // not provided for this method + total: 0, + }; + let inc_progress = move || { + progress_info.current += 1; + progress(progress_info) + }; + + Self::brent(config, &weights, inc_progress) } /// https://argmin-rs.github.io/argmin/argmin/solver/brent/index.html /// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2446 @@ -484,22 +504,14 @@ impl FSRS { mut progress: F, ) -> Result where - F: FnMut(ItemProgress) -> bool + Send, + F: FnMut() -> bool, { - let mut progress_info = ItemProgress { - current: 0, - total: 100, - }; - let inc_progress = Arc::new(Mutex::new(move || { - progress_info.current += 1; - progress(progress_info) - })); let mintol = 1e-10; let cg = 0.3819660; let maxiter = 20; let tol = 0.01f64; - let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, weights); + let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, weights, &mut progress)?; let (mut v, mut w, mut x) = (xb, xb, xb); let (mut fx, mut fv, mut fw) = (fb, fb, fb); @@ -510,9 +522,6 @@ impl FSRS { let mut u; while iter < maxiter { - if !(inc_progress.lock().unwrap()()) { - return Err(FSRSError::Interrupted); - } let tol1 = tol.mul_add(x.abs(), mintol); let tol2 = 2.0 * tol1; let xmid = 0.5 * (a + b); @@ -557,7 +566,7 @@ impl FSRS { rat }; // calculate new output value - let fu = -sample(config, weights, u, 5).unwrap(); + let fu = -sample(config, weights, u, 5, &mut progress)?; // if it's bigger than current if fu > fx {