Skip to content

Commit

Permalink
Fix/remove append_default_point to improve accuracy (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Dec 21, 2023
1 parent 8ff6ccf commit f45f46b
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,6 @@ fn loss(
logloss + l1
}

fn append_default_point(data: &mut Vec<AverageRecall>, default_s0: f32) {
data.push(AverageRecall {
delta_t: default_s0,
recall: 0.9,
count: 16.0,
});
}

const MIN_INIT_S0: f32 = 0.1;
const MAX_INIT_S0: f32 = 100.0;

Expand All @@ -127,11 +119,7 @@ fn search_parameters(
for (first_rating, data) in &mut pretrainset {
let r_s0_default: HashMap<u32, f32> = R_S0_DEFAULT_ARRAY.iter().cloned().collect();
let default_s0 = r_s0_default[first_rating];

append_default_point(data, default_s0);

let delta_t = Array1::from_iter(data.iter().map(|d| d.delta_t));

let recall = {
// Laplace smoothing
// (real_recall * n + average_recall * 1) / (n + 1)
Expand All @@ -140,7 +128,6 @@ fn search_parameters(
let n = data.iter().map(|d| d.count).sum::<f32>();
(real_recall * n + average_recall) / (n + 1.0)
};

let count = Array1::from_iter(data.iter().map(|d| d.count));

let mut low = MIN_INIT_S0;
Expand Down Expand Up @@ -340,7 +327,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
Data::from([*actual.get(&4).unwrap()]).assert_approx_eq(&Data::from([1.230_132_3]), 4);
Data::from([*actual.get(&4).unwrap()]).assert_approx_eq(&Data::from([0.944_284]), 4);
}

#[test]
Expand All @@ -349,10 +336,8 @@ mod tests {
let items = anki21_sample_file_converted_to_fsrs();
let average_recall = calculate_average_recall(&items);
let pretrainset = split_data(items, 1).0;
Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq(
&Data::from([0.956_017_43, 1.694_406_5, 3.998_023_5, 8.268_223]),
4,
)
Data::from(pretrain(pretrainset, average_recall).unwrap())
.assert_approx_eq(&Data::from([1.001_276, 1.811_072, 4.405_640, 8.532_001]), 4)
}

#[test]
Expand Down

0 comments on commit f45f46b

Please sign in to comment.