Skip to content

Commit

Permalink
pretrain default S0 should be from DEFAULT_WEIGHTS (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Dec 14, 2023
1 parent d8b2aad commit 07621d4
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crate::error::{FSRSError, Result};
use crate::FSRSItem;
use crate::DEFAULT_WEIGHTS;
use itertools::Itertools;
use ndarray::Array1;
use std::collections::HashMap;

static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[(1, 0.4), (2, 0.9), (3, 2.3), (4, 10.9)];
static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[
(1, DEFAULT_WEIGHTS[0]),
(2, DEFAULT_WEIGHTS[1]),
(3, DEFAULT_WEIGHTS[2]),
(4, DEFAULT_WEIGHTS[3]),
];

pub fn pretrain(fsrs_items: Vec<FSRSItem>, average_recall: f32) -> Result<[f32; 4]> {
let pretrainset = create_pretrain_data(fsrs_items);
Expand Down Expand Up @@ -109,6 +115,9 @@ fn append_default_point(data: &mut Vec<AverageRecall>, default_s0: f32) {
});
}

const MIN_INIT_S0: f32 = 0.1;
const MAX_INIT_S0: f32 = 100.0;

fn search_parameters(
mut pretrainset: HashMap<FirstRating, Vec<AverageRecall>>,
average_recall: f32,
Expand All @@ -135,8 +144,8 @@ fn search_parameters(

let count = Array1::from_iter(data.iter().map(|d| d.count));

let mut low = 0.1;
let mut high = 100.0;
let mut low = MIN_INIT_S0;
let mut high = MAX_INIT_S0;
let mut optimal_s = 1.0;

let mut iter = 0;
Expand Down Expand Up @@ -284,6 +293,10 @@ fn smooth_and_fill(
}
_ => {}
}
init_s0 = init_s0
.iter()
.map(|&v| v.clamp(MIN_INIT_S0, MAX_INIT_S0))
.collect();
Ok(init_s0[0..=3].try_into().unwrap())
}

Expand Down Expand Up @@ -341,7 +354,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
let expected = [(4, 1.4098487)].into_iter().collect();
let expected = [(4, 1.4877763)].into_iter().collect();
assert_eq!(actual, expected);
}

Expand All @@ -353,7 +366,7 @@ mod tests {
let pretrainset = split_data(items, 1).0;
assert_eq!(
pretrain(pretrainset, average_recall).unwrap(),
[0.948_268_3, 1.695_154, 4.051_595_7, 9.332_188,],
[0.9517492, 1.7152255, 4.149725, 9.399195,],
)
}

Expand Down

0 comments on commit 07621d4

Please sign in to comment.