Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/filter outlier in trainset #119

Merged
merged 8 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
# older versions may fail to compile; newer versions may fail the clippy tests
channel = "1.73"
channel = "1.74"
components = ["rustfmt", "clippy"]
53 changes: 38 additions & 15 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use burn::data::dataloader::batcher::Batcher;
use burn::{
Expand Down Expand Up @@ -135,44 +135,52 @@ impl From<Vec<FSRSItem>> for FSRSDataset {
}
}

pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
pub fn filter_outlier(
pretrainset: Vec<FSRSItem>,
mut trainset: Vec<FSRSItem>,
) -> (Vec<FSRSItem>, Vec<FSRSItem>) {
let mut groups = HashMap::<u32, HashMap<u32, Vec<FSRSItem>>>::new();

// 首先按照第一个 review 的 rating 和第二个 review 的 delta 进行分组
for item in items.iter() {
// group by rating of first review and delta_t of second review
for item in pretrainset.into_iter() {
let (first_review, second_review) = (item.reviews.first().unwrap(), item.current());
let rating_group = groups.entry(first_review.rating).or_default();
let delta_t_group = rating_group.entry(second_review.delta_t).or_default();
delta_t_group.push(item.clone());
delta_t_group.push(item);
}

let mut filtered_items = vec![];
let mut removed_pairs: [HashSet<_>; 5] = Default::default();

// 对每个按 rating 分组的子组进一步处理
for (_rating, delta_t_groups) in groups.iter() {
let mut sub_groups = delta_t_groups.iter().collect::<Vec<_>>();
for (rating, delta_t_groups) in groups.into_iter() {
let mut sub_groups = delta_t_groups.into_iter().collect::<Vec<_>>();

// 按子组大小升序排序,大小相同的按 delta_t 降序排序
// order by size of sub group ascending and delta_t descending
sub_groups.sort_by(|(delta_t_a, subv_a), (delta_t_b, subv_b)| {
subv_b
.len()
.cmp(&subv_a.len())
.then(delta_t_a.cmp(delta_t_b))
.then(delta_t_b.cmp(delta_t_a))
});

// 计算总大小
// remove 5% of items from each sub group
let total = sub_groups.iter().map(|(_, vec)| vec.len()).sum::<usize>();
let mut has_been_removed = 0;

for (_delta_t, sub_group) in sub_groups.iter().rev() {
for (delta_t, sub_group) in sub_groups.iter().rev() {
if has_been_removed + sub_group.len() > total / 20 {
filtered_items.extend_from_slice(sub_group);
} else {
has_been_removed += sub_group.len();
removed_pairs[rating as usize].insert(*delta_t);
}
}
}
filtered_items
// keep the items in trainset if they are not removed from filtered_items
trainset.retain(|item| {
!removed_pairs[item.reviews[0].rating as usize].contains(&item.reviews[1].delta_t)
});
(filtered_items, trainset)
}

fn stratified_kfold(mut trainset: Vec<FSRSItem>, n_splits: usize) -> Vec<Vec<FSRSItem>> {
Expand All @@ -190,9 +198,11 @@ pub fn split_data(
items: Vec<FSRSItem>,
n_splits: usize,
) -> (Vec<FSRSItem>, Vec<Vec<FSRSItem>>, Vec<FSRSItem>) {
let (pretrainset, trainset) = items.into_iter().partition(|item| item.reviews.len() == 2);
let (mut pretrainset, mut trainset) =
items.into_iter().partition(|item| item.reviews.len() == 2);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
(
filter_outlier(pretrainset),
pretrainset,
stratified_kfold(trainset.clone(), n_splits),
trainset,
)
Expand Down Expand Up @@ -406,4 +416,17 @@ mod tests {
);
assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1]));
}

#[test]
fn test_filter_outlier() {
let dataset = anki21_sample_file_converted_to_fsrs();
let (mut pretrainset, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = dataset
.into_iter()
.partition(|item| item.reviews.len() == 2);
assert_eq!(pretrainset.len(), 3315);
assert_eq!(trainset.len(), 10806);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
assert_eq!(pretrainset.len(), 3265);
assert_eq!(trainset.len(), 10731);
}
}
Loading