Skip to content

Commit

Permalink
Merge branch 'main' into Feat/flat-power-forgetting-curve
Browse files Browse the repository at this point in the history
  • Loading branch information
asukaminato0721 authored Dec 18, 2023
2 parents c64cf1d + 2c7cdf9 commit e72b00e
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
7 changes: 6 additions & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@ pub fn filter_outlier(

for (delta_t, sub_group) in sub_groups.iter().rev() {
if has_been_removed + sub_group.len() > total / 20 {
filtered_items.extend_from_slice(sub_group);
// keep the group if it includes at least one item rated again (retention < 100%)
if sub_group.iter().any(|item| item.current().rating == 1) {
filtered_items.extend_from_slice(sub_group);
} else {
removed_pairs[rating as usize].insert(*delta_t);
}
} else {
has_been_removed += sub_group.len();
removed_pairs[rating as usize].insert(*delta_t);
Expand Down
6 changes: 3 additions & 3 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ pub const FACTOR: f64 = 19f64 / 81f64;
pub type Weights = [f32];

pub static DEFAULT_WEIGHTS: [f32; 17] = [
0.5888, 1.4616, 3.8226, 14.1364, 4.9214, 1.0325, 0.8731, 0.0613, 1.57, 0.1395, 0.988, 2.212,
0.0658, 0.3439, 1.3098, 0.2837, 2.7766,
0.27, 0.74, 1.3, 5.52, 5.1, 1.02, 0.78, 0.06, 1.57, 0.14, 0.94, 2.16, 0.06, 0.31, 1.34, 0.21,
2.69,
];

fn infer<B: Backend>(
Expand Down Expand Up @@ -503,7 +503,7 @@ mod tests {
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(),
MemoryState {
stability: 9.999995,
difficulty: 6.8565593
difficulty: 6.6293178
}
);
assert_eq!(
Expand Down
14 changes: 7 additions & 7 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ mod tests {
let stability = model.init_stability(rating);
assert_eq!(
stability.to_data(),
Data::from([0.5888, 1.4616, 3.8226, 14.1364, 0.5888, 1.4616])
Data::from([0.27, 0.74, 1.3, 5.52, 0.27, 0.74])
)
}

Expand All @@ -289,7 +289,7 @@ mod tests {
let difficulty = model.init_difficulty(rating);
assert_eq!(
difficulty.to_data(),
Data::from([6.9864, 5.9539003, 4.9214, 3.8889, 6.9864, 5.9539003])
Data::from([7.14, 6.12, 5.1, 4.08, 7.14, 6.12])
)
}

Expand Down Expand Up @@ -317,13 +317,13 @@ mod tests {
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.7462, 5.8731, 5.0, 4.1269])
Data::from([6.56, 5.7799997, 5.0, 4.2200003])
);
let next_difficulty = model.mean_reversion(next_difficulty);
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.63434, 5.8147607, 4.995182, 4.175603])
Data::from([6.4723997, 5.7391996, 5.006, 4.2728004])
)
}

Expand All @@ -343,19 +343,19 @@ mod tests {
s_recall.clone().backward();
assert_eq!(
s_recall.to_data(),
Data::from([24.938553, 15.710489, 57.993835, 185.87283])
Data::from([23.908455, 12.499619, 54.99991, 169.89117])
);
let s_forget = model.stability_after_failure(stability, difficulty, retention);
s_forget.clone().backward();
assert_eq!(
s_forget.to_data(),
Data::from([2.1479936, 2.339425, 2.596607, 2.904485])
Data::from([1.8343093, 2.0118992, 2.245103, 2.5231054])
);
let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget);
next_stability.clone().backward();
assert_eq!(
next_stability.to_data(),
Data::from([2.1479936, 15.710489, 57.993835, 185.87283])
Data::from([1.8343093, 12.499619, 54.99991, 169.89117])
)
}

Expand Down
3 changes: 1 addition & 2 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
let expected = [(4, 1.4525769)].into_iter().collect();
assert_eq!(actual, expected);
Data::from([actual.get(&4).unwrap().clone()]).assert_approx_eq(&Data::from([1.2390649]), 4);
}

#[test]
Expand Down

0 comments on commit e72b00e

Please sign in to comment.