Skip to content

Commit

Permalink
Feat/filter outlier in trainset
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Nov 20, 2023
1 parent 6716304 commit 88d0794
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,14 @@ impl From<Vec<FSRSItem>> for FSRSDataset {
}
}

pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
pub fn filter_outlier(
pretrainset: Vec<FSRSItem>,
mut trainset: Vec<FSRSItem>,
) -> (Vec<FSRSItem>, Vec<FSRSItem>) {
let mut groups = HashMap::<u32, HashMap<u32, Vec<FSRSItem>>>::new();

// 首先按照第一个 review 的 rating 和第二个 review 的 delta 进行分组
for item in items.iter() {
for item in pretrainset.iter() {
let (first_review, second_review) = (item.reviews.first().unwrap(), item.current());
let rating_group = groups.entry(first_review.rating).or_default();
let delta_t_group = rating_group.entry(second_review.delta_t).or_default();
Expand All @@ -149,7 +152,7 @@ pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
let mut filtered_items = vec![];

// 对每个按 rating 分组的子组进一步处理
for (_rating, delta_t_groups) in groups.iter() {
for (rating, delta_t_groups) in groups.iter() {
let mut sub_groups = delta_t_groups.iter().collect::<Vec<_>>();

// 按子组大小升序排序,大小相同的按 delta_t 降序排序
Expand All @@ -164,15 +167,21 @@ pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
let total = sub_groups.iter().map(|(_, vec)| vec.len()).sum::<usize>();
let mut has_been_removed = 0;

for (_delta_t, sub_group) in sub_groups.iter().rev() {
for (delta_t, sub_group) in sub_groups.iter().rev() {
if has_been_removed + sub_group.len() > total / 20 {
filtered_items.extend_from_slice(sub_group);
} else {
has_been_removed += sub_group.len();
// 删除 trainset 中第一个 review 的 rating 等于 rating,第二个 review 的 delta_t 等于 delta_t 的 item
trainset.retain(|item| {
let (first_review, second_review) =
(item.reviews.first().unwrap(), item.reviews.get(1).unwrap());
first_review.rating != *rating || second_review.delta_t != **delta_t
});
}
}
}
filtered_items
(filtered_items, trainset)
}

fn stratified_kfold(mut trainset: Vec<FSRSItem>, n_splits: usize) -> Vec<Vec<FSRSItem>> {
Expand All @@ -190,9 +199,11 @@ pub fn split_data(
items: Vec<FSRSItem>,
n_splits: usize,
) -> (Vec<FSRSItem>, Vec<Vec<FSRSItem>>, Vec<FSRSItem>) {
let (pretrainset, trainset) = items.into_iter().partition(|item| item.reviews.len() == 2);
let (mut pretrainset, mut trainset) =
items.into_iter().partition(|item| item.reviews.len() == 2);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
(
filter_outlier(pretrainset),
pretrainset,
stratified_kfold(trainset.clone(), n_splits),
trainset,
)
Expand Down Expand Up @@ -406,4 +417,17 @@ mod tests {
);
assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1]));
}

#[test]
fn test_filter_outlier() {
let dataset = anki21_sample_file_converted_to_fsrs();
let (mut pretrainset, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = dataset
.into_iter()
.partition(|item| item.reviews.len() == 2);
dbg!(pretrainset.len());
dbg!(trainset.len());
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
dbg!(pretrainset.len());
dbg!(trainset.len());
}
}

0 comments on commit 88d0794

Please sign in to comment.