From 458ed0a55950a0719d4b9b676630638bc30b2d91 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 19 Dec 2023 17:46:33 +0900 Subject: [PATCH] Use array instead of HashMap (#140) * match arr * clippy --fix * use arr --- src/inference.rs | 2 +- src/pre_training.rs | 104 +++++++++++++++++--------------------------- 2 files changed, 41 insertions(+), 65 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index 0dc22b59..8c652ac1 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -410,7 +410,7 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.21364396810531616, 0.05370686203241348]), 5); + .assert_approx_eq(&Data::from([0.213_643_97, 0.053_706_862]), 5); let fsrs = FSRS::new(Some(WEIGHTS))?; let metrics = fsrs.evaluate(items, |_| true).unwrap(); diff --git a/src/pre_training.rs b/src/pre_training.rs index a67fe5da..f95c2832 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -2,7 +2,6 @@ use crate::error::{FSRSError, Result}; use crate::inference::{DECAY, FACTOR}; use crate::FSRSItem; use crate::DEFAULT_WEIGHTS; -use itertools::Itertools; use ndarray::Array1; use std::collections::HashMap; @@ -200,7 +199,13 @@ fn smooth_and_fill( .iter() .cloned() .collect::>(); - + let mut rating_stability_arr = [ + None, + rating_stability.get(&1).cloned(), + rating_stability.get(&2).cloned(), + rating_stability.get(&3).cloned(), + rating_stability.get(&4).cloned(), + ]; match rating_stability.len() { 0 => return Err(FSRSError::NotEnoughData), 1 => { @@ -210,87 +215,64 @@ fn smooth_and_fill( init_s0.sort_by(|a, b| a.partial_cmp(b).unwrap()); } 2 => { - match ( - rating_stability.get(&1), - rating_stability.get(&2), - rating_stability.get(&3), - rating_stability.get(&4), - ) { - (None, None, Some(&r3), Some(&r4)) => { + match rating_stability_arr { + [_, None, None, Some(r3), Some(r4)] => { let r2 = r3.powf(1.0 / (1.0 - w2)) * r4.powf(1.0 - 1.0 / (1.0 - w2)); - rating_stability.insert(2, r2); - rating_stability.insert(1, (r2.powf(1.0 / w1)) * (r3.powf(1.0 - 1.0 / w1))); + rating_stability_arr[2] = Some(r2); + rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); } - (None, Some(&r2), None, Some(&r4)) => { + [_, None, Some(r2), None, Some(r4)] => { let r3 = r2.powf(1.0 - w2) * r4.powf(w2); - rating_stability.insert(3, r3); - rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); + rating_stability_arr[3] = Some(r3); + rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); } - (None, Some(&r2), Some(&r3), None) => { - rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); - rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); + [_, None, Some(r2), Some(r3), None] => { + rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); + rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); } - (Some(&r1), None, None, Some(&r4)) => { + [_, Some(r1), None, None, Some(r4)] => { let r2 = r1.powf(w1 / (w1.mul_add(-w2, w1 + w2))) * r4.powf(1.0 - w1 / (w1.mul_add(-w2, w1 + w2))); - rating_stability.insert(2, r2); - rating_stability.insert( - 3, + rating_stability_arr[2] = Some(r2); + rating_stability_arr[3] = Some( r1.powf(1.0 - w2 / (w1.mul_add(-w2, w1 + w2))) * r4.powf(w2 / (w1.mul_add(-w2, w1 + w2))), ); } - (Some(&r1), None, Some(&r3), None) => { + [_, Some(r1), None, Some(r3), None] => { let r2 = r1.powf(w1) * r3.powf(1.0 - w1); - rating_stability.insert(2, r2); - rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); + rating_stability_arr[2] = Some(r2); + rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); } - (Some(&r1), Some(&r2), None, None) => { + [_, Some(r1), Some(r2), None, None] => { let r3 = r1.powf(1.0 - 1.0 / (1.0 - w1)) * r2.powf(1.0 / (1.0 - w1)); - rating_stability.insert(3, r3); - rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); + rating_stability_arr[3] = Some(r3); + rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); } _ => {} } - init_s0 = rating_stability - .iter() - .sorted_by(|a, b| a.0.cmp(b.0)) - .map(|(_, &v)| v) - .collect(); + init_s0 = rating_stability_arr.into_iter().flatten().collect(); } 3 => { - match ( - rating_stability.get(&1), - rating_stability.get(&2), - rating_stability.get(&3), - rating_stability.get(&4), - ) { - (None, Some(r2), Some(r3), _) => { - rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); + match rating_stability_arr { + [_, None, Some(r2), Some(r3), _] => { + rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); } - (Some(r1), None, Some(r3), _) => { - rating_stability.insert(2, r1.powf(w1) * r3.powf(1.0 - w1)); + [_, Some(r1), None, Some(r3), _] => { + rating_stability_arr[2] = Some(r1.powf(w1) * r3.powf(1.0 - w1)); } - (_, Some(r2), None, Some(r4)) => { - rating_stability.insert(3, r2.powf(1.0 - w2) * r4.powf(w2)); + [_, _, Some(r2), None, Some(r4)] => { + rating_stability_arr[3] = Some(r2.powf(1.0 - w2) * r4.powf(w2)); } - (_, Some(r2), Some(r3), None) => { - rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); + [_, _, Some(r2), Some(r3), None] => { + rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2)); } _ => {} } - init_s0 = rating_stability - .iter() - .sorted_by(|a, b| a.0.cmp(b.0)) - .map(|(_, &v)| v) - .collect(); + init_s0 = rating_stability_arr.into_iter().flatten().collect(); } 4 => { - init_s0 = rating_stability - .iter() - .sorted_by(|a, b| a.0.cmp(b.0)) - .map(|(_, &v)| v) - .collect(); + init_s0 = rating_stability_arr.into_iter().flatten().collect(); } _ => {} } @@ -358,8 +340,7 @@ mod tests { ], )]); let actual = search_parameters(pretrainset, 0.9); - Data::from([actual.get(&4).unwrap().clone()]) - .assert_approx_eq(&Data::from([1.2301323413848877]), 4); + Data::from([*actual.get(&4).unwrap()]).assert_approx_eq(&Data::from([1.230_132_3]), 4); } #[test] @@ -369,12 +350,7 @@ mod tests { let average_recall = calculate_average_recall(&items); let pretrainset = split_data(items, 1).0; Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq( - &Data::from([ - 0.9560174345970154, - 1.694406509399414, - 3.998023509979248, - 8.26822280883789, - ]), + &Data::from([0.956_017_43, 1.694_406_5, 3.998_023_5, 8.268_223]), 4, ) }