Skip to content

Commit

Permalink
Feat/filter outlier in trainset (#119)
Browse files Browse the repository at this point in the history
Co-authored-by: AsukaMinato <[email protected]>
  • Loading branch information
L-M-Sherlock and asukaminato0721 authored Nov 21, 2023
1 parent df5bad7 commit a9cc36a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
# older versions may fail to compile; newer versions may fail the clippy tests
channel = "1.73"
channel = "1.74"
components = ["rustfmt", "clippy"]
53 changes: 38 additions & 15 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use burn::data::dataloader::batcher::Batcher;
use burn::{
Expand Down Expand Up @@ -135,44 +135,52 @@ impl From<Vec<FSRSItem>> for FSRSDataset {
}
}

pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
pub fn filter_outlier(
pretrainset: Vec<FSRSItem>,
mut trainset: Vec<FSRSItem>,
) -> (Vec<FSRSItem>, Vec<FSRSItem>) {
let mut groups = HashMap::<u32, HashMap<u32, Vec<FSRSItem>>>::new();

// 首先按照第一个 review 的 rating 和第二个 review 的 delta 进行分组
for item in items.iter() {
// group by rating of first review and delta_t of second review
for item in pretrainset.into_iter() {
let (first_review, second_review) = (item.reviews.first().unwrap(), item.current());
let rating_group = groups.entry(first_review.rating).or_default();
let delta_t_group = rating_group.entry(second_review.delta_t).or_default();
delta_t_group.push(item.clone());
delta_t_group.push(item);
}

let mut filtered_items = vec![];
let mut removed_pairs: [HashSet<_>; 5] = Default::default();

// 对每个按 rating 分组的子组进一步处理
for (_rating, delta_t_groups) in groups.iter() {
let mut sub_groups = delta_t_groups.iter().collect::<Vec<_>>();
for (rating, delta_t_groups) in groups.into_iter() {
let mut sub_groups = delta_t_groups.into_iter().collect::<Vec<_>>();

// 按子组大小升序排序,大小相同的按 delta_t 降序排序
// order by size of sub group ascending and delta_t descending
sub_groups.sort_by(|(delta_t_a, subv_a), (delta_t_b, subv_b)| {
subv_b
.len()
.cmp(&subv_a.len())
.then(delta_t_a.cmp(delta_t_b))
.then(delta_t_b.cmp(delta_t_a))
});

// 计算总大小
// remove 5% of items from each sub group
let total = sub_groups.iter().map(|(_, vec)| vec.len()).sum::<usize>();
let mut has_been_removed = 0;

for (_delta_t, sub_group) in sub_groups.iter().rev() {
for (delta_t, sub_group) in sub_groups.iter().rev() {
if has_been_removed + sub_group.len() > total / 20 {
filtered_items.extend_from_slice(sub_group);
} else {
has_been_removed += sub_group.len();
removed_pairs[rating as usize].insert(*delta_t);
}
}
}
filtered_items
// keep the items in trainset if they are not removed from filtered_items
trainset.retain(|item| {
!removed_pairs[item.reviews[0].rating as usize].contains(&item.reviews[1].delta_t)
});
(filtered_items, trainset)
}

fn stratified_kfold(mut trainset: Vec<FSRSItem>, n_splits: usize) -> Vec<Vec<FSRSItem>> {
Expand All @@ -190,9 +198,11 @@ pub fn split_data(
items: Vec<FSRSItem>,
n_splits: usize,
) -> (Vec<FSRSItem>, Vec<Vec<FSRSItem>>, Vec<FSRSItem>) {
let (pretrainset, trainset) = items.into_iter().partition(|item| item.reviews.len() == 2);
let (mut pretrainset, mut trainset) =
items.into_iter().partition(|item| item.reviews.len() == 2);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
(
filter_outlier(pretrainset),
pretrainset,
stratified_kfold(trainset.clone(), n_splits),
trainset,
)
Expand Down Expand Up @@ -406,4 +416,17 @@ mod tests {
);
assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1]));
}

#[test]
fn test_filter_outlier() {
let dataset = anki21_sample_file_converted_to_fsrs();
let (mut pretrainset, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = dataset
.into_iter()
.partition(|item| item.reviews.len() == 2);
assert_eq!(pretrainset.len(), 3315);
assert_eq!(trainset.len(), 10806);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
assert_eq!(pretrainset.len(), 3265);
assert_eq!(trainset.len(), 10731);
}
}

0 comments on commit a9cc36a

Please sign in to comment.