Skip to content

Commit

Permalink
move retain out the loop
Browse files Browse the repository at this point in the history
  • Loading branch information
asukaminato0721 committed Nov 21, 2023
1 parent c0af31a commit c9756fd
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use burn::data::dataloader::batcher::Batcher;
use burn::{
Expand Down Expand Up @@ -150,6 +150,7 @@ pub fn filter_outlier(
}

let mut filtered_items = vec![];
let mut to_rm_index = HashSet::new();

for (rating, delta_t_groups) in groups.into_iter() {
let mut sub_groups = delta_t_groups.into_iter().collect::<Vec<_>>();
Expand All @@ -165,30 +166,19 @@ pub fn filter_outlier(
// remove 5% of items from each sub group
let total = sub_groups.iter().map(|(_, vec)| vec.len()).sum::<usize>();
let mut has_been_removed = 0;
let mut to_rm_index = vec![false; trainset.len()];

for (delta_t, sub_group) in sub_groups.iter().rev() {
if has_been_removed + sub_group.len() > total / 20 {
filtered_items.extend_from_slice(sub_group);
} else {
has_been_removed += sub_group.len();
trainset
.iter()
.enumerate() //
.filter(|(.., item)| {
item.reviews[0].rating == rating && item.reviews[1].delta_t == *delta_t
})
.for_each(|(index, ..)| to_rm_index[index] = true);
to_rm_index.insert((rating, *delta_t));
}
}
// keep the items in trainset if they are not removed from filtered_items
trainset = trainset
.into_iter()
.enumerate()
.filter(|(index, ..)| !to_rm_index[*index])
.map(|(.., x)| x)
.collect();
}
trainset
.retain(|item| !to_rm_index.contains(&(item.reviews[0].rating, item.reviews[1].delta_t)));
(filtered_items, trainset)
}

Expand Down

0 comments on commit c9756fd

Please sign in to comment.