Skip to content

Commit

Permalink
Increase the sample size.rs (#129)
Browse files Browse the repository at this point in the history
Since the new method is roughly 2 times faster now, how about sacrificing that speed in favor of accuracy? Of course, ideally we want both, but if I had to choose, I would say accuracy is more important in this case.

---------

Co-authored-by: Jarrett Ye <[email protected]>
  • Loading branch information
Expertium and L-M-Sherlock authored Dec 11, 2023
1 parent 6e5ff88 commit d9b69c6
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ where
/ n as f64)
}

const SAMPLE_SIZE: usize = 10;

/// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2894
fn bracket<F>(
mut xa: f64,
Expand All @@ -395,15 +397,15 @@ where
const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html
const MAXITER: i32 = 20;

let mut fa = -sample(config, weights, xa, 5, progress)?;
let mut fb = -sample(config, weights, xb, 5, progress)?;
let mut fa = -sample(config, weights, xa, SAMPLE_SIZE, progress)?;
let mut fb = -sample(config, weights, xb, SAMPLE_SIZE, progress)?;

if fa < fb {
(fa, fb) = (fb, fa);
(xa, xb) = (xb, xa);
}
let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM);
let mut fc = -sample(config, weights, xc, 5, progress)?;
let mut fc = -sample(config, weights, xc, SAMPLE_SIZE, progress)?;

let mut iter = 0;
while fc < fb {
Expand All @@ -423,34 +425,34 @@ where
let mut fw: f64;

if (w - xc) * (xb - w) > 0.0 {
fw = -sample(config, weights, w, 5, progress)?;
fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?;
if fw < fc {
(xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM));
(fa, fb) = (fb, fw);
break;
} else if fw > fb {
xc = w.clamp(L_LIM, U_LIM);
fc = -sample(config, weights, xc, 5, progress)?;
fc = -sample(config, weights, xc, SAMPLE_SIZE, progress)?;
break;
}
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = -sample(config, weights, w, 5, progress)?;
fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?;
} else if (w - wlim) * (wlim - xc) >= 0.0 {
w = wlim;
fw = -sample(config, weights, w, 5, progress)?;
fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?;
} else if (w - wlim) * (xc - w) > 0.0 {
fw = -sample(config, weights, w, 5, progress)?;
fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?;
if fw < fc {
(xb, xc, w) = (
xc.clamp(L_LIM, U_LIM),
w.clamp(L_LIM, U_LIM),
GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM),
);
(fb, fc, fw) = (fc, fw, -sample(config, weights, w, 5, progress)?);
(fb, fc, fw) = (fc, fw, -sample(config, weights, w, SAMPLE_SIZE, progress)?);
}
} else {
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = -sample(config, weights, w, 5, progress)?;
fw = -sample(config, weights, w, SAMPLE_SIZE, progress)?;
}
(xa, xb, xc) = (
xb.clamp(L_LIM, U_LIM),
Expand Down Expand Up @@ -508,7 +510,7 @@ impl<B: Backend> FSRS<B> {
{
let mintol = 1e-10;
let cg = 0.3819660;
let maxiter = 20;
let maxiter = 64;
let tol = 0.01f64;

let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, weights, &mut progress)?;
Expand All @@ -526,7 +528,7 @@ impl<B: Backend> FSRS<B> {
let tol2 = 2.0 * tol1;
let xmid = 0.5 * (a + b);
// check for convergence
if (x - xmid).abs() <= (tol2 - 0.5 * (b - a)).abs() {
if (x - xmid).abs() < (tol2 - 0.5 * (b - a)) {
break;
}
if deltax.abs() <= tol1 {
Expand All @@ -546,7 +548,10 @@ impl<B: Backend> FSRS<B> {
let deltax_tmp = deltax;
deltax = rat;
// check parabolic fit
if (p > tmp2 * (a - x)) && (p < tmp2 * (b - x)) && (p.abs() < deltax_tmp.abs()) {
if (p > tmp2 * (a - x))
&& (p < tmp2 * (b - x))
&& (p.abs() < (0.5 * tmp2 * deltax_tmp).abs())
{
// if parabolic step is useful
rat = p / tmp2;
u = x + rat;
Expand All @@ -566,7 +571,7 @@ impl<B: Backend> FSRS<B> {
rat
};
// calculate new output value
let fu = -sample(config, weights, u, 5, &mut progress)?;
let fu = -sample(config, weights, u, SAMPLE_SIZE, &mut progress)?;

// if it's bigger than current
if fu > fx {
Expand Down Expand Up @@ -629,7 +634,7 @@ mod tests {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8263932);
assert_eq!(optimal_retention, 0.8736067949688);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down

0 comments on commit d9b69c6

Please sign in to comment.