From c0937b4956e0e85c91be72ee23406b71f6923aff Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sat, 16 Dec 2023 18:35:43 +0800 Subject: [PATCH] Feat/flat power forgetting curve --- src/inference.rs | 28 +++++++++++++++------------- src/model.rs | 6 ++++-- src/optimal_retention.rs | 23 ++++++++++++++--------- src/pre_training.rs | 14 ++++++++------ 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index 0ac05e27..01b2125d 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -57,8 +57,10 @@ impl From for MemoryStateTensors { } } -fn next_interval(stability: f32, request_retention: f32) -> u32 { - (9.0 * stability * (1.0 / request_retention - 1.0)) +pub fn next_interval(stability: f32, desired_retention: f32) -> u32 { + let decay: f32 = -0.5; + let factor = 0.9_f32.powf(1.0 / decay) - 1.0; + (stability / factor * (desired_retention.powf(1.0 / decay) - 1.0)) .round() .max(1.0) as u32 } @@ -365,7 +367,7 @@ mod tests { assert_eq!( fsrs.memory_state(item, None).unwrap(), MemoryState { - stability: 51.344814, + stability: 51.31289, difficulty: 7.005062 } ); @@ -383,7 +385,7 @@ mod tests { .good .memory, MemoryState { - stability: 51.344814, + stability: 51.33972, difficulty: 7.005062 } ); @@ -392,12 +394,12 @@ mod tests { #[test] fn test_next_interval() { - let request_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::>(); - let intervals = request_retentions + let desired_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::>(); + let intervals = desired_retentions .iter() .map(|r| next_interval(1.0, *r)) .collect::>(); - assert_eq!(intervals, [81, 36, 21, 14, 9, 6, 4, 2, 1, 1,]); + assert_eq!(intervals, [422, 102, 43, 22, 13, 8, 4, 2, 1, 1]); } #[test] @@ -408,13 +410,13 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.20513662, 0.026_716_57]), 5); + .assert_approx_eq(&Data::from([0.20457467, 0.02268843]), 5); let fsrs = FSRS::new(Some(WEIGHTS))?; let metrics = fsrs.evaluate(items, |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.203_217_7, 0.015_836_29]), 5); + .assert_approx_eq(&Data::from([0.20306083, 0.01326745]), 5); Ok(()) } @@ -447,28 +449,28 @@ mod tests { NextStates { again: ItemState { memory: MemoryState { - stability: 4.5802255, + stability: 4.5778565, difficulty: 8.881129, }, interval: 5 }, hard: ItemState { memory: MemoryState { - stability: 27.7025, + stability: 27.6745, difficulty: 7.9430957 }, interval: 28, }, good: ItemState { memory: MemoryState { - stability: 51.344814, + stability: 51.31289, difficulty: 7.005062 }, interval: 51, }, easy: ItemState { memory: MemoryState { - stability: 101.98282, + stability: 101.94249, difficulty: 6.0670285 }, interval: 102, diff --git a/src/model.rs b/src/model.rs index ebd005f2..07014741 100644 --- a/src/model.rs +++ b/src/model.rs @@ -58,7 +58,9 @@ impl Model { } pub fn power_forgetting_curve(&self, t: Tensor, s: Tensor) -> Tensor { - (t / (s * 9) + 1).powf(-1.0) + let decay: f32 = -0.5; + let factor = 0.9_f32.powf(1.0 / decay) - 1.0; + (t / s * factor + 1).powf(decay) } fn stability_after_success( @@ -267,7 +269,7 @@ mod tests { let retention = model.power_forgetting_curve(delta_t, stability); assert_eq!( retention.to_data(), - Data::from([1.0, 0.9473684, 0.9310345, 0.92307687, 0.9, 0.7826087]) + Data::from([1.0, 0.9460589, 0.9299294, 0.9221679, 0.9, 0.7939459]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index a04c7aab..414ec98e 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -1,5 +1,5 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{ItemProgress, Weights}; +use crate::inference::{next_interval, ItemProgress, Weights}; use crate::{DEFAULT_WEIGHTS, FSRS}; use burn::tensor::backend::Backend; use itertools::izip; @@ -90,7 +90,7 @@ fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 { .clamp(0.1, s) } -fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: Option) -> f64 { +fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: Option) -> f64 { let SimulatorConfig { deck_size, learn_span, @@ -140,11 +140,17 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O let mut retrievability = Array1::zeros(deck_size); // Create an array for retrievability + fn power_forgetting_curve(t: f64, s: f64) -> f64 { + let decay: f64 = -0.5; + let factor = 0.9_f64.powf(1.0 / decay) - 1.0; + (t / s * factor + 1.0).powf(decay) + } + // Calculate retrievability for entries where has_learned is true izip!(&mut retrievability, &delta_t, &old_stability, &has_learned) .filter(|(.., &has_learned_flag)| has_learned_flag) .for_each(|(retrievability, &delta_t, &stability, ..)| { - *retrievability = (1.0 + delta_t / (9.0 * stability)).powi(-1) + *retrievability = power_forgetting_curve(delta_t, stability) }); // Set 'cost' column to 0 @@ -315,8 +321,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O izip!(&mut new_interval, &new_stability, &true_review, &true_learn) .filter(|(.., &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag) .for_each(|(new_ivl, &new_stab, ..)| { - *new_ivl = (9.0 * new_stab * (1.0 / request_retention - 1.0)) - .round() + *new_ivl = (next_interval(new_stab as f32, desired_retention as f32) as f64) .clamp(1.0, max_ivl); }); @@ -354,7 +359,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O fn sample( config: &SimulatorConfig, weights: &[f64], - request_retention: f64, + desired_retention: f64, n: usize, progress: &mut F, ) -> Result @@ -370,7 +375,7 @@ where simulate( config, weights, - request_retention, + desired_retention, Some((i + 42).try_into().unwrap()), ) }) @@ -626,7 +631,7 @@ mod tests { 0.9, None, ); - assert_eq!(memorization, 3211.3084298933477) + assert_eq!(memorization, 3323.859146517903) } #[test] @@ -634,7 +639,7 @@ mod tests { let config = SimulatorConfig::default(); let fsrs = FSRS::new(None)?; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8736067949688); + assert_eq!(optimal_retention, 0.8263932); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } diff --git a/src/pre_training.rs b/src/pre_training.rs index 0113c3c7..378db531 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -87,7 +87,9 @@ fn total_rating_count( } fn power_forgetting_curve(t: &Array1, s: f32) -> Array1 { - 1.0 / (1.0 + t / (9.0 * s)) + let decay: f32 = -0.5; + let factor = 0.9_f32.powf(1.0 / decay) - 1.0; + (t / s * factor + 1.0).mapv(|v| v.powf(decay)) } fn loss( @@ -311,7 +313,7 @@ mod tests { let t = Array1::from(vec![0.0, 1.0, 2.0, 3.0]); let s = 1.0; let y = power_forgetting_curve(&t, s); - let expected = Array1::from(vec![1.0, 0.9, 0.8181818, 0.75]); + let expected = Array1::from(vec![1.0, 0.9, 0.8250286, 0.7661308]); assert_eq!(y, expected); } @@ -322,8 +324,8 @@ mod tests { let count = Array1::from(vec![100.0, 100.0, 100.0]); let init_s0 = 1.0; let actual = loss(&delta_t, &recall, &count, init_s0, init_s0); - assert_eq!(actual, 0.45385247); - assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 0.48355862); + assert_eq!(actual, 0.45414436); + assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 0.48402837); } #[test] @@ -354,7 +356,7 @@ mod tests { ], )]); let actual = search_parameters(pretrainset, 0.9); - let expected = [(4, 1.4877763)].into_iter().collect(); + let expected = [(4, 1.2733965)].into_iter().collect(); assert_eq!(actual, expected); } @@ -366,7 +368,7 @@ mod tests { let pretrainset = split_data(items, 1).0; assert_eq!( pretrain(pretrainset, average_recall).unwrap(), - [0.9517492, 1.7152255, 4.149725, 9.399195,], + [0.89360625, 1.6562619, 4.1792974, 9.724018], ) }