From 78d353bf41705d9a4ebf6047592fdcfd537b5370 Mon Sep 17 00:00:00 2001 From: Christoph Gerum Date: Thu, 16 May 2024 12:46:26 +0200 Subject: [PATCH] Add a pickle dataset, that contains predetermined batches --- experiments/test_pickle_set/README.md | 40 +++++++++ experiments/test_pickle_set/config.yaml | 28 ++++++ experiments/test_pickle_set/create_sample.py | 50 +++++++++++ .../dataset/test_pickle_set.yaml | 11 +++ hannah/datasets/base.py | 14 +-- hannah/modules/base.py | 90 +++++++++++++------ hannah/modules/classifier.py | 5 +- hannah/modules/vision/base.py | 62 +------------ 8 files changed, 206 insertions(+), 94 deletions(-) create mode 100644 experiments/test_pickle_set/README.md create mode 100644 experiments/test_pickle_set/config.yaml create mode 100644 experiments/test_pickle_set/create_sample.py create mode 100644 experiments/test_pickle_set/dataset/test_pickle_set.yaml diff --git a/experiments/test_pickle_set/README.md b/experiments/test_pickle_set/README.md new file mode 100644 index 00000000..62c1f398 --- /dev/null +++ b/experiments/test_pickle_set/README.md @@ -0,0 +1,40 @@ + +# Test Pickle dataset + +A simple test implementation for pickled datasets. + +The pickled datasets are expected to contain a tuple of numpy arrays. + +The first array contains the (preprocessed) input data, the second array contains the target class ids as int32 values. + +## Creating Test Data + +The following creates test, val and train datasets with 400, 400 and 4000 samples respectively. +The data is randomly initialized, and the classes are also randomly attached to a number of 2. + + python create_sample.py --size 400 --dim='(20,17)' --classes=2 test.pkl + python create_sample.py --size 400 --dim='(20,17)' --classes=2 val.pkl + python create_sample.py --size 4000 --dim='(20,17)' --classes=2 train.pkl + +## Training + +This then runs a training on tc-res8 + + hannah-train diff --git a/experiments/test_pickle_set/config.yaml b/experiments/test_pickle_set/config.yaml new file mode 100644 index 00000000..0c89df91 --- /dev/null +++ b/experiments/test_pickle_set/config.yaml @@ -0,0 +1,28 @@ +## +## Copyright (c) 2022 University of Tübingen. +## +## This file is part of hannah. +## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## + + +defaults: + - base_config # Base configuration uses a single neural network training and kws dataset + - override dataset: test_pickle_set # Override the dataset to use the test_pickle_set dataset + - override features: raw # Override the features to not use any preprocessing + - _self_ # This is a special value that specifies that values defined in this file take precedence over values from the other files + +trainer: # Trainer arguments set hyperparameters for all trainings + max_epochs: 30 diff --git a/experiments/test_pickle_set/create_sample.py b/experiments/test_pickle_set/create_sample.py new file mode 100644 index 00000000..69fd72de --- /dev/null +++ b/experiments/test_pickle_set/create_sample.py @@ -0,0 +1,50 @@ +# +# Copyright (c) 2024 Hannah contributors. +# +# This file is part of hannah. +# See https://github.com/ekut-es/hannah for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +import os +import pickle + +import numpy + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script converts a pickle file to a numpy file." + ) + parser.add_argument("pickle_file", help="The pickle file to create.") + parser.add_argument("--size", help="The number of samples in the dataset.") + parser.add_argument( + "--dim", + help='The dimension of the samples, in the form of a tuple e.g. "(3, 32, 32)"', + ) + parser.add_argument("--classes", help="The number of classes in the dataset.") + + args = parser.parse_args() + + size = int(args.size) + dim = tuple(map(int, args.dim.strip("()").split(","))) + classes = int(args.classes) + + with open(args.pickle_file, "wb") as f: + pickle.dump( + ( + numpy.random.rand(size, *dim).astype(numpy.float32), + numpy.random.randint(0, classes, size), + ), + f, + ) diff --git a/experiments/test_pickle_set/dataset/test_pickle_set.yaml b/experiments/test_pickle_set/dataset/test_pickle_set.yaml new file mode 100644 index 00000000..dbb0c87e --- /dev/null +++ b/experiments/test_pickle_set/dataset/test_pickle_set.yaml @@ -0,0 +1,11 @@ +cls: hannah.datasets.pickle_set.PickleDataset +train: + - ${hydra:runtime.cwd}/train.pkl + +val: + - ${hydra:runtime.cwd}/val.pkl + +test: + - ${hydra:runtime.cwd}/test.pkl + +samplingrate: 16000 diff --git a/hannah/datasets/base.py b/hannah/datasets/base.py index 01b78dc6..c37eedf8 100644 --- a/hannah/datasets/base.py +++ b/hannah/datasets/base.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023 Hannah contributors. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. # See https://github.com/ekut-es/hannah for further info. @@ -36,7 +36,8 @@ class DatasetType(Enum): class AbstractDataset(Dataset, ABC): - @abstractclassmethod + @classmethod + @abstractmethod def prepare(cls, config: Dict[str, Any]) -> None: """Prepare the dataset. @@ -52,7 +53,8 @@ def prepare(cls, config: Dict[str, Any]) -> None: pass - @abstractclassmethod + @classmethod + @abstractmethod def splits( cls, config: Dict[str, Any] ) -> Tuple["AbstractDataset", "AbstractDataset", "AbstractDataset"]: @@ -64,7 +66,8 @@ def splits( pass # pytype: disable=bad-return-type - @abstractproperty + @property + @abstractmethod def class_names(self) -> List[str]: """Returns the names of the classes in the classification dataset""" pass # pytype: disable=bad-return-type @@ -81,7 +84,8 @@ def class_names_abbreviated(self) -> List[str]: return self.class_names - @abstractproperty + @property + @abstractmethod def class_counts(self) -> Optional[Dict[int, int]]: """Returns the number of items in each class of the dataset diff --git a/hannah/modules/base.py b/hannah/modules/base.py index 71434e93..4115f54c 100644 --- a/hannah/modules/base.py +++ b/hannah/modules/base.py @@ -150,39 +150,68 @@ def val_dataloader(self): return self._get_dataloader(self.dev_set, self.dev_set_unlabeled) def _get_dataloader(self, dataset, unlabeled_data=None, shuffle=False): - dataset_conf = self.hparams.dataset - sampler = None - if shuffle: - sampler_type = dataset_conf.get("sampler", "random") - if sampler_type == "weighted": - sampler = self.get_balancing_sampler(dataset) - else: - sampler = data.RandomSampler(dataset) - - loader = data.DataLoader( - dataset, - batch_size=self.batch_size, - drop_last=True, - num_workers=self.hparams["num_workers"], - sampler=sampler if not dataset.sequential else None, - multiprocessing_context="fork" if self.hparams["num_workers"] > 0 else None, - ) + batch_size = self.hparams["batch_size"] + + def calc_workers(dataset): + result = ( + num_workers + if num_workers <= dataset.max_workers or dataset.max_workers == -1 + else dataset.max_workers + ) + return result + + if hasattr(dataset, "loader"): + loader = dataset.loader(batch_size, shuffle=shuffle) + else: + # FIXME: don't use hparams here + dataset_conf = self.hparams.dataset + sampler = None + if shuffle: + sampler_type = dataset_conf.get("sampler", "random") + if sampler_type == "weighted": + sampler = self.get_balancing_sampler(dataset) + else: + sampler = data.RandomSampler(dataset) + + num_workers = self.hparams["num_workers"] + + num_workers = calc_workers(dataset) + + loader = data.DataLoader( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers, + sampler=sampler if not dataset.sequential else None, + collate_fn=self.collate_fn, + multiprocessing_context="fork" if num_workers > 0 else None, + persistent_workers=True if num_workers > 0 else False, + prefetch_factor=2 if num_workers > 0 else None, + pin_memory=True, + ) + self.batches_per_epoch = len(loader) if unlabeled_data: - loader_unlabeled = data.DataLoader( - unlabeled_data, - batch_size=self.batch_size, - drop_last=True, - num_workers=self.hparams["num_workers"], - sampler=data.RandomSampler(unlabeled_data) - if not dataset.sequential - else None, - multiprocessing_context="fork" - if self.hparams["num_workers"] > 0 - else None, + if hasattr(unlabeled_data, "loader"): + unlabeled_data = unlabeled_data.loader(batch_size, shuffle=shuffle) + else: + unlabeled_workers = calc_workers(unlabeled_data) + loader_unlabeled = data.DataLoader( + unlabeled_data, + batch_size=batch_size, + drop_last=True, + num_workers=unlabeled_workers, + sampler=data.RandomSampler(unlabeled_data) + if not unlabeled_data.sequential + else None, + multiprocessing_context="fork" if unlabeled_workers > 0 else None, + ) + + return CombinedLoader( + {"labeled": loader, "unlabeled": loader_unlabeled}, + mode="max_size_cycle", ) - return CombinedLoader({"labeled": loader, "unlabeled": loader_unlabeled}) return loader @@ -446,3 +475,6 @@ def _setup_loss_weights(self): return loss_weights return None + + def collate_fn(self, batch): + return torch.utils.data.default_collate(batch) diff --git a/hannah/modules/classifier.py b/hannah/modules/classifier.py index 60dc4add..754313a3 100644 --- a/hannah/modules/classifier.py +++ b/hannah/modules/classifier.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023 Hannah contributors. +# Copyright (c) 2024 Hannah contributors. # # This file is part of hannah. # See https://github.com/ekut-es/hannah for further info. @@ -396,6 +396,9 @@ def _log_audio(self, x, logits, y): ) self.logged_samples += 1 + def collate_fn(self, batch): + return ctc_collate_fn(batch) + class StreamClassifierModule(BaseStreamClassifierModule): def get_class_names(self): diff --git a/hannah/modules/vision/base.py b/hannah/modules/vision/base.py index abadde14..d208d07c 100644 --- a/hannah/modules/vision/base.py +++ b/hannah/modules/vision/base.py @@ -293,65 +293,6 @@ def setup_augmentations(self, pipeline_configs): return augmentations - def _get_dataloader(self, dataset, unlabeled_data=None, shuffle=False): - batch_size = self.hparams["batch_size"] - - # FIXME: don't use hparams here - dataset_conf = self.hparams.dataset - sampler = None - if shuffle: - sampler_type = dataset_conf.get("sampler", "random") - if sampler_type == "weighted": - sampler = self.get_balancing_sampler(dataset) - else: - sampler = data.RandomSampler(dataset) - - num_workers = self.hparams["num_workers"] - - def calc_workers(dataset): - result = ( - num_workers - if num_workers <= dataset.max_workers or dataset.max_workers == -1 - else dataset.max_workers - ) - return result - - num_workers = calc_workers(dataset) - - loader = data.DataLoader( - dataset, - batch_size=batch_size, - drop_last=True, - num_workers=num_workers, - sampler=sampler if not dataset.sequential else None, - collate_fn=vision_collate_fn, - multiprocessing_context="fork" if num_workers > 0 else None, - persistent_workers=True if num_workers > 0 else False, - prefetch_factor=2 if num_workers > 0 else None, - pin_memory=True, - ) - self.batches_per_epoch = len(loader) - - if unlabeled_data: - unlabeled_workers = calc_workers(unlabeled_data) - loader_unlabeled = data.DataLoader( - unlabeled_data, - batch_size=batch_size, - drop_last=True, - num_workers=unlabeled_workers, - sampler=data.RandomSampler(unlabeled_data) - if not unlabeled_data.sequential - else None, - multiprocessing_context="fork" if unlabeled_workers > 0 else None, - ) - - return CombinedLoader( - {"labeled": loader, "unlabeled": loader_unlabeled}, - mode="max_size_cycle", - ) - - return loader - @property def backbone(self): if self.model is None: @@ -363,3 +304,6 @@ def backbone(self): return self.model.encoder else: raise AttributeError("No backbone found in model") + + def collate_fn(self, batch): + return vision_collate_fn(batch)