From 8b35144001ac043a7335a73305af76b4db93e47c Mon Sep 17 00:00:00 2001 From: PhilipMay Date: Sat, 5 Nov 2022 09:14:56 +0100 Subject: [PATCH 1/3] add NoDuplicateClassesDataLoader --- src/setfit/data.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/setfit/data.py b/src/setfit/data.py index abb6d8a2..e50a1559 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -280,3 +280,43 @@ def collate_fn(batch): labels = torch.Tensor(labels).long() return features, labels + + +class NoDuplicateClassesDataLoader: + + def __init__(self, train_examples, batch_size): + self.batch_size = batch_size + self.collate_fn = None + self.train_examples = train_examples + + # TODO: add assert batch_size <= num_classes + + def __iter__(self): + label_class_dict = {} + random.shuffle(self.train_examples) + for example in self.train_examples: + example_label_list = label_class_dict.get(example.label, []) + example_label_list.append(example) + label_class_dict[example.label] = example_label_list + + for _ in range(self.__len__()): + batch = [] + classes_in_batch = set() + + while len(batch) < self.batch_size: + class_to_add = random.choice(label_class_dict.keys()) + if class_to_add not in classes_in_batch: + example = label_class_dict[class_to_add].pop(0) + batch.append(example) + + # list of examples for this class is empty and needs to be refilled + if len(label_class_dict[class_to_add]) == 0: + random.shuffle(self.train_examples) + for example in self.train_examples: + if example.label == class_to_add: + label_class_dict[class_to_add].append(example) + + yield self.collate_fn(batch) if self.collate_fn is not None else batch + + def __len__(self): + return math.floor(len(self.train_examples) / self.batch_size) From e9e4195221603df641e450ca1e50dfffc6ee6f57 Mon Sep 17 00:00:00 2001 From: PhilipMay Date: Sat, 5 Nov 2022 09:25:36 +0100 Subject: [PATCH 2/3] fix formatting --- src/setfit/data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/setfit/data.py b/src/setfit/data.py index e50a1559..aa1c5007 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -283,7 +283,6 @@ def collate_fn(batch): class NoDuplicateClassesDataLoader: - def __init__(self, train_examples, batch_size): self.batch_size = batch_size self.collate_fn = None From 2f3fdb6e8d4c4030cdc3e93b139975cd7d2c0825 Mon Sep 17 00:00:00 2001 From: PhilipMay Date: Sat, 5 Nov 2022 09:30:36 +0100 Subject: [PATCH 3/3] add missing imports --- src/setfit/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/setfit/data.py b/src/setfit/data.py index aa1c5007..275591c0 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -1,3 +1,5 @@ +import math +import random from typing import TYPE_CHECKING, Dict, List, Tuple import pandas as pd