Skip to content

Commit

Permalink
use arr
Browse files Browse the repository at this point in the history
  • Loading branch information
asukaminato0721 committed Dec 19, 2023
1 parent dfc89fe commit 9461ab8
Showing 1 changed file with 37 additions and 56 deletions.
93 changes: 37 additions & 56 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::error::{FSRSError, Result};
use crate::inference::{DECAY, FACTOR};
use crate::FSRSItem;
use crate::DEFAULT_WEIGHTS;
use itertools::Itertools;
use ndarray::Array1;
use std::collections::HashMap;

Expand Down Expand Up @@ -200,11 +199,12 @@ fn smooth_and_fill(
.iter()
.cloned()
.collect::<HashMap<_, _>>();
let rating_stability_arr = [
rating_stability.get(&1),
rating_stability.get(&2),
rating_stability.get(&3),
rating_stability.get(&4),
let mut rating_stability_arr = [
None,
rating_stability.get(&1).cloned(),
rating_stability.get(&2).cloned(),
rating_stability.get(&3).cloned(),
rating_stability.get(&4).cloned(),
];
match rating_stability.len() {
0 => return Err(FSRSError::NotEnoughData),
Expand All @@ -216,76 +216,63 @@ fn smooth_and_fill(
}
2 => {
match rating_stability_arr {
[None, None, Some(&r3), Some(&r4)] => {
[_, None, None, Some(r3), Some(r4)] => {
let r2 = r3.powf(1.0 / (1.0 - w2)) * r4.powf(1.0 - 1.0 / (1.0 - w2));
rating_stability.insert(2, r2);
rating_stability.insert(1, (r2.powf(1.0 / w1)) * (r3.powf(1.0 - 1.0 / w1)));
rating_stability_arr[2] = Some(r2);
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
[None, Some(&r2), None, Some(&r4)] => {
[_, None, Some(r2), None, Some(r4)] => {
let r3 = r2.powf(1.0 - w2) * r4.powf(w2);
rating_stability.insert(3, r3);
rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
rating_stability_arr[3] = Some(r3);
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
[None, Some(&r2), Some(&r3), None] => {
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
[_, None, Some(r2), Some(r3), None] => {
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
[Some(&r1), None, None, Some(&r4)] => {
[_, Some(r1), None, None, Some(r4)] => {
let r2 = r1.powf(w1 / (w1.mul_add(-w2, w1 + w2)))
* r4.powf(1.0 - w1 / (w1.mul_add(-w2, w1 + w2)));
rating_stability.insert(2, r2);
rating_stability.insert(
3,
rating_stability_arr[2] = Some(r2);
rating_stability_arr[3] = Some(
r1.powf(1.0 - w2 / (w1.mul_add(-w2, w1 + w2)))
* r4.powf(w2 / (w1.mul_add(-w2, w1 + w2))),
);
}
[Some(&r1), None, Some(&r3), None] => {
[_, Some(r1), None, Some(r3), None] => {
let r2 = r1.powf(w1) * r3.powf(1.0 - w1);
rating_stability.insert(2, r2);
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability_arr[2] = Some(r2);
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
}
[Some(&r1), Some(&r2), None, None] => {
[_, Some(r1), Some(r2), None, None] => {
let r3 = r1.powf(1.0 - 1.0 / (1.0 - w1)) * r2.powf(1.0 / (1.0 - w1));
rating_stability.insert(3, r3);
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability_arr[3] = Some(r3);
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
}
_ => {}
}
init_s0 = rating_stability
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.map(|(_, &v)| v)
.collect();
init_s0 = rating_stability_arr.into_iter().flatten().collect();
}
3 => {
match rating_stability_arr {
[None, Some(r2), Some(r3), _] => {
rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
[_, None, Some(r2), Some(r3), _] => {
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
[Some(r1), None, Some(r3), _] => {
rating_stability.insert(2, r1.powf(w1) * r3.powf(1.0 - w1));
[_, Some(r1), None, Some(r3), _] => {
rating_stability_arr[2] = Some(r1.powf(w1) * r3.powf(1.0 - w1));
}
[_, Some(r2), None, Some(r4)] => {
rating_stability.insert(3, r2.powf(1.0 - w2) * r4.powf(w2));
[_, _, Some(r2), None, Some(r4)] => {
rating_stability_arr[3] = Some(r2.powf(1.0 - w2) * r4.powf(w2));
}
[_, Some(r2), Some(r3), None] => {
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
[_, _, Some(r2), Some(r3), None] => {
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
}
_ => {}
}
init_s0 = rating_stability
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.map(|(_, &v)| v)
.collect();
init_s0 = rating_stability_arr.into_iter().flatten().collect();
}
4 => {
init_s0 = rating_stability
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.map(|(_, &v)| v)
.collect();
init_s0 = rating_stability_arr.into_iter().flatten().collect();
}
_ => {}
}
Expand Down Expand Up @@ -353,8 +340,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
Data::from([*actual.get(&4).unwrap()])
.assert_approx_eq(&Data::from([1.230_132_3]), 4);
Data::from([*actual.get(&4).unwrap()]).assert_approx_eq(&Data::from([1.230_132_3]), 4);
}

#[test]
Expand All @@ -364,12 +350,7 @@ mod tests {
let average_recall = calculate_average_recall(&items);
let pretrainset = split_data(items, 1).0;
Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq(
&Data::from([
0.956_017_43,
1.694_406_5,
3.998_023_5,
8.268_223,
]),
&Data::from([0.956_017_43, 1.694_406_5, 3.998_023_5, 8.268_223]),
4,
)
}
Expand Down

0 comments on commit 9461ab8

Please sign in to comment.