Skip to content

Commit

Permalink
Feat/update default weights (#128)
Browse files Browse the repository at this point in the history
* Feat/update default weights

open-spaced-repetition/srs-benchmark#14

* update tests
  • Loading branch information
L-M-Sherlock authored Dec 8, 2023
1 parent 34570f3 commit 6e5ff88
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
10 changes: 5 additions & 5 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use burn::tensor::ElementConversion;
pub type Weights = [f32];

pub static DEFAULT_WEIGHTS: [f32; 17] = [
0.4, 0.9, 2.3, 10.9, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29,
2.61,
0.5888, 1.4616, 3.8226, 14.1364, 4.9214, 1.0325, 0.8731, 0.0613, 1.57, 0.1395, 0.988, 2.212,
0.0658, 0.3439, 1.3098, 0.2837, 2.7766,
];

fn infer<B: Backend>(
Expand Down Expand Up @@ -408,7 +408,7 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.20745006, 0.040_497_02]), 5);
.assert_approx_eq(&Data::from([0.20513662, 0.026_716_57]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();
Expand Down Expand Up @@ -501,14 +501,14 @@ mod tests {
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(),
MemoryState {
stability: 9.999995,
difficulty: 6.2652965
difficulty: 6.8565593
}
);
assert_eq!(
fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap(),
MemoryState {
stability: 19.99999,
difficulty: 9.956561
difficulty: 10.0
}
);
let interval = 15;
Expand Down
21 changes: 8 additions & 13 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ impl<B: Backend> Model<B> {
.initial_stability
.unwrap_or_else(|| DEFAULT_WEIGHTS[0..4].try_into().unwrap())
.into_iter()
.chain([
4.93, 0.94, 0.86, 0.01, // difficulty
1.49, 0.14, 0.94, // success
2.18, 0.05, 0.34, 1.26, // failure
0.29, 2.61, // hard penalty, easy bonus
])
.chain(DEFAULT_WEIGHTS[4..].iter().copied())
.collect();

Self {
Expand Down Expand Up @@ -283,7 +278,7 @@ mod tests {
let stability = model.init_stability(rating);
assert_eq!(
stability.to_data(),
Data::from([0.4, 0.9, 2.3, 10.9, 0.4, 0.9])
Data::from([0.5888, 1.4616, 3.8226, 14.1364, 0.5888, 1.4616])
)
}

Expand All @@ -294,7 +289,7 @@ mod tests {
let difficulty = model.init_difficulty(rating);
assert_eq!(
difficulty.to_data(),
Data::from([6.81, 5.87, 4.93, 3.9899998, 6.81, 5.87])
Data::from([6.9864, 5.9539003, 4.9214, 3.8889, 6.9864, 5.9539003])
)
}

Expand Down Expand Up @@ -322,13 +317,13 @@ mod tests {
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.7200003, 5.86, 5.0, 4.14])
Data::from([6.7462, 5.8731, 5.0, 4.1269])
);
let next_difficulty = model.mean_reversion(next_difficulty);
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.7021003, 5.8507, 4.9993, 4.1478996])
Data::from([6.63434, 5.8147607, 4.995182, 4.175603])
)
}

Expand All @@ -348,19 +343,19 @@ mod tests {
s_recall.clone().backward();
assert_eq!(
s_recall.to_data(),
Data::from([22.454704, 14.560361, 51.15574, 152.6869])
Data::from([24.938553, 15.710489, 57.993835, 185.87283])
);
let s_forget = model.stability_after_failure(stability, difficulty, retention);
s_forget.clone().backward();
assert_eq!(
s_forget.to_data(),
Data::from([2.074517, 2.2729328, 2.526406, 2.8247323])
Data::from([2.1479936, 2.339425, 2.596607, 2.904485])
);
let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget);
next_stability.clone().backward();
assert_eq!(
next_stability.to_data(),
Data::from([2.074517, 14.560361, 51.15574, 152.6869])
Data::from([2.1479936, 15.710489, 57.993835, 185.87283])
)
}

Expand Down
2 changes: 1 addition & 1 deletion src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ mod tests {
0.9,
None,
);
assert_eq!(memorization, 2633.365434092778)
assert_eq!(memorization, 3211.3084298933477)
}

#[test]
Expand Down

0 comments on commit 6e5ff88

Please sign in to comment.