Skip to content

Commit

Permalink
filter and sort train_set when benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 27, 2024
1 parent d715930 commit 55cb2e4
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ impl<B: Backend> FSRS<B> {
Ok(optimized_parameters)
}

pub fn benchmark(&self, train_set: Vec<FSRSItem>) -> Vec<f32> {
pub fn benchmark(&self, mut train_set: Vec<FSRSItem>) -> Vec<f32> {
let average_recall = calculate_average_recall(&train_set);
let (pre_train_set, _next_train_set) = train_set
.clone()
Expand All @@ -302,6 +302,8 @@ impl<B: Backend> FSRS<B> {
},
AdamConfig::new().with_epsilon(1e-8),
);
train_set.retain(|item| item.reviews.len() <= 64);
train_set.sort_by_cached_key(|item| item.long_term_review_cnt());
let model =
train::<Autodiff<B>>(train_set.clone(), train_set, &config, self.device(), None);
let parameters: Vec<f32> = model.unwrap().w.val().to_data().convert().value;
Expand Down

0 comments on commit 55cb2e4

Please sign in to comment.