Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/filter outlier in trainset #119

Merged
merged 8 commits into from
Nov 21, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,14 @@ impl From<Vec<FSRSItem>> for FSRSDataset {
}
}

pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
pub fn filter_outlier(
pretrainset: Vec<FSRSItem>,
mut trainset: Vec<FSRSItem>,
) -> (Vec<FSRSItem>, Vec<FSRSItem>) {
let mut groups = HashMap::<u32, HashMap<u32, Vec<FSRSItem>>>::new();

// 首先按照第一个 review 的 rating 和第二个 review 的 delta 进行分组
for item in items.iter() {
for item in pretrainset.iter() {
asukaminato0721 marked this conversation as resolved.
Show resolved Hide resolved
let (first_review, second_review) = (item.reviews.first().unwrap(), item.current());
let rating_group = groups.entry(first_review.rating).or_default();
let delta_t_group = rating_group.entry(second_review.delta_t).or_default();
Expand All @@ -149,7 +152,7 @@ pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
let mut filtered_items = vec![];

// 对每个按 rating 分组的子组进一步处理
for (_rating, delta_t_groups) in groups.iter() {
for (rating, delta_t_groups) in groups.iter() {
let mut sub_groups = delta_t_groups.iter().collect::<Vec<_>>();

// 按子组大小升序排序,大小相同的按 delta_t 降序排序
Expand All @@ -164,15 +167,21 @@ pub fn filter_outlier(items: Vec<FSRSItem>) -> Vec<FSRSItem> {
let total = sub_groups.iter().map(|(_, vec)| vec.len()).sum::<usize>();
let mut has_been_removed = 0;

for (_delta_t, sub_group) in sub_groups.iter().rev() {
for (delta_t, sub_group) in sub_groups.iter().rev() {
if has_been_removed + sub_group.len() > total / 20 {
filtered_items.extend_from_slice(sub_group);
} else {
has_been_removed += sub_group.len();
// 删除 trainset 中第一个 review 的 rating 等于 rating,第二个 review 的 delta_t 等于 delta_t 的 item
trainset.retain(|item| {
let (first_review, second_review) =
(item.reviews.first().unwrap(), item.reviews.get(1).unwrap());
first_review.rating != *rating || second_review.delta_t != **delta_t
L-M-Sherlock marked this conversation as resolved.
Show resolved Hide resolved
});
asukaminato0721 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
filtered_items
(filtered_items, trainset)
}

fn stratified_kfold(mut trainset: Vec<FSRSItem>, n_splits: usize) -> Vec<Vec<FSRSItem>> {
Expand All @@ -190,9 +199,11 @@ pub fn split_data(
items: Vec<FSRSItem>,
n_splits: usize,
) -> (Vec<FSRSItem>, Vec<Vec<FSRSItem>>, Vec<FSRSItem>) {
let (pretrainset, trainset) = items.into_iter().partition(|item| item.reviews.len() == 2);
let (mut pretrainset, mut trainset) =
items.into_iter().partition(|item| item.reviews.len() == 2);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
(
filter_outlier(pretrainset),
pretrainset,
stratified_kfold(trainset.clone(), n_splits),
trainset,
)
Expand Down Expand Up @@ -406,4 +417,17 @@ mod tests {
);
assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1]));
}

#[test]
fn test_filter_outlier() {
let dataset = anki21_sample_file_converted_to_fsrs();
let (mut pretrainset, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = dataset
.into_iter()
.partition(|item| item.reviews.len() == 2);
assert_eq!(pretrainset.len(), 3315);
assert_eq!(trainset.len(), 10806);
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
assert_eq!(pretrainset.len(), 3265);
assert_eq!(trainset.len(), 10731);
}
}
Loading