Skip to content

Commit

Permalink
Add a pickle dataset, that contains predetermined batches
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed May 16, 2024
1 parent 072d850 commit 78d353b
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 94 deletions.
40 changes: 40 additions & 0 deletions experiments/test_pickle_set/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<!--
Copyright (c) 2024 Hannah contributors.
This file is part of hannah.
See https://github.com/ekut-es/hannah for further info.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Test Pickle dataset

A simple test implementation for pickled datasets.

The pickled datasets are expected to contain a tuple of numpy arrays.

The first array contains the (preprocessed) input data, the second array contains the target class ids as int32 values.

## Creating Test Data

The following creates test, val and train datasets with 400, 400 and 4000 samples respectively.
The data is randomly initialized, and the classes are also randomly attached to a number of 2.

python create_sample.py --size 400 --dim='(20,17)' --classes=2 test.pkl
python create_sample.py --size 400 --dim='(20,17)' --classes=2 val.pkl
python create_sample.py --size 4000 --dim='(20,17)' --classes=2 train.pkl

## Training

This then runs a training on tc-res8

hannah-train
28 changes: 28 additions & 0 deletions experiments/test_pickle_set/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
##
## Copyright (c) 2022 University of Tübingen.
##
## This file is part of hannah.
## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info.
##
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
##
## http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.
##


defaults:
- base_config # Base configuration uses a single neural network training and kws dataset
- override dataset: test_pickle_set # Override the dataset to use the test_pickle_set dataset
- override features: raw # Override the features to not use any preprocessing
- _self_ # This is a special value that specifies that values defined in this file take precedence over values from the other files

trainer: # Trainer arguments set hyperparameters for all trainings
max_epochs: 30
50 changes: 50 additions & 0 deletions experiments/test_pickle_set/create_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import os
import pickle

import numpy

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="This script converts a pickle file to a numpy file."
)
parser.add_argument("pickle_file", help="The pickle file to create.")
parser.add_argument("--size", help="The number of samples in the dataset.")
parser.add_argument(
"--dim",
help='The dimension of the samples, in the form of a tuple e.g. "(3, 32, 32)"',
)
parser.add_argument("--classes", help="The number of classes in the dataset.")

args = parser.parse_args()

size = int(args.size)
dim = tuple(map(int, args.dim.strip("()").split(",")))
classes = int(args.classes)

with open(args.pickle_file, "wb") as f:
pickle.dump(
(
numpy.random.rand(size, *dim).astype(numpy.float32),
numpy.random.randint(0, classes, size),
),
f,
)
11 changes: 11 additions & 0 deletions experiments/test_pickle_set/dataset/test_pickle_set.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cls: hannah.datasets.pickle_set.PickleDataset
train:
- ${hydra:runtime.cwd}/train.pkl

val:
- ${hydra:runtime.cwd}/val.pkl

test:
- ${hydra:runtime.cwd}/test.pkl

samplingrate: 16000
14 changes: 9 additions & 5 deletions hannah/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023 Hannah contributors.
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
Expand Down Expand Up @@ -36,7 +36,8 @@ class DatasetType(Enum):


class AbstractDataset(Dataset, ABC):
@abstractclassmethod
@classmethod
@abstractmethod
def prepare(cls, config: Dict[str, Any]) -> None:
"""Prepare the dataset.
Expand All @@ -52,7 +53,8 @@ def prepare(cls, config: Dict[str, Any]) -> None:

pass

@abstractclassmethod
@classmethod
@abstractmethod
def splits(
cls, config: Dict[str, Any]
) -> Tuple["AbstractDataset", "AbstractDataset", "AbstractDataset"]:
Expand All @@ -64,7 +66,8 @@ def splits(

pass # pytype: disable=bad-return-type

@abstractproperty
@property
@abstractmethod
def class_names(self) -> List[str]:
"""Returns the names of the classes in the classification dataset"""
pass # pytype: disable=bad-return-type
Expand All @@ -81,7 +84,8 @@ def class_names_abbreviated(self) -> List[str]:

return self.class_names

@abstractproperty
@property
@abstractmethod
def class_counts(self) -> Optional[Dict[int, int]]:
"""Returns the number of items in each class of the dataset
Expand Down
90 changes: 61 additions & 29 deletions hannah/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,39 +150,68 @@ def val_dataloader(self):
return self._get_dataloader(self.dev_set, self.dev_set_unlabeled)

def _get_dataloader(self, dataset, unlabeled_data=None, shuffle=False):
dataset_conf = self.hparams.dataset
sampler = None
if shuffle:
sampler_type = dataset_conf.get("sampler", "random")
if sampler_type == "weighted":
sampler = self.get_balancing_sampler(dataset)
else:
sampler = data.RandomSampler(dataset)

loader = data.DataLoader(
dataset,
batch_size=self.batch_size,
drop_last=True,
num_workers=self.hparams["num_workers"],
sampler=sampler if not dataset.sequential else None,
multiprocessing_context="fork" if self.hparams["num_workers"] > 0 else None,
)
batch_size = self.hparams["batch_size"]

def calc_workers(dataset):
result = (
num_workers
if num_workers <= dataset.max_workers or dataset.max_workers == -1
else dataset.max_workers
)
return result

if hasattr(dataset, "loader"):
loader = dataset.loader(batch_size, shuffle=shuffle)
else:
# FIXME: don't use hparams here
dataset_conf = self.hparams.dataset
sampler = None
if shuffle:
sampler_type = dataset_conf.get("sampler", "random")
if sampler_type == "weighted":
sampler = self.get_balancing_sampler(dataset)
else:
sampler = data.RandomSampler(dataset)

num_workers = self.hparams["num_workers"]

num_workers = calc_workers(dataset)

loader = data.DataLoader(
dataset,
batch_size=batch_size,
drop_last=True,
num_workers=num_workers,
sampler=sampler if not dataset.sequential else None,
collate_fn=self.collate_fn,
multiprocessing_context="fork" if num_workers > 0 else None,
persistent_workers=True if num_workers > 0 else False,
prefetch_factor=2 if num_workers > 0 else None,
pin_memory=True,
)

self.batches_per_epoch = len(loader)

if unlabeled_data:
loader_unlabeled = data.DataLoader(
unlabeled_data,
batch_size=self.batch_size,
drop_last=True,
num_workers=self.hparams["num_workers"],
sampler=data.RandomSampler(unlabeled_data)
if not dataset.sequential
else None,
multiprocessing_context="fork"
if self.hparams["num_workers"] > 0
else None,
if hasattr(unlabeled_data, "loader"):
unlabeled_data = unlabeled_data.loader(batch_size, shuffle=shuffle)
else:
unlabeled_workers = calc_workers(unlabeled_data)
loader_unlabeled = data.DataLoader(
unlabeled_data,
batch_size=batch_size,
drop_last=True,
num_workers=unlabeled_workers,
sampler=data.RandomSampler(unlabeled_data)
if not unlabeled_data.sequential
else None,
multiprocessing_context="fork" if unlabeled_workers > 0 else None,
)

return CombinedLoader(
{"labeled": loader, "unlabeled": loader_unlabeled},
mode="max_size_cycle",
)
return CombinedLoader({"labeled": loader, "unlabeled": loader_unlabeled})

return loader

Expand Down Expand Up @@ -446,3 +475,6 @@ def _setup_loss_weights(self):
return loss_weights

return None

def collate_fn(self, batch):
return torch.utils.data.default_collate(batch)
5 changes: 4 additions & 1 deletion hannah/modules/classifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023 Hannah contributors.
# Copyright (c) 2024 Hannah contributors.
#
# This file is part of hannah.
# See https://github.com/ekut-es/hannah for further info.
Expand Down Expand Up @@ -396,6 +396,9 @@ def _log_audio(self, x, logits, y):
)
self.logged_samples += 1

def collate_fn(self, batch):
return ctc_collate_fn(batch)


class StreamClassifierModule(BaseStreamClassifierModule):
def get_class_names(self):
Expand Down
62 changes: 3 additions & 59 deletions hannah/modules/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,65 +293,6 @@ def setup_augmentations(self, pipeline_configs):

return augmentations

def _get_dataloader(self, dataset, unlabeled_data=None, shuffle=False):
batch_size = self.hparams["batch_size"]

# FIXME: don't use hparams here
dataset_conf = self.hparams.dataset
sampler = None
if shuffle:
sampler_type = dataset_conf.get("sampler", "random")
if sampler_type == "weighted":
sampler = self.get_balancing_sampler(dataset)
else:
sampler = data.RandomSampler(dataset)

num_workers = self.hparams["num_workers"]

def calc_workers(dataset):
result = (
num_workers
if num_workers <= dataset.max_workers or dataset.max_workers == -1
else dataset.max_workers
)
return result

num_workers = calc_workers(dataset)

loader = data.DataLoader(
dataset,
batch_size=batch_size,
drop_last=True,
num_workers=num_workers,
sampler=sampler if not dataset.sequential else None,
collate_fn=vision_collate_fn,
multiprocessing_context="fork" if num_workers > 0 else None,
persistent_workers=True if num_workers > 0 else False,
prefetch_factor=2 if num_workers > 0 else None,
pin_memory=True,
)
self.batches_per_epoch = len(loader)

if unlabeled_data:
unlabeled_workers = calc_workers(unlabeled_data)
loader_unlabeled = data.DataLoader(
unlabeled_data,
batch_size=batch_size,
drop_last=True,
num_workers=unlabeled_workers,
sampler=data.RandomSampler(unlabeled_data)
if not unlabeled_data.sequential
else None,
multiprocessing_context="fork" if unlabeled_workers > 0 else None,
)

return CombinedLoader(
{"labeled": loader, "unlabeled": loader_unlabeled},
mode="max_size_cycle",
)

return loader

@property
def backbone(self):
if self.model is None:
Expand All @@ -363,3 +304,6 @@ def backbone(self):
return self.model.encoder
else:
raise AttributeError("No backbone found in model")

def collate_fn(self, batch):
return vision_collate_fn(batch)

0 comments on commit 78d353b

Please sign in to comment.