Skip to content

Commit

Permalink
Feat/flat power forgetting curve
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Dec 16, 2023
1 parent 07621d4 commit c0937b4
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 30 deletions.
28 changes: 15 additions & 13 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ impl<B: Backend> From<MemoryState> for MemoryStateTensors<B> {
}
}

fn next_interval(stability: f32, request_retention: f32) -> u32 {
(9.0 * stability * (1.0 / request_retention - 1.0))
pub fn next_interval(stability: f32, desired_retention: f32) -> u32 {
let decay: f32 = -0.5;
let factor = 0.9_f32.powf(1.0 / decay) - 1.0;
(stability / factor * (desired_retention.powf(1.0 / decay) - 1.0))
.round()
.max(1.0) as u32
}
Expand Down Expand Up @@ -365,7 +367,7 @@ mod tests {
assert_eq!(
fsrs.memory_state(item, None).unwrap(),
MemoryState {
stability: 51.344814,
stability: 51.31289,
difficulty: 7.005062
}
);
Expand All @@ -383,7 +385,7 @@ mod tests {
.good
.memory,
MemoryState {
stability: 51.344814,
stability: 51.33972,
difficulty: 7.005062
}
);
Expand All @@ -392,12 +394,12 @@ mod tests {

#[test]
fn test_next_interval() {
let request_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::<Vec<_>>();
let intervals = request_retentions
let desired_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::<Vec<_>>();
let intervals = desired_retentions
.iter()
.map(|r| next_interval(1.0, *r))
.collect::<Vec<_>>();
assert_eq!(intervals, [81, 36, 21, 14, 9, 6, 4, 2, 1, 1,]);
assert_eq!(intervals, [422, 102, 43, 22, 13, 8, 4, 2, 1, 1]);
}

#[test]
Expand All @@ -408,13 +410,13 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.20513662, 0.026_716_57]), 5);
.assert_approx_eq(&Data::from([0.20457467, 0.02268843]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.203_217_7, 0.015_836_29]), 5);
.assert_approx_eq(&Data::from([0.20306083, 0.01326745]), 5);
Ok(())
}

Expand Down Expand Up @@ -447,28 +449,28 @@ mod tests {
NextStates {
again: ItemState {
memory: MemoryState {
stability: 4.5802255,
stability: 4.5778565,
difficulty: 8.881129,
},
interval: 5
},
hard: ItemState {
memory: MemoryState {
stability: 27.7025,
stability: 27.6745,
difficulty: 7.9430957
},
interval: 28,
},
good: ItemState {
memory: MemoryState {
stability: 51.344814,
stability: 51.31289,
difficulty: 7.005062
},
interval: 51,
},
easy: ItemState {
memory: MemoryState {
stability: 101.98282,
stability: 101.94249,
difficulty: 6.0670285
},
interval: 102,
Expand Down
6 changes: 4 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ impl<B: Backend> Model<B> {
}

pub fn power_forgetting_curve(&self, t: Tensor<B, 1>, s: Tensor<B, 1>) -> Tensor<B, 1> {
(t / (s * 9) + 1).powf(-1.0)
let decay: f32 = -0.5;
let factor = 0.9_f32.powf(1.0 / decay) - 1.0;
(t / s * factor + 1).powf(decay)
}

fn stability_after_success(
Expand Down Expand Up @@ -267,7 +269,7 @@ mod tests {
let retention = model.power_forgetting_curve(delta_t, stability);
assert_eq!(
retention.to_data(),
Data::from([1.0, 0.9473684, 0.9310345, 0.92307687, 0.9, 0.7826087])
Data::from([1.0, 0.9460589, 0.9299294, 0.9221679, 0.9, 0.7939459])
)
}

Expand Down
23 changes: 14 additions & 9 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::{FSRSError, Result};
use crate::inference::{ItemProgress, Weights};
use crate::inference::{next_interval, ItemProgress, Weights};
use crate::{DEFAULT_WEIGHTS, FSRS};
use burn::tensor::backend::Backend;
use itertools::izip;
Expand Down Expand Up @@ -90,7 +90,7 @@ fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 {
.clamp(0.1, s)
}

fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: Option<u64>) -> f64 {
fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: Option<u64>) -> f64 {
let SimulatorConfig {
deck_size,
learn_span,
Expand Down Expand Up @@ -140,11 +140,17 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O

let mut retrievability = Array1::zeros(deck_size); // Create an array for retrievability

fn power_forgetting_curve(t: f64, s: f64) -> f64 {
let decay: f64 = -0.5;
let factor = 0.9_f64.powf(1.0 / decay) - 1.0;
(t / s * factor + 1.0).powf(decay)
}

// Calculate retrievability for entries where has_learned is true
izip!(&mut retrievability, &delta_t, &old_stability, &has_learned)
.filter(|(.., &has_learned_flag)| has_learned_flag)
.for_each(|(retrievability, &delta_t, &stability, ..)| {
*retrievability = (1.0 + delta_t / (9.0 * stability)).powi(-1)
*retrievability = power_forgetting_curve(delta_t, stability)
});

// Set 'cost' column to 0
Expand Down Expand Up @@ -315,8 +321,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
izip!(&mut new_interval, &new_stability, &true_review, &true_learn)
.filter(|(.., &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag)
.for_each(|(new_ivl, &new_stab, ..)| {
*new_ivl = (9.0 * new_stab * (1.0 / request_retention - 1.0))
.round()
*new_ivl = (next_interval(new_stab as f32, desired_retention as f32) as f64)
.clamp(1.0, max_ivl);
});

Expand Down Expand Up @@ -354,7 +359,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
fn sample<F>(
config: &SimulatorConfig,
weights: &[f64],
request_retention: f64,
desired_retention: f64,
n: usize,
progress: &mut F,
) -> Result<f64>
Expand All @@ -370,7 +375,7 @@ where
simulate(
config,
weights,
request_retention,
desired_retention,
Some((i + 42).try_into().unwrap()),
)
})
Expand Down Expand Up @@ -626,15 +631,15 @@ mod tests {
0.9,
None,
);
assert_eq!(memorization, 3211.3084298933477)
assert_eq!(memorization, 3323.859146517903)
}

#[test]
fn optimal_retention() -> Result<()> {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8736067949688);
assert_eq!(optimal_retention, 0.8263932);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down
14 changes: 8 additions & 6 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ fn total_rating_count(
}

fn power_forgetting_curve(t: &Array1<f32>, s: f32) -> Array1<f32> {
1.0 / (1.0 + t / (9.0 * s))
let decay: f32 = -0.5;
let factor = 0.9_f32.powf(1.0 / decay) - 1.0;
(t / s * factor + 1.0).mapv(|v| v.powf(decay))
}

fn loss(
Expand Down Expand Up @@ -311,7 +313,7 @@ mod tests {
let t = Array1::from(vec![0.0, 1.0, 2.0, 3.0]);
let s = 1.0;
let y = power_forgetting_curve(&t, s);
let expected = Array1::from(vec![1.0, 0.9, 0.8181818, 0.75]);
let expected = Array1::from(vec![1.0, 0.9, 0.8250286, 0.7661308]);
assert_eq!(y, expected);
}

Expand All @@ -322,8 +324,8 @@ mod tests {
let count = Array1::from(vec![100.0, 100.0, 100.0]);
let init_s0 = 1.0;
let actual = loss(&delta_t, &recall, &count, init_s0, init_s0);
assert_eq!(actual, 0.45385247);
assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 0.48355862);
assert_eq!(actual, 0.45414436);
assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 0.48402837);
}

#[test]
Expand Down Expand Up @@ -354,7 +356,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
let expected = [(4, 1.4877763)].into_iter().collect();
let expected = [(4, 1.2733965)].into_iter().collect();
assert_eq!(actual, expected);
}

Expand All @@ -366,7 +368,7 @@ mod tests {
let pretrainset = split_data(items, 1).0;
assert_eq!(
pretrain(pretrainset, average_recall).unwrap(),
[0.9517492, 1.7152255, 4.149725, 9.399195,],
[0.89360625, 1.6562619, 4.1792974, 9.724018],
)
}

Expand Down

0 comments on commit c0937b4

Please sign in to comment.