Skip to content

Commit

Permalink
use mode parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Dec 16, 2023
1 parent 9f245da commit 6aa0ef4
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use burn::tensor::ElementConversion;
pub type Weights = [f32];

pub static DEFAULT_WEIGHTS: [f32; 17] = [
0.5888, 1.4616, 3.8226, 14.1364, 4.9214, 1.0325, 0.8731, 0.0613, 1.57, 0.1395, 0.988, 2.212,
0.0658, 0.3439, 1.3098, 0.2837, 2.7766,
0.27, 0.74, 1.3, 5.52, 5.1, 1.02, 0.78, 0.06, 1.57, 0.14, 0.94, 2.16, 0.06, 0.31, 1.34, 0.21,
2.69,
];

fn infer<B: Backend>(
Expand Down Expand Up @@ -408,7 +408,7 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.20513662, 0.026_716_57]), 5);
.assert_approx_eq(&Data::from([0.21600282, 0.06387164]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();
Expand Down Expand Up @@ -501,7 +501,7 @@ mod tests {
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(),
MemoryState {
stability: 9.999995,
difficulty: 6.8565593
difficulty: 6.6293178
}
);
assert_eq!(
Expand Down
14 changes: 7 additions & 7 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ mod tests {
let stability = model.init_stability(rating);
assert_eq!(
stability.to_data(),
Data::from([0.5888, 1.4616, 3.8226, 14.1364, 0.5888, 1.4616])
Data::from([0.27, 0.74, 1.3, 5.52, 0.27, 0.74])
)
}

Expand All @@ -289,7 +289,7 @@ mod tests {
let difficulty = model.init_difficulty(rating);
assert_eq!(
difficulty.to_data(),
Data::from([6.9864, 5.9539003, 4.9214, 3.8889, 6.9864, 5.9539003])
Data::from([7.14, 6.12, 5.1, 4.08, 7.14, 6.12])
)
}

Expand Down Expand Up @@ -317,13 +317,13 @@ mod tests {
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.7462, 5.8731, 5.0, 4.1269])
Data::from([6.56, 5.7799997, 5.0, 4.2200003])
);
let next_difficulty = model.mean_reversion(next_difficulty);
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.63434, 5.8147607, 4.995182, 4.175603])
Data::from([6.4723997, 5.7391996, 5.006, 4.2728004])
)
}

Expand All @@ -343,19 +343,19 @@ mod tests {
s_recall.clone().backward();
assert_eq!(
s_recall.to_data(),
Data::from([24.938553, 15.710489, 57.993835, 185.87283])
Data::from([23.908455, 12.499619, 54.99991, 169.89117])
);
let s_forget = model.stability_after_failure(stability, difficulty, retention);
s_forget.clone().backward();
assert_eq!(
s_forget.to_data(),
Data::from([2.1479936, 2.339425, 2.596607, 2.904485])
Data::from([1.8343093, 2.0118992, 2.245103, 2.5231054])
);
let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget);
next_stability.clone().backward();
assert_eq!(
next_stability.to_data(),
Data::from([2.1479936, 15.710489, 57.993835, 185.87283])
Data::from([1.8343093, 12.499619, 54.99991, 169.89117])
)
}

Expand Down
4 changes: 2 additions & 2 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,15 +626,15 @@ mod tests {
0.9,
None,
);
assert_eq!(memorization, 3211.3084298933477)
assert_eq!(memorization, 2405.020202735966)
}

#[test]
fn optimal_retention() -> Result<()> {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8736067949688);
assert_eq!(optimal_retention, 0.8608067460076987);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
let expected = [(4, 1.4877763)].into_iter().collect();
let expected = [(4, 1.2390649)].into_iter().collect();
assert_eq!(actual, expected);
}

Expand All @@ -366,7 +366,7 @@ mod tests {
let pretrainset = split_data(items, 1).0;
assert_eq!(
pretrain(pretrainset, average_recall).unwrap(),
[0.9517492, 1.7152255, 4.149725, 9.399195,],
[0.94550645, 1.6813093, 3.9867811, 8.992397,],
)
}

Expand Down

0 comments on commit 6aa0ef4

Please sign in to comment.