From df877c945f157be27cbb1e4bdbaf49b3f864c763 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Fri, 8 Dec 2023 11:30:29 +0800 Subject: [PATCH] update tests --- src/inference.rs | 6 +++--- src/model.rs | 21 ++++++++------------- src/optimal_retention.rs | 2 +- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index b7c4c3bf..62692d78 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -408,7 +408,7 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.20745006, 0.040_497_02]), 5); + .assert_approx_eq(&Data::from([0.20513662, 0.026_716_57]), 5); let fsrs = FSRS::new(Some(WEIGHTS))?; let metrics = fsrs.evaluate(items, |_| true).unwrap(); @@ -501,14 +501,14 @@ mod tests { fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(), MemoryState { stability: 9.999995, - difficulty: 6.2652965 + difficulty: 6.8565593 } ); assert_eq!( fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap(), MemoryState { stability: 19.99999, - difficulty: 9.956561 + difficulty: 10.0 } ); let interval = 15; diff --git a/src/model.rs b/src/model.rs index 0108abc6..ebd005f2 100644 --- a/src/model.rs +++ b/src/model.rs @@ -45,12 +45,7 @@ impl Model { .initial_stability .unwrap_or_else(|| DEFAULT_WEIGHTS[0..4].try_into().unwrap()) .into_iter() - .chain([ - 4.93, 0.94, 0.86, 0.01, // difficulty - 1.49, 0.14, 0.94, // success - 2.18, 0.05, 0.34, 1.26, // failure - 0.29, 2.61, // hard penalty, easy bonus - ]) + .chain(DEFAULT_WEIGHTS[4..].iter().copied()) .collect(); Self { @@ -283,7 +278,7 @@ mod tests { let stability = model.init_stability(rating); assert_eq!( stability.to_data(), - Data::from([0.4, 0.9, 2.3, 10.9, 0.4, 0.9]) + Data::from([0.5888, 1.4616, 3.8226, 14.1364, 0.5888, 1.4616]) ) } @@ -294,7 +289,7 @@ mod tests { let difficulty = model.init_difficulty(rating); assert_eq!( difficulty.to_data(), - Data::from([6.81, 5.87, 4.93, 3.9899998, 6.81, 5.87]) + Data::from([6.9864, 5.9539003, 4.9214, 3.8889, 6.9864, 5.9539003]) ) } @@ -322,13 +317,13 @@ mod tests { next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.7200003, 5.86, 5.0, 4.14]) + Data::from([6.7462, 5.8731, 5.0, 4.1269]) ); let next_difficulty = model.mean_reversion(next_difficulty); next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.7021003, 5.8507, 4.9993, 4.1478996]) + Data::from([6.63434, 5.8147607, 4.995182, 4.175603]) ) } @@ -348,19 +343,19 @@ mod tests { s_recall.clone().backward(); assert_eq!( s_recall.to_data(), - Data::from([22.454704, 14.560361, 51.15574, 152.6869]) + Data::from([24.938553, 15.710489, 57.993835, 185.87283]) ); let s_forget = model.stability_after_failure(stability, difficulty, retention); s_forget.clone().backward(); assert_eq!( s_forget.to_data(), - Data::from([2.074517, 2.2729328, 2.526406, 2.8247323]) + Data::from([2.1479936, 2.339425, 2.596607, 2.904485]) ); let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget); next_stability.clone().backward(); assert_eq!( next_stability.to_data(), - Data::from([2.074517, 14.560361, 51.15574, 152.6869]) + Data::from([2.1479936, 15.710489, 57.993835, 185.87283]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 1500d0d9..9142b4fc 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -621,7 +621,7 @@ mod tests { 0.9, None, ); - assert_eq!(memorization, 2633.365434092778) + assert_eq!(memorization, 3211.3084298933477) } #[test]