Skip to content

Commit

Permalink
iter_mut + enumerate
Browse files Browse the repository at this point in the history
  • Loading branch information
asukaminato0721 committed Mar 10, 2024
1 parent 19a1961 commit db12033
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ impl<B: Backend> FSRS<B> {

if let Some(progress) = &progress {
let mut progress_states = vec![ProgressState::default(); n_splits];
for i in 0..n_splits {
progress_states[i].epoch_total = config.num_epochs;
progress_states[i].items_total = trainsets[i].len();
for (i, progress_state) in progress_states.iter_mut().enumerate() {
progress_state.epoch_total = config.num_epochs;
progress_state.items_total = trainsets[i].len();
}
progress.lock().unwrap().splits = progress_states
}
Expand Down Expand Up @@ -291,10 +291,8 @@ impl<B: Backend> FSRS<B> {
.map(|&sum| sum / n_splits as f32)
.collect();

for weight in &average_parameters {
if !weight.is_finite() {
return Err(FSRSError::InvalidInput);
}
if average_parameters.iter().any(|weight| weight.is_infinite()) {
return Err(FSRSError::InvalidInput);
}

Ok(average_parameters)
Expand Down

0 comments on commit db12033

Please sign in to comment.