Skip to content

Commit

Permalink
Minor tweaks to #166
Browse files Browse the repository at this point in the history
- Avoid par_iter() inside existing into_par_iter()
- Avoid extra trainset clone
  • Loading branch information
dae committed Mar 11, 2024
1 parent 08771a8 commit 2426871
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ impl<B: Backend> FSRS<B> {
.into_par_iter()
.map(|i| {
trainsets
.par_iter()
.iter()
.enumerate()
.filter(|&(j, _)| j != i)
.flat_map(|(_, trainset)| trainset.clone())
Expand All @@ -257,15 +257,16 @@ impl<B: Backend> FSRS<B> {
progress.lock().unwrap().splits = progress_states
}

let weight_sets: Result<Vec<Vec<f32>>> = (0..n_splits)
let weight_sets: Result<Vec<Vec<f32>>> = trainsets
.into_par_iter()
.map(|i| {
.enumerate()
.map(|(idx, trainset)| {
let model = train::<Autodiff<B>>(
trainsets[i].clone(),
trainset,
testset.clone(),
&config,
self.device(),
progress.clone().map(|p| ProgressCollector::new(p, i)),
progress.clone().map(|p| ProgressCollector::new(p, idx)),
);
Ok(model
.map_err(|e| {
Expand Down

0 comments on commit 2426871

Please sign in to comment.