diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index 6022d835..d535159b 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -57,57 +57,65 @@ jobs: - name: test_gen_data run: | source ${VENV}/bin/activate - mpirun -np 2 pytest -k test_gen_data[png-tensorflow] -v - mpirun -np 2 pytest -k test_gen_data[npz-tensorflow] -v - mpirun -np 2 pytest -k test_gen_data[jpeg-tensorflow] -v - mpirun -np 2 pytest -k test_gen_data[tfrecord-tensorflow] -v - mpirun -np 2 pytest -k test_gen_data[hdf5-tensorflow] -v + mpirun -np 2 pytest -k test_gen_data[dlio_png-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_gen_data[dlio_npz-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_gen_data[dlio_jpeg-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_gen_data[dlio_tfrecord-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_gen_data[dlio_hdf5-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_gen_data[dali_npz-pytorch-native_dali] -v + - name: test_custom_storage_root_gen_data run: | source ${VENV}/bin/activate - mpirun -np 2 pytest -k test_storage_root_gen_data[png-tensorflow] -v - mpirun -np 2 pytest -k test_storage_root_gen_data[npz-tensorflow] -v - mpirun -np 2 pytest -k test_storage_root_gen_data[jpeg-tensorflow] -v - mpirun -np 2 pytest -k test_storage_root_gen_data[tfrecord-tensorflow] -v - mpirun -np 2 pytest -k test_storage_root_gen_data[hdf5-tensorflow] -v + mpirun -np 2 pytest -k test_storage_root_gen_data[dlio_png-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_storage_root_gen_data[dlio_npz-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_storage_root_gen_data[dlio_jpeg-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_storage_root_gen_data[dlio_tfrecord-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_storage_root_gen_data[dlio_hdf5-tensorflow-dlio_tensorflow] -v + - name: test_train run: | source ${VENV}/bin/activate - mpirun -np 2 pytest -k test_train[png-tensorflow-tensorflow] -v - mpirun -np 2 pytest -k test_train[npz-tensorflow-tensorflow] -v - mpirun -np 2 pytest -k test_train[jpeg-tensorflow-tensorflow] -v - mpirun -np 2 pytest -k test_train[tfrecord-tensorflow-tensorflow] -v - mpirun -np 2 pytest -k test_train[hdf5-tensorflow-tensorflow] -v - mpirun -np 2 pytest -k test_train[csv-tensorflow-tensorflow] -v - mpirun -np 2 pytest -k test_train[png-pytorch-pytorch] -v - mpirun -np 2 pytest -k test_train[npz-pytorch-pytorch] -v - mpirun -np 2 pytest -k test_train[jpeg-pytorch-pytorch] -v - mpirun -np 2 pytest -k test_train[hdf5-pytorch-pytorch] -v - mpirun -np 2 pytest -k test_train[csv-pytorch-pytorch] -v - mpirun -np 2 pytest -k test_train[png-tensorflow-dali] -v - mpirun -np 2 pytest -k test_train[npz-tensorflow-dali] -v - mpirun -np 2 pytest -k test_train[jpeg-tensorflow-dali] -v - mpirun -np 2 pytest -k test_train[hdf5-tensorflow-dali] -v - mpirun -np 2 pytest -k test_train[csv-tensorflow-dali] -v - mpirun -np 2 pytest -k test_train[png-pytorch-dali] -v - mpirun -np 2 pytest -k test_train[npz-pytorch-dali] -v - mpirun -np 2 pytest -k test_train[jpeg-pytorch-dali] -v - mpirun -np 2 pytest -k test_train[hdf5-pytorch-dali] -v - mpirun -np 2 pytest -k test_train[csv-pytorch-dali] -v + mpirun -np 2 pytest -k test_train[dlio_png-tensorflow-dlio_tensorflow0] -v + mpirun -np 2 pytest -k test_train[dlio_npz-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_train[dlio_jpeg-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_train[dlio_tfrecord-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_train[dlio_hdf5-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_train[dlio_csv-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_train[dlio_png-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_train[dlio_npz-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_train[dlio_jpeg-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_train[dlio_hdf5-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_train[dlio_csv-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_train[dlio_png-tensorflow-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_npz-tensorflow-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_jpeg-tensorflow-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_hdf5-tensorflow-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_csv-tensorflow-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_png-pytorch-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_npz-pytorch-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_jpeg-pytorch-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_hdf5-pytorch-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_csv-pytorch-dlio_dali] -v + mpirun -np 2 pytest -k test_train[dlio_png-tensorflow-dlio_tensorflow1] -v + mpirun -np 2 pytest -k test_train[tf_tfrecord-tensorflow-native_tensorflow] -v + mpirun -np 2 pytest -k test_train[dali_tfrecord-pytorch-native_dali] -v + mpirun -np 2 pytest -k test_train[dali_npz-pytorch-native_dali] -v + - name: test_custom_storage_root_train run: | source ${VENV}/bin/activate - mpirun -np 2 pytest -k test_custom_storage_root_train[png-tensorflow] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[npz-tensorflow] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[jpeg-tensorflow] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[tfrecord-tensorflow] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[hdf5-tensorflow] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[csv-tensorflow] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[png-pytorch] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[npz-pytorch] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[jpeg-pytorch] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[hdf5-pytorch] -v - mpirun -np 2 pytest -k test_custom_storage_root_train[csv-pytorch] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_png-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_npz-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_jpeg-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_tfrecord-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_hdf5-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_csv-tensorflow-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_png-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_npz-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_jpeg-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_hdf5-pytorch-dlio_pytorch] -v + mpirun -np 2 pytest -k test_custom_storage_root_train[dlio_csv-pytorch-dlio_pytorch] -v - name: test_checkpoint_epoch run: | source ${VENV}/bin/activate @@ -123,12 +131,12 @@ jobs: - name: test_multi_threads run: | source ${VENV}/bin/activate - mpirun -np 2 pytest -k test_multi_threads[tensorflow-0] -v - mpirun -np 2 pytest -k test_multi_threads[tensorflow-1] -v - mpirun -np 2 pytest -k test_multi_threads[tensorflow-2] -v - mpirun -np 2 pytest -k test_multi_threads[pytorch-0] -v - mpirun -np 2 pytest -k test_multi_threads[pytorch-1] -v - mpirun -np 2 pytest -k test_multi_threads[pytorch-2] -v + mpirun -np 2 pytest -k test_multi_threads[tensorflow-0-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_multi_threads[tensorflow-1-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_multi_threads[tensorflow-2-dlio_tensorflow] -v + mpirun -np 2 pytest -k test_multi_threads[pytorch-0-dlio_pytorch] -v + mpirun -np 2 pytest -k test_multi_threads[pytorch-1-dlio_pytorch] -v + mpirun -np 2 pytest -k test_multi_threads[pytorch-2-dlio_pytorch] -v - name: test-tf-loader-tfrecord run: | source ${VENV}/bin/activate @@ -142,10 +150,10 @@ jobs: - name: test-tf-loader-npz run: | source ${VENV}/bin/activate - mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 - mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=dlio_tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 + mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=dlio_tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0 - name: test_subset run: | source ${VENV}/bin/activate mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=True ++workload.workflow.train=False - mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=False ++workload.workflow.train=True ++workload.dataset.num_files_train=8 \ No newline at end of file + mpirun -np 2 dlio_benchmark ++workload.workflow.generate_data=False ++workload.workflow.train=True ++workload.dataset.num_files_train=8 diff --git a/dlio_benchmark/common/enumerations.py b/dlio_benchmark/common/enumerations.py index 86332902..e2878cc4 100644 --- a/dlio_benchmark/common/enumerations.py +++ b/dlio_benchmark/common/enumerations.py @@ -89,41 +89,66 @@ class FormatType(Enum): """ Format Type supported by the benchmark. """ - TFRECORD = 'tfrecord' - HDF5 = 'hdf5' - CSV = 'csv' - NPZ = 'npz' - HDF5_OPT = 'hdf5_opt' - JPEG = 'jpeg' - PNG = 'png' + DLIO_TFRECORD = 'dlio_tfrecord' + DLIO_HDF5 = 'dlio_hdf5' + DLIO_CSV = 'dlio_csv' + DLIO_NPZ = 'dlio_npz' + DLIO_HDF5_OPT = 'dlio_hdf5_opt' + DLIO_JPEG = 'dlio_jpeg' + DLIO_PNG = 'dlio_png' + TF_TFRECORD = 'tf_tfrecord' + DALI_TFRECORD = 'dali_tfrecord' + DALI_NPZ = 'dali_npz' def __str__(self): return self.value + @staticmethod + def getextension(value): + if value in [FormatType.DLIO_TFRECORD.value,FormatType.DALI_TFRECORD.value] : + return "tfrecord" + elif FormatType.DLIO_HDF5.value == value: + return "hdf5" + elif FormatType.DLIO_CSV.value == value: + return "csv" + elif value in [FormatType.DLIO_NPZ.value] : + return "npz" + elif value == FormatType.DALI_NPZ.value: + return "npy" + elif FormatType.DLIO_HDF5_OPT.value == value: + return "hdf5" + elif FormatType.DLIO_JPEG.value == value: + return "jpeg" + elif FormatType.DLIO_PNG.value == value: + return "png" + @ staticmethod def get_enum(value): - if FormatType.TFRECORD.value == value: - return FormatType.TFRECORD - elif FormatType.HDF5.value == value: - return FormatType.HDF5 - elif FormatType.CSV.value == value: - return FormatType.CSV - elif FormatType.NPZ.value == value: - return FormatType.NPZ - elif FormatType.HDF5_OPT.value == value: - return FormatType.HDF5_OPT - elif FormatType.JPEG.value == value: - return FormatType.JPEG - elif FormatType.PNG.value == value: - return FormatType.PNG + if FormatType.DLIO_TFRECORD.value == value: + return FormatType.DLIO_TFRECORD + elif FormatType.DLIO_HDF5.value == value: + return FormatType.DLIO_HDF5 + elif FormatType.DLIO_CSV.value == value: + return FormatType.DLIO_CSV + elif FormatType.DLIO_NPZ.value == value: + return FormatType.DLIO_NPZ + elif FormatType.DLIO_HDF5_OPT.value == value: + return FormatType.DLIO_HDF5_OPT + elif FormatType.DLIO_JPEG.value == value: + return FormatType.DLIO_JPEG + elif FormatType.DLIO_PNG.value == value: + return FormatType.DLIO_PNG class DataLoaderType(Enum): """ Framework DataLoader Type """ - TENSORFLOW='tensorflow' - PYTORCH='pytorch' - DALI='dali' + DLIO_TENSORFLOW='dlio_tensorflow' + DLIO_PYTORCH='dlio_pytorch' + DLIO_DALI='dlio_dali' + NATIVE_TENSORFLOW = 'native_tensorflow' + NATIVE_PYTORCH = 'native_pytorch' + NATIVE_DALI = 'native_dali' CUSTOM='custom' NONE='none' diff --git a/dlio_benchmark/configs/workload/bert.yaml b/dlio_benchmark/configs/workload/bert.yaml index d730132d..b096c4ad 100644 --- a/dlio_benchmark/configs/workload/bert.yaml +++ b/dlio_benchmark/configs/workload/bert.yaml @@ -10,7 +10,7 @@ workflow: dataset: data_folder: data/bert - format: tfrecord + format: dlio_tfrecord num_files_train: 500 num_samples_per_file: 313532 record_length: 2500 @@ -22,7 +22,7 @@ train: total_training_steps: 1000 reader: - data_loader: tensorflow + data_loader: dlio_tensorflow read_threads: 1 computation_threads: 1 transfer_size: 262144 diff --git a/dlio_benchmark/configs/workload/cosmoflow.yaml b/dlio_benchmark/configs/workload/cosmoflow.yaml index 690cfd05..c8cf4161 100644 --- a/dlio_benchmark/configs/workload/cosmoflow.yaml +++ b/dlio_benchmark/configs/workload/cosmoflow.yaml @@ -14,7 +14,7 @@ dataset: reader: - data_loader: tensorflow + data_loader: dlio_tensorflow computation_threads: 8 read_threads: 8 batch_size: 1 diff --git a/dlio_benchmark/configs/workload/default.yaml b/dlio_benchmark/configs/workload/default.yaml index 4a944952..b769b137 100644 --- a/dlio_benchmark/configs/workload/default.yaml +++ b/dlio_benchmark/configs/workload/default.yaml @@ -10,7 +10,7 @@ workflow: dataset: data_folder: data/default - format: npz + format: dlio_npz num_files_train: 64 num_files_eval: 8 num_samples_per_file: 1 @@ -19,7 +19,7 @@ dataset: num_subfolders_eval: 2 reader: - data_loader: pytorch + data_loader: dlio_pytorch batch_size: 4 batch_size_eval: 1 diff --git a/dlio_benchmark/configs/workload/resnet50.yaml b/dlio_benchmark/configs/workload/resnet50.yaml index d8376ed9..254ddf7f 100644 --- a/dlio_benchmark/configs/workload/resnet50.yaml +++ b/dlio_benchmark/configs/workload/resnet50.yaml @@ -11,12 +11,12 @@ dataset: num_samples_per_file: 1 record_length: 150528 data_folder: data/resnet50 - format: jpeg + format: dlio_jpeg train: computation_time: 0.1 reader: - data_loader: pytorch + data_loader: dlio_pytorch read_threads: 8 computation_threads: 8 diff --git a/dlio_benchmark/configs/workload/unet3d.yaml b/dlio_benchmark/configs/workload/unet3d.yaml index 2de4b194..8d3b4567 100644 --- a/dlio_benchmark/configs/workload/unet3d.yaml +++ b/dlio_benchmark/configs/workload/unet3d.yaml @@ -9,7 +9,7 @@ workflow: dataset: data_folder: data/unet3d/ - format: npz + format: dlio_npz num_files_train: 168 num_samples_per_file: 1 record_length: 146600628 @@ -17,7 +17,7 @@ dataset: record_length_resize: 2097152 reader: - data_loader: pytorch + data_loader: dlio_pytorch batch_size: 4 read_threads: 4 file_shuffle: seed diff --git a/dlio_benchmark/data_generator/data_generator.py b/dlio_benchmark/data_generator/data_generator.py index 3b9543bf..8dffbf52 100644 --- a/dlio_benchmark/data_generator/data_generator.py +++ b/dlio_benchmark/data_generator/data_generator.py @@ -48,7 +48,7 @@ def __init__(self): self._file_list = None self.num_subfolders_train = self._args.num_subfolders_train self.num_subfolders_eval = self._args.num_subfolders_eval - self.format = self._args.format + self.format = self._args.format_ext self.storage = StorageFactory().get_storage(self._args.storage_type, self._args.storage_root, self._args.framework) def get_dimension(self): diff --git a/dlio_benchmark/data_generator/generator_factory.py b/dlio_benchmark/data_generator/generator_factory.py index 7c05d3a4..c8baeae5 100644 --- a/dlio_benchmark/data_generator/generator_factory.py +++ b/dlio_benchmark/data_generator/generator_factory.py @@ -26,22 +26,25 @@ def __init__(self): @staticmethod def get_generator(type): - if type == FormatType.TFRECORD: + if type in [FormatType.DLIO_TFRECORD, FormatType.DALI_TFRECORD, FormatType.TF_TFRECORD]: from dlio_benchmark.data_generator.tf_generator import TFRecordGenerator return TFRecordGenerator() - elif type == FormatType.HDF5: + elif type == FormatType.DLIO_HDF5: from dlio_benchmark.data_generator.hdf5_generator import HDF5Generator return HDF5Generator() - elif type == FormatType.CSV: + elif type == FormatType.DLIO_CSV: from dlio_benchmark.data_generator.csv_generator import CSVGenerator return CSVGenerator() - elif type == FormatType.NPZ: + elif type == FormatType.DLIO_NPZ: from dlio_benchmark.data_generator.npz_generator import NPZGenerator return NPZGenerator() - elif type == FormatType.JPEG: + elif type == FormatType.DALI_NPZ: + from dlio_benchmark.data_generator.npy_generator import NPYGenerator + return NPYGenerator() + elif type == FormatType.DLIO_JPEG: from dlio_benchmark.data_generator.jpeg_generator import JPEGGenerator return JPEGGenerator() - elif type == FormatType.PNG: + elif type == FormatType.DLIO_PNG: from dlio_benchmark.data_generator.png_generator import PNGGenerator return PNGGenerator() else: diff --git a/dlio_benchmark/data_generator/npy_generator.py b/dlio_benchmark/data_generator/npy_generator.py new file mode 100644 index 00000000..cf484b99 --- /dev/null +++ b/dlio_benchmark/data_generator/npy_generator.py @@ -0,0 +1,52 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from dlio_benchmark.common.enumerations import Compression +from dlio_benchmark.data_generator.data_generator import DataGenerator + +import logging +import numpy as np + +from dlio_benchmark.utils.utility import progress, utcnow, Profile +from shutil import copyfile +from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR + +dlp = Profile(MODULE_DATA_GENERATOR) + +""" +Generator for creating data in NPZ format. +""" +class NPYGenerator(DataGenerator): + def __init__(self): + super().__init__() + + @dlp.log + def generate(self): + """ + Generator for creating data in NPZ format of 3d dataset. + """ + super().generate() + np.random.seed(10) + record_labels = [0] * self.num_samples + for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)): + dim1, dim2 = self.get_dimension() + records = np.random.randint(255, size=(dim1, dim2, self.num_samples), dtype=np.uint8) + out_path_spec = self.storage.get_uri(self._file_list[i]) + progress(i+1, self.total_files_to_generate, "Generating NPZ Data") + prev_out_spec = out_path_spec + np.save(out_path_spec, records) + np.random.seed() diff --git a/dlio_benchmark/data_generator/tf_generator.py b/dlio_benchmark/data_generator/tf_generator.py index 77fdbd32..d86841d7 100644 --- a/dlio_benchmark/data_generator/tf_generator.py +++ b/dlio_benchmark/data_generator/tf_generator.py @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os +from subprocess import call from dlio_benchmark.data_generator.data_generator import DataGenerator import numpy as np @@ -64,4 +66,14 @@ def generate(self): serialized = example.SerializeToString() # Write the serialized data to the TFRecords file. writer.write(serialized) + tfrecord2idx_script = "tfrecord2idx" + folder = "train" + if "valid" in out_path_spec: + folder = "valid" + index_folder = f"{self._args.data_folder}/index/{folder}" + filename = os.path.basename(out_path_spec) + self.storage.create_node(index_folder, exist_ok=True) + tfrecord_idx = f"{index_folder}/{filename}.idx" + if not os.path.isfile(tfrecord_idx): + call([tfrecord2idx_script, out_path_spec, tfrecord_idx]) np.random.seed() diff --git a/dlio_benchmark/data_loader/data_loader_factory.py b/dlio_benchmark/data_loader/data_loader_factory.py index e8457450..53ecfcc8 100644 --- a/dlio_benchmark/data_loader/data_loader_factory.py +++ b/dlio_benchmark/data_loader/data_loader_factory.py @@ -36,15 +36,24 @@ def get_loader(type, format_type, dataset_type, epoch): if _args.data_loader_class is not None: logging.info(f"{utcnow()} Running DLIO with custom data loader class {_args.data_loader_class.__name__}") return _args.data_loader_class(format_type, dataset_type, epoch) - elif type == DataLoaderType.PYTORCH: - from dlio_benchmark.data_loader.torch_data_loader import TorchDataLoader - return TorchDataLoader(format_type, dataset_type, epoch) - elif type == DataLoaderType.TENSORFLOW: - from dlio_benchmark.data_loader.tf_data_loader import TFDataLoader - return TFDataLoader(format_type, dataset_type, epoch) - elif type == DataLoaderType.DALI: - from dlio_benchmark.data_loader.dali_data_loader import DaliDataLoader - return DaliDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.DLIO_PYTORCH: + from dlio_benchmark.data_loader.dlio_torch_data_loader import DLIOTorchDataLoader + return DLIOTorchDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.DLIO_TENSORFLOW: + from dlio_benchmark.data_loader.dlio_tf_data_loader import DLIOTFDataLoader + return DLIOTFDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.DLIO_DALI: + from dlio_benchmark.data_loader.dlio_dali_data_loader import DLIODaliDataLoader + return DLIODaliDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.NATIVE_DALI: + from dlio_benchmark.data_loader.native_dali_data_loader import NativeDaliDataLoader + return NativeDaliDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.NATIVE_PYTORCH: + from dlio_benchmark.data_loader.native_torch_data_loader import NativeTorchDataLoader + return NativeTorchDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.NATIVE_TENSORFLOW: + from dlio_benchmark.data_loader.native_tf_data_loader import NativeTFDataLoader + return NativeTFDataLoader(format_type, dataset_type, epoch) else: print("Data Loader %s not supported or plugins not found" % type) raise Exception(str(ErrorCodes.EC1004)) diff --git a/dlio_benchmark/data_loader/dali_data_loader.py b/dlio_benchmark/data_loader/dlio_dali_data_loader.py similarity index 98% rename from dlio_benchmark/data_loader/dali_data_loader.py rename to dlio_benchmark/data_loader/dlio_dali_data_loader.py index dc36b9f8..21b3f556 100644 --- a/dlio_benchmark/data_loader/dali_data_loader.py +++ b/dlio_benchmark/data_loader/dlio_dali_data_loader.py @@ -45,10 +45,10 @@ def __call__(self, sample_info): return image, np.uint8([sample_idx]) -class DaliDataLoader(BaseDataLoader): +class DLIODaliDataLoader(BaseDataLoader): @dlp.log_init def __init__(self, format_type, dataset_type, epoch): - super().__init__(format_type, dataset_type, epoch, DataLoaderType.DALI) + super().__init__(format_type, dataset_type, epoch, DataLoaderType.DLIO_DALI) self.pipelines = [] @dlp.log diff --git a/dlio_benchmark/data_loader/tf_data_loader.py b/dlio_benchmark/data_loader/dlio_tf_data_loader.py similarity index 92% rename from dlio_benchmark/data_loader/tf_data_loader.py rename to dlio_benchmark/data_loader/dlio_tf_data_loader.py index 304d10a3..127a7b2d 100644 --- a/dlio_benchmark/data_loader/tf_data_loader.py +++ b/dlio_benchmark/data_loader/dlio_tf_data_loader.py @@ -15,7 +15,7 @@ dlp = Profile(MODULE_DATA_LOADER) -class TensorflowDataset(tf.data.Dataset): +class DLIOTensorflowDataset(tf.data.Dataset): @staticmethod @dlp.log def _generator(format_type, dataset_type, epoch_number, thread_index): @@ -40,11 +40,11 @@ def __new__(cls, format_type, dataset_type, epoch, shape, thread_index): return dataset -class TFDataLoader(BaseDataLoader): +class DLIOTFDataLoader(BaseDataLoader): @dlp.log_init def __init__(self, format_type, dataset_type, epoch): - super().__init__(format_type, dataset_type, epoch, DataLoaderType.TENSORFLOW) + super().__init__(format_type, dataset_type, epoch, DataLoaderType.DLIO_TENSORFLOW) self._dataset = None @dlp.log @@ -67,8 +67,8 @@ def read(self): batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval self._dataset = tf.data.Dataset.from_tensor_slices(np.arange(read_threads)).with_options(options) - self._dataset = self._dataset.interleave(lambda x: TensorflowDataset(self.format_type, self.dataset_type, - self.epoch_number, ( + self._dataset = self._dataset.interleave(lambda x: DLIOTensorflowDataset(self.format_type, self.dataset_type, + self.epoch_number, ( batch_size, self._args.max_dimension, self._args.max_dimension), x), diff --git a/dlio_benchmark/data_loader/torch_data_loader.py b/dlio_benchmark/data_loader/dlio_torch_data_loader.py similarity index 95% rename from dlio_benchmark/data_loader/torch_data_loader.py rename to dlio_benchmark/data_loader/dlio_torch_data_loader.py index 0f42806e..20ae56a4 100644 --- a/dlio_benchmark/data_loader/torch_data_loader.py +++ b/dlio_benchmark/data_loader/dlio_torch_data_loader.py @@ -13,7 +13,7 @@ dlp = Profile(MODULE_DATA_LOADER) -class TorchDataset(Dataset): +class DLIOTorchDataset(Dataset): """ Currently, we only support loading one sample per file TODO: support multiple samples per file @@ -49,17 +49,17 @@ def __getitem__(self, image_idx): logging.debug(f"{utcnow()} Rank {get_rank()} reading {image_idx} sample") return self.reader.read_index(image_idx, step) -class TorchDataLoader(BaseDataLoader): +class DLIOTorchDataLoader(BaseDataLoader): @dlp.log_init def __init__(self, format_type, dataset_type, epoch_number): - super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.PYTORCH) + super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.DLIO_PYTORCH) @dlp.log def read(self): do_shuffle = True if self._args.sample_shuffle != Shuffle.OFF else False num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval - dataset = TorchDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, self._args.read_threads, batch_size) + dataset = DLIOTorchDataset(self.format_type, self.dataset_type, self.epoch_number, num_samples, self._args.read_threads, batch_size) if do_shuffle: sampler = RandomSampler(dataset) else: diff --git a/dlio_benchmark/data_loader/native_dali_data_loader.py b/dlio_benchmark/data_loader/native_dali_data_loader.py new file mode 100644 index 00000000..2df04a5e --- /dev/null +++ b/dlio_benchmark/data_loader/native_dali_data_loader.py @@ -0,0 +1,60 @@ +from time import time +import logging +import math +import numpy as np +from nvidia.dali.pipeline import Pipeline +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali as dali +from nvidia.dali.plugin.pytorch import DALIGenericIterator + +from dlio_benchmark.common.constants import MODULE_DATA_LOADER +from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType +from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader +from dlio_benchmark.reader.reader_factory import ReaderFactory +from dlio_benchmark.utils.utility import utcnow, get_rank, timeit, Profile + +dlp = Profile(MODULE_DATA_LOADER) + + +class NativeDaliDataLoader(BaseDataLoader): + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch): + super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_DALI) + self.pipelines = [] + + @dlp.log + def read(self): + num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + parallel = True if self._args.read_threads > 0 else False + self.pipelines = [] + num_threads = 1 + if self._args.read_threads > 0: + num_threads = self._args.read_threads + # None executes pipeline on CPU and the reader does the batching + pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads, + exec_async=False, exec_pipelined=False) + with pipeline: + images = ReaderFactory.get_reader(type=self.format_type, + dataset_type=self.dataset_type, + thread_index=-1, + epoch_number=self.epoch_number).read() + pipeline.set_outputs(images) + self.pipelines.append(pipeline) + logging.info(f"{utcnow()} Creating {num_threads} pipelines by {self._args.my_rank} rank ") + + @dlp.log + def next(self): + super().next() + num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + for step in range(num_samples // batch_size): + _dataset = DALIGenericIterator(self.pipelines, ['data']) + for batch in _dataset: + logging.info(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") + yield batch + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/data_loader/native_tf_data_loader.py b/dlio_benchmark/data_loader/native_tf_data_loader.py new file mode 100644 index 00000000..a81a47dd --- /dev/null +++ b/dlio_benchmark/data_loader/native_tf_data_loader.py @@ -0,0 +1,58 @@ +from time import time +import logging +import math + +import tensorflow as tf + +from dlio_benchmark.common.constants import MODULE_DATA_LOADER +from dlio_benchmark.common.enumerations import DataLoaderType, Shuffle, FormatType, DatasetType +from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader +from dlio_benchmark.reader.reader_factory import ReaderFactory +from dlio_benchmark.utils.utility import utcnow, Profile + +import numpy as np + +dlp = Profile(MODULE_DATA_LOADER) + + +class NativeTFDataLoader(BaseDataLoader): + + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch): + super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_TENSORFLOW) + self._dataset = None + + @dlp.log + def read(self): + read_threads = self._args.read_threads + if read_threads == 0: + if self._args.my_rank == 0: + logging.warning( + f"{utcnow()} `read_threads` is set to be 0 for tf.data loader. We change it to 1") + read_threads = 1 + + options = tf.data.Options() + if "threading" in dir(options): + options.threading.private_threadpool_size = read_threads + options.threading.max_intra_op_parallelism = read_threads + elif "experimental_threading" in dir(options): + options.experimental_threading.private_threadpool_size = read_threads + options.experimental_threading.max_intra_op_parallelism = read_threads + + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + self._dataset = ReaderFactory.get_reader(type=self.format_type, + dataset_type=self.dataset_type, + thread_index=-1, + epoch_number=self.epoch_number).read() + if self._args.prefetch_size > 0: + self._dataset = self._dataset.prefetch(buffer_size=self._args.prefetch_size) + self._dataset = self._dataset.batch(batch_size, drop_remainder=True) + + @dlp.log + def next(self): + for batch in self._dataset: + yield batch + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/data_loader/native_torch_data_loader.py b/dlio_benchmark/data_loader/native_torch_data_loader.py new file mode 100644 index 00000000..5ffc699b --- /dev/null +++ b/dlio_benchmark/data_loader/native_torch_data_loader.py @@ -0,0 +1,77 @@ +from time import time +import logging +import math +import torch +from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler + +from dlio_benchmark.common.constants import MODULE_DATA_LOADER +from dlio_benchmark.common.enumerations import Shuffle, DatasetType, DataLoaderType +from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader +from dlio_benchmark.reader.reader_factory import ReaderFactory +from dlio_benchmark.utils.utility import utcnow, get_rank, Profile + +dlp = Profile(MODULE_DATA_LOADER) + +class NativeTorchDataLoader(BaseDataLoader): + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch_number): + super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.NATIVE_PYTORCH) + + @dlp.log + def read(self): + do_shuffle = True if self._args.sample_shuffle != Shuffle.OFF else False + num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + dataset = ReaderFactory.get_reader(type=self.format_type, + dataset_type=self.dataset_type, + thread_index=-1, + epoch_number=self.epoch_number).read() + if do_shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + if self._args.read_threads > 1: + prefetch_factor = math.ceil(self._args.prefetch_size / self._args.read_threads) + else: + prefetch_factor = self._args.prefetch_size + if prefetch_factor > 0: + if self._args.my_rank == 0: + logging.debug( + f"{utcnow()} Prefetch size is {self._args.prefetch_size}; prefetch factor of {prefetch_factor} will be set to Torch DataLoader.") + else: + if self._args.my_rank == 0: + logging.debug( + f"{utcnow()} Prefetch size is 0; a default prefetch factor of 2 will be set to Torch DataLoader.") + logging.debug(f"{utcnow()} Setup dataloader with {self._args.read_threads} workers {torch.__version__}") + if torch.__version__ == '1.3.1': + self._dataset = DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=self._args.read_threads, + pin_memory=True, + drop_last=True, + worker_init_fn=dataset.worker_init) + else: + self._dataset = DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=self._args.read_threads, + pin_memory=True, + drop_last=True, + worker_init_fn=dataset.worker_init, + prefetch_factor=prefetch_factor if prefetch_factor > 0 else 2) # 2 is the default value + logging.debug(f"{utcnow()} Rank {self._args.my_rank} will read {len(self._dataset) * batch_size} files") + + # self._dataset.sampler.set_epoch(epoch_number) + + @dlp.log + def next(self): + super().next() + total = self._args.training_steps if self.dataset_type is DatasetType.TRAIN else self._args.eval_steps + logging.debug(f"{utcnow()} Rank {self._args.my_rank} should read {total} batches") + for batch in self._dataset: + yield batch + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/main.py b/dlio_benchmark/main.py index 7b490bbb..567bad57 100644 --- a/dlio_benchmark/main.py +++ b/dlio_benchmark/main.py @@ -191,7 +191,7 @@ def initialize(self): os.path.join(self.args.data_folder, f"{dataset_type}", filenames[0])) == MetadataType.DIRECTORY: assert(num_subfolders == len(filenames)) - fullpaths = self.storage.walk_node(os.path.join(self.args.data_folder, f"{dataset_type}/*/*.{self.args.format}"), + fullpaths = self.storage.walk_node(os.path.join(self.args.data_folder, f"{dataset_type}/*/*.{self.args.format_ext}"), use_pattern=True) files = [self.storage.get_basename(f) for f in fullpaths] idx = np.argsort(files) @@ -199,14 +199,14 @@ def initialize(self): else: assert(num_subfolders==0) fullpaths = [self.storage.get_uri(os.path.join(self.args.data_folder, f"{dataset_type}", entry)) - for entry in filenames if entry.find(f'{self.args.format}')!=-1] + for entry in filenames if entry.find(f'{self.args.format_ext}')!=-1] fullpaths = sorted(fullpaths) if dataset_type is DatasetType.TRAIN: file_list_train = fullpaths elif dataset_type is DatasetType.VALID: file_list_eval = fullpaths if not self.generate_only: - assert(self.num_files_train <=len(file_list_train)) + assert(self.num_files_train <= len(file_list_train)) if self.do_eval: assert(self.num_files_eval <=len(file_list_eval)) if (self.num_files_train < len(file_list_train)): diff --git a/dlio_benchmark/reader/dali_base_reader.py b/dlio_benchmark/reader/dali_base_reader.py new file mode 100644 index 00000000..28b2ef9f --- /dev/null +++ b/dlio_benchmark/reader/dali_base_reader.py @@ -0,0 +1,65 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from time import sleep + +import numpy as np +import nvidia +from abc import ABC, abstractmethod +from nvidia import dali + +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.common.enumerations import DatasetType +from dlio_benchmark.utils.config import ConfigArguments +from dlio_benchmark.utils.utility import Profile + +from nvidia.dali import fn +dlp = Profile(MODULE_DATA_READER) + +class DaliBaseReader(ABC): + + @dlp.log_init + def __init__(self, dataset_type): + self.dataset_type = dataset_type + self._args = ConfigArguments.get_instance() + self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + self.file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval + + @dlp.log + def _preprocess(self, dataset): + if self._args.preprocess_time != 0. or self._args.preprocess_time_stdev != 0.: + t = np.random.normal(self._args.preprocess_time, self._args.preprocess_time_stdev) + sleep(max(t, 0.0)) + return dataset + + @dlp.log + def _resize(self, dataset): + return nvidia.dali.fn.reshape(dataset, shape=[self._args.max_dimension, self._args.max_dimension]) + + @abstractmethod + def _load(self): + pass + + @dlp.log + def read(self): + dataset = self._load() + #dataset = self._resize(dataset) + #dataset = nvidia.dali.fn.python_function(dataset, function= self._preprocess, num_outputs=1) + return dataset + + @abstractmethod + def finalize(self): + pass diff --git a/dlio_benchmark/reader/dali_npz_reader.py b/dlio_benchmark/reader/dali_npz_reader.py new file mode 100644 index 00000000..f2887ef5 --- /dev/null +++ b/dlio_benchmark/reader/dali_npz_reader.py @@ -0,0 +1,68 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import math +import logging +from time import time + +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.reader.dali_base_reader import DaliBaseReader +from dlio_benchmark.reader.tf_base_reader import TFBaseReader +from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec + +dlp = Profile(MODULE_DATA_READER) + + +class DaliNPZReader(DaliBaseReader): + @dlp.log_init + def __init__(self, dataset_type): + super().__init__(dataset_type) + + @dlp.log + def _load(self): + logging.debug( + f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + seed_change_epoch = False + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + seed_change_epoch = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + + stick_to_shard = True + if seed_change_epoch: + stick_to_shard = False + + dataset = fn.readers.numpy(device='cpu', files=self.file_list, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, initial_fill=initial_fill, + random_shuffle=random_shuffle, seed=seed, shuffle_after_epoch=seed_change_epoch, + stick_to_shard=stick_to_shard, pad_last_batch=True) + return dataset + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/dali_tfrecord_reader.py b/dlio_benchmark/reader/dali_tfrecord_reader.py new file mode 100644 index 00000000..4b8147af --- /dev/null +++ b/dlio_benchmark/reader/dali_tfrecord_reader.py @@ -0,0 +1,78 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os.path + +import math +import logging +from time import time + +import nvidia +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.reader.dali_base_reader import DaliBaseReader +from dlio_benchmark.reader.tf_base_reader import TFBaseReader +from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec + +dlp = Profile(MODULE_DATA_READER) + + +class DaliTFRecordReader(DaliBaseReader): + @dlp.log_init + def __init__(self, dataset_type): + super().__init__(dataset_type) + + @dlp.log + def _load(self): + folder = "valid" + if self.dataset_type == DatasetType.TRAIN: + folder = "train" + index_folder = f"{self._args.data_folder}/index/{folder}" + index_files = [] + for file in self.file_list: + filename = os.path.basename(file) + index_files.append(f"{index_folder}/{filename}.idx") + logging.info( + f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + dataset = fn.readers.tfrecord(path=self.file_list, + index_path=index_files, + features={ + 'image': tfrec.FixedLenFeature((), tfrec.string, ""), + 'size': tfrec.FixedLenFeature([1], tfrec.int64, 0) + }, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, + initial_fill=initial_fill, + random_shuffle=random_shuffle, seed=seed, + stick_to_shard=True, pad_last_batch=True) + return dataset["image"] + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/reader_handler.py b/dlio_benchmark/reader/dlio_base_reader.py similarity index 96% rename from dlio_benchmark/reader/reader_handler.py rename to dlio_benchmark/reader/dlio_base_reader.py index 658b4394..26fe3c67 100644 --- a/dlio_benchmark/reader/reader_handler.py +++ b/dlio_benchmark/reader/dlio_base_reader.py @@ -33,7 +33,7 @@ dlp = Profile(MODULE_DATA_READER) -class FormatReader(ABC): +class DLIOBaseReader(ABC): read_images = None def __init__(self, dataset_type, thread_index): @@ -44,8 +44,8 @@ def __init__(self, dataset_type, thread_index): self.dataset_type = dataset_type self.open_file_map = {} - if FormatReader.read_images is None: - FormatReader.read_images = 0 + if DLIOBaseReader.read_images is None: + DLIOBaseReader.read_images = 0 self.step = 1 self.image_idx = 0 self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval @@ -106,7 +106,7 @@ def read_index(self, global_sample_idx, step): self.image_idx = global_sample_idx filename, sample_index = self._args.global_index_map[global_sample_idx] logging.debug(f"{utcnow()} read_index {filename}, {sample_index}") - FormatReader.read_images += 1 + DLIOBaseReader.read_images += 1 if self._args.read_type is ReadType.ON_DEMAND or filename not in self.open_file_map: self.open_file_map[filename] = self.open(filename) image = self.get_sample(filename, sample_index) diff --git a/dlio_benchmark/reader/csv_reader.py b/dlio_benchmark/reader/dlio_csv_reader.py similarity index 94% rename from dlio_benchmark/reader/csv_reader.py rename to dlio_benchmark/reader/dlio_csv_reader.py index 83a25f39..947e3200 100644 --- a/dlio_benchmark/reader/csv_reader.py +++ b/dlio_benchmark/reader/dlio_csv_reader.py @@ -18,12 +18,12 @@ from dlio_benchmark.common.constants import MODULE_DATA_READER from dlio_benchmark.utils.utility import Profile -from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader dlp = Profile(MODULE_DATA_READER) -class CSVReader(FormatReader): +class DLIOCSVReader(DLIOBaseReader): """ CSV Reader reader and iterator logic. """ diff --git a/dlio_benchmark/reader/hdf5_reader.py b/dlio_benchmark/reader/dlio_hdf5_reader.py similarity index 94% rename from dlio_benchmark/reader/hdf5_reader.py rename to dlio_benchmark/reader/dlio_hdf5_reader.py index 49cdb196..1fd4b1ed 100644 --- a/dlio_benchmark/reader/hdf5_reader.py +++ b/dlio_benchmark/reader/dlio_hdf5_reader.py @@ -20,7 +20,7 @@ from dlio_benchmark.common.constants import MODULE_DATA_READER from dlio_benchmark.utils.utility import Profile -from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader dlp = Profile(MODULE_DATA_READER) @@ -29,7 +29,7 @@ """ -class HDF5Reader(FormatReader): +class DLIOHDF5Reader(DLIOBaseReader): @dlp.log_init def __init__(self, dataset_type, thread_index, epoch): diff --git a/dlio_benchmark/reader/jpeg_reader.py b/dlio_benchmark/reader/dlio_jpeg_reader.py similarity index 94% rename from dlio_benchmark/reader/jpeg_reader.py rename to dlio_benchmark/reader/dlio_jpeg_reader.py index c441ffee..174f13cc 100644 --- a/dlio_benchmark/reader/jpeg_reader.py +++ b/dlio_benchmark/reader/dlio_jpeg_reader.py @@ -19,13 +19,13 @@ from PIL import Image from dlio_benchmark.common.constants import MODULE_DATA_READER -from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader from dlio_benchmark.utils.utility import Profile dlp = Profile(MODULE_DATA_READER) -class JPEGReader(FormatReader): +class DLIOJPEGReader(DLIOBaseReader): """ Reader for JPEG files """ diff --git a/dlio_benchmark/reader/npz_reader.py b/dlio_benchmark/reader/dlio_npz_reader.py similarity index 94% rename from dlio_benchmark/reader/npz_reader.py rename to dlio_benchmark/reader/dlio_npz_reader.py index 7ec99851..e58cd0aa 100644 --- a/dlio_benchmark/reader/npz_reader.py +++ b/dlio_benchmark/reader/dlio_npz_reader.py @@ -17,13 +17,13 @@ import numpy as np from dlio_benchmark.common.constants import MODULE_DATA_READER -from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader from dlio_benchmark.utils.utility import Profile dlp = Profile(MODULE_DATA_READER) -class NPZReader(FormatReader): +class DLIONPZReader(DLIOBaseReader): """ Reader for NPZ files """ diff --git a/dlio_benchmark/reader/png_reader.py b/dlio_benchmark/reader/dlio_png_reader.py similarity index 94% rename from dlio_benchmark/reader/png_reader.py rename to dlio_benchmark/reader/dlio_png_reader.py index 3d217356..3d095e0b 100644 --- a/dlio_benchmark/reader/png_reader.py +++ b/dlio_benchmark/reader/dlio_png_reader.py @@ -20,12 +20,12 @@ from PIL import Image from dlio_benchmark.common.constants import MODULE_DATA_READER -from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader from dlio_benchmark.utils.utility import Profile, utcnow dlp = Profile(MODULE_DATA_READER) -class PNGReader(FormatReader): +class DLIOPNGReader(DLIOBaseReader): """ Reader for PNG files """ diff --git a/dlio_benchmark/reader/tf_reader.py b/dlio_benchmark/reader/dlio_tfrecord_reader.py similarity index 97% rename from dlio_benchmark/reader/tf_reader.py rename to dlio_benchmark/reader/dlio_tfrecord_reader.py index b4147206..d6384c8a 100644 --- a/dlio_benchmark/reader/tf_reader.py +++ b/dlio_benchmark/reader/dlio_tfrecord_reader.py @@ -21,13 +21,13 @@ from dlio_benchmark.common.constants import MODULE_DATA_READER from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile from dlio_benchmark.common.enumerations import DatasetType -from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader import tensorflow as tf dlp = Profile(MODULE_DATA_READER) -class TFReader(FormatReader): +class DLIOTFRecordReader(DLIOBaseReader): """ Reader for TFRecord files. """ diff --git a/dlio_benchmark/reader/pytorch_base_reader.py b/dlio_benchmark/reader/pytorch_base_reader.py new file mode 100644 index 00000000..fe4b9d72 --- /dev/null +++ b/dlio_benchmark/reader/pytorch_base_reader.py @@ -0,0 +1,66 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from time import sleep + +import numpy as np +from abc import ABC, abstractmethod +import nvidia.dali as dali +from torchvision.transforms import transforms + +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.common.enumerations import DatasetType +from dlio_benchmark.utils.config import ConfigArguments +from dlio_benchmark.utils.utility import Profile + +from nvidia.dali import fn +import torch +dlp = Profile(MODULE_DATA_READER) + +class PytorchBaseReader(ABC): + + @dlp.log_init + def __init__(self, dataset_type): + self.dataset_type = dataset_type + self._args = ConfigArguments.get_instance() + self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + self.file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval + + @dlp.log + def _preprocess(self, dataset): + if self._args.preprocess_time != 0. or self._args.preprocess_time_stdev != 0.: + t = np.random.normal(self._args.preprocess_time, self._args.preprocess_time_stdev) + sleep(max(t, 0.0)) + return dataset + + @dlp.log + def _resize(self, image): + return torch.from_numpy(self._args.resized_image) + + @abstractmethod + def _load(self): + pass + + @dlp.log + def read(self): + dataset = self._load() + dataset = dataset.map(self._resize) + dataset = dataset.map(self._preprocess) + return dataset + + @abstractmethod + def finalize(self): + pass diff --git a/dlio_benchmark/reader/reader_factory.py b/dlio_benchmark/reader/reader_factory.py index 74fc353e..6c2c077d 100644 --- a/dlio_benchmark/reader/reader_factory.py +++ b/dlio_benchmark/reader/reader_factory.py @@ -37,24 +37,33 @@ def get_reader(type, dataset_type, thread_index, epoch_number): if _args.reader_class is not None: logging.info(f"{utcnow()} Running DLIO with custom data loader class {_args.reader_class.__name__}") return _args.reader_class(dataset_type, thread_index, epoch_number) - elif type == FormatType.HDF5: - from dlio_benchmark.reader.hdf5_reader import HDF5Reader - return HDF5Reader(dataset_type, thread_index, epoch_number) - elif type == FormatType.CSV: - from dlio_benchmark.reader.csv_reader import CSVReader - return CSVReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.JPEG: - from dlio_benchmark.reader.jpeg_reader import JPEGReader - return JPEGReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.PNG: - from dlio_benchmark.reader.png_reader import PNGReader - return PNGReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.NPZ: - from dlio_benchmark.reader.npz_reader import NPZReader - return NPZReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.TFRECORD: - from dlio_benchmark.reader.tf_reader import TFReader - return TFReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.DLIO_HDF5: + from dlio_benchmark.reader.dlio_hdf5_reader import DLIOHDF5Reader + return DLIOHDF5Reader(dataset_type, thread_index, epoch_number) + elif type == FormatType.DLIO_CSV: + from dlio_benchmark.reader.dlio_csv_reader import DLIOCSVReader + return DLIOCSVReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.DLIO_JPEG: + from dlio_benchmark.reader.dlio_jpeg_reader import DLIOJPEGReader + return DLIOJPEGReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.DLIO_PNG: + from dlio_benchmark.reader.dlio_png_reader import DLIOPNGReader + return DLIOPNGReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.DLIO_NPZ: + from dlio_benchmark.reader.dlio_npz_reader import DLIONPZReader + return DLIONPZReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.DLIO_TFRECORD: + from dlio_benchmark.reader.dlio_tfrecord_reader import DLIOTFRecordReader + return DLIOTFRecordReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.TF_TFRECORD: + from dlio_benchmark.reader.tf_tfrecord_reader import TFTFRecordReader + return TFTFRecordReader(dataset_type) + elif type == FormatType.DALI_TFRECORD: + from dlio_benchmark.reader.dali_tfrecord_reader import DaliTFRecordReader + return DaliTFRecordReader(dataset_type) + elif type == FormatType.DALI_NPZ: + from dlio_benchmark.reader.dali_npz_reader import DaliNPZReader + return DaliNPZReader(dataset_type) else: print("Loading data of %s format is not supported without framework data loader" %type) raise Exception(type) diff --git a/dlio_benchmark/reader/tf_base_reader.py b/dlio_benchmark/reader/tf_base_reader.py new file mode 100644 index 00000000..eadda33d --- /dev/null +++ b/dlio_benchmark/reader/tf_base_reader.py @@ -0,0 +1,61 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from time import sleep + +import numpy as np +from abc import ABC, abstractmethod + +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.common.enumerations import DatasetType +from dlio_benchmark.utils.config import ConfigArguments +from dlio_benchmark.utils.utility import Profile +import tensorflow as tf + +dlp = Profile(MODULE_DATA_READER) + +class TFBaseReader(ABC): + def __init__(self, dataset_type): + self.dataset_type = dataset_type + self._args = ConfigArguments.get_instance() + self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + self.file_list = self._args.file_list_train if self.dataset_type is DatasetType.TRAIN else self._args.file_list_eval + + @dlp.log + def _preprocess(self, image): + if self._args.preprocess_time != 0. or self._args.preprocess_time_stdev != 0.: + t = np.random.normal(self._args.preprocess_time, self._args.preprocess_time_stdev) + sleep(max(t, 0.0)) + return image + + @dlp.log + def _resize(self, image): + return tf.convert_to_tensor(self._args.resized_image, dtype=tf.uint8) + + @abstractmethod + def _load(self): + pass + + @dlp.log + def read(self): + dataset = self._load() + dataset = dataset.map(self._resize) + dataset = dataset.map(self._preprocess) + return dataset + + @abstractmethod + def finalize(self): + pass diff --git a/dlio_benchmark/reader/tf_tfrecord_reader.py b/dlio_benchmark/reader/tf_tfrecord_reader.py new file mode 100644 index 00000000..bca2c3f5 --- /dev/null +++ b/dlio_benchmark/reader/tf_tfrecord_reader.py @@ -0,0 +1,71 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import math +import logging +from time import time + +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.reader.tf_base_reader import TFBaseReader +from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile +from dlio_benchmark.common.enumerations import DatasetType +import tensorflow as tf + +dlp = Profile(MODULE_DATA_READER) + + +class TFTFRecordReader(TFBaseReader): + @dlp.log_init + def __init__(self, dataset_type): + super().__init__(dataset_type) + + @dlp.log + def parse_image(self, serialized): + """ + performs deserialization of the tfrecord. + :param serialized: is the serialized version using protobuf + :return: deserialized image and label. + """ + features = \ + { + 'image': tf.io.FixedLenFeature([], tf.string), + 'size': tf.io.FixedLenFeature([], tf.int64) + } + parsed_example = tf.io.parse_example(serialized=serialized, features=features) + # Get the image as raw bytes. + image_raw = parsed_example['image'] + dimension = tf.cast(parsed_example['size'], tf.int32).numpy() + # Decode the raw bytes so it becomes a tensor with type. + image_tensor = tf.io.decode_raw(image_raw, tf.uint8) + size = dimension * dimension + dlp.update(image_size=size) + # image_tensor = tf.io.decode_image(image_raw) + return image_tensor + + @dlp.log + def _load(self): + logging.debug( + f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") + dataset = tf.data.TFRecordDataset(filenames=self.file_list, buffer_size=self._args.transfer_size) + dataset = dataset.shard(num_shards=self._args.comm_size, index=self._args.my_rank) + dataset = dataset.map( + lambda x: tf.py_function(func=self.parse_image, inp=[x], Tout=[tf.uint8]) + , num_parallel_calls=self._args.computation_threads) + return dataset + + @dlp.log + def finalize(self): + pass \ No newline at end of file diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 449637d3..2a343653 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -43,7 +43,7 @@ class ConfigArguments: model: str = "default" framework: FrameworkType = FrameworkType.TENSORFLOW # Dataset format, such as PNG, JPEG - format: FormatType = FormatType.TFRECORD + format: FormatType = FormatType.DLIO_TFRECORD # Shuffle type file_shuffle: Shuffle = Shuffle.OFF shuffle_size: int = 1024 @@ -98,7 +98,7 @@ class ConfigArguments: eval_after_epoch: int = 1 epochs_between_evals: int = 1 model_size: int = 10240 - data_loader: DataLoaderType = DataLoaderType.TENSORFLOW.value + data_loader: DataLoaderType = DataLoaderType.DLIO_TENSORFLOW.value num_subfolders_train: int = 0 num_subfolders_eval: int = 0 iostat_devices: ClassVar[List[str]] = [] @@ -147,12 +147,12 @@ def validate(self): if (self.do_profiling == True) and (self.profiler == Profiler('darshan')): if ('LD_PRELOAD' not in os.environ or os.environ["LD_PRELOAD"].find("libdarshan") == -1): raise Exception("Please set darshan runtime library in LD_PRELOAD") - if self.format is FormatType.TFRECORD and self.framework is not FrameworkType.TENSORFLOW: + if self.format is FormatType.DLIO_TFRECORD and self.framework is not FrameworkType.TENSORFLOW: raise Exception("Imcompatible between format and framework setup.") - if self.format is FormatType.TFRECORD and self.data_loader is not DataLoaderType.TENSORFLOW: + if self.format is FormatType.DLIO_TFRECORD and self.data_loader is not DataLoaderType.DLIO_TENSORFLOW: raise Exception("Imcompatible between format and data loader setup.") - if (self.framework == FrameworkType.TENSORFLOW and self.data_loader == DataLoaderType.PYTORCH) or ( - self.framework == FrameworkType.PYTORCH and self.data_loader == DataLoaderType.TENSORFLOW): + if (self.framework == FrameworkType.TENSORFLOW and self.data_loader == DataLoaderType.DLIO_PYTORCH) or ( + self.framework == FrameworkType.PYTORCH and self.data_loader == DataLoaderType.DLIO_TENSORFLOW): raise Exception("Imcompatible between framework and data_loader setup.") if len(self.file_list_train) != self.num_files_train: raise Exception( @@ -177,11 +177,11 @@ def reset(self): @dlp.log def derive_configurations(self, file_list_train=None, file_list_eval=None): self.dimension = int(math.sqrt(self.record_length)) - self.dimension_stdev = self.record_length_stdev/2.0/math.sqrt(self.record_length) + self.dimension_stdev = self.record_length_stdev / 2.0 / math.sqrt(self.record_length) self.max_dimension = self.dimension - if (self.record_length_resize>0): - self.max_dimension = int(math.sqrt(self.record_length_resize)) - if (file_list_train !=None and file_list_eval !=None): + if self.record_length_resize > 0: + self.max_dimension = int(math.sqrt(self.record_length_resize)) + if file_list_train is not None and file_list_eval is not None: self.resized_image = np.random.randint(255, size=(self.max_dimension, self.max_dimension), dtype=np.uint8) self.file_list_train = file_list_train self.file_list_eval = file_list_eval @@ -195,9 +195,9 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): self.training_steps = int(math.ceil(self.total_samples_train / self.batch_size / self.comm_size)) self.eval_steps = int(math.ceil(self.total_samples_eval / self.batch_size_eval / self.comm_size)) if self.data_loader_sampler is None and self.data_loader_classname is None: - if self.data_loader == DataLoaderType.TENSORFLOW: + if self.data_loader == DataLoaderType.DLIO_TENSORFLOW: self.data_loader_sampler = DataLoaderSampler.ITERATIVE - elif self.data_loader in [DataLoaderType.PYTORCH, DataLoaderType.DALI]: + elif self.data_loader in [DataLoaderType.DLIO_PYTORCH, DataLoaderType.DLIO_DALI]: self.data_loader_sampler = DataLoaderSampler.INDEX if self.data_loader_classname is not None: from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader @@ -209,11 +209,17 @@ def derive_configurations(self, file_list_train=None, file_list_eval=None): self.data_loader_class = obj break if self.reader_classname is not None: - from dlio_benchmark.reader.reader_handler import FormatReader + from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader + from dlio_benchmark.reader.tf_base_reader import TFBaseReader + from dlio_benchmark.reader.pytorch_base_reader import PytorchBaseReader + from dlio_benchmark.reader.dali_base_reader import DaliBaseReader classname = self.reader_classname.split(".")[-1] module = importlib.import_module(".".join(self.reader_classname.split(".")[:-1])) for class_name, obj in inspect.getmembers(module): - if class_name == classname and issubclass(obj, FormatReader): + if class_name == classname and (issubclass(obj, DLIOBaseReader) or + issubclass(obj, TFBaseReader) or + issubclass(obj, PytorchBaseReader) or + issubclass(obj, DaliBaseReader)): logging.info(f"Discovered custom data reader {class_name}") self.reader_class = obj break @@ -223,7 +229,7 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number): logging.debug(f"ranks {self.comm_size} threads {self.read_threads} tensors") num_files = len(file_list) num_threads = 1 - if self.read_threads > 0 and self.data_loader is not DataLoaderType.DALI: + if self.read_threads > 0 and self.data_loader is not DataLoaderType.DLIO_DALI: num_threads = self.read_threads self.samples_per_thread = total_samples / self.comm_size / num_threads file_index = 0 @@ -244,7 +250,7 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number): process_thread_file_map[rank][thread_index] = [] selected_samples = 0 while selected_samples < self.samples_per_thread: - process_thread_file_map[rank][thread_index].append((sample_global_list[sample_index], + process_thread_file_map[rank][thread_index].append((sample_global_list[sample_index], os.path.abspath(file_list[file_index]), sample_global_list[sample_index] % self.num_samples_per_file)) sample_index += 1 @@ -305,7 +311,7 @@ def LoadConfig(args, config): args.storage_type = StorageType(config['storage']['storage_type']) if 'storage_root' in config['storage']: args.storage_root = config['storage']['storage_root'] - + # dataset related settings if 'dataset' in config: if 'record_length' in config['dataset']: @@ -339,6 +345,7 @@ def LoadConfig(args, config): args.file_prefix = config['dataset']['file_prefix'] if 'format' in config['dataset']: args.format = FormatType(config['dataset']['format']) + args.format_ext = FormatType.getextension(config['dataset']['format']) if 'keep_files' in config['dataset']: args.keep_files = config['dataset']['keep_files'] @@ -379,7 +386,7 @@ def LoadConfig(args, config): args.transfer_size = reader['transfer_size'] if 'preprocess_time' in reader: args.preprocess_time = reader['preprocess_time'] - if 'preprocess_time_stdev' in reader: + if 'preprocess_time_stdev' in reader: args.preprocess_time_stdev = reader['preprocess_time_stdev'] # training relevant setting @@ -424,7 +431,7 @@ def LoadConfig(args, config): args.output_folder = config['output']['folder'] if 'log_file' in config['output']: args.log_file = config['output']['log_file'] - + if 'workflow' in config: if 'generate_data' in config['workflow']: args.generate_data = config['workflow']['generate_data'] diff --git a/docs/source/config.rst b/docs/source/config.rst index 8040fb15..98dbc9ef 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -17,7 +17,7 @@ The characteristics of a workload is specified through a YAML file. This file wi dataset: data_folder: data/unet3d/ - format: npz + format: dlio_npz num_files_train: 168 num_samples_per_file: 1 record_length: 146600628 @@ -25,7 +25,7 @@ The characteristics of a workload is specified through a YAML file. This file wi record_length_resize: 2097152 reader: - data_loader: pytorch + data_loader: dlio_pytorch batch_size: 4 read_threads: 4 file_shuffle: seed @@ -128,7 +128,7 @@ dataset - resized sample size * - format - tfrecord - - data format [tfrecord|csv|npz|jpeg|png] + - data format [dlio_tfrecord|dlio_csv|dlio_npz|dlio_jpeg|dlio_png] * - num_files_train - 1 - number of files for the training set @@ -181,7 +181,7 @@ reader - Description * - data_loader - tensorflow - - select the data loader to use [tensorflow|pytorch]. + - select the data loader to use [dlio_tensorflow|dlio_pytorch|dlio_dali]. * - batch_size - 1 - batch size for training diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 767bd714..e958fe22 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -25,7 +25,7 @@ UNET3D: 3D Medical Image Segmentation dataset: data_folder: data/unet3d/ - format: npz + format: dlio_npz num_files_train: 168 num_samples_per_file: 1 record_length: 146600628 @@ -33,7 +33,7 @@ UNET3D: 3D Medical Image Segmentation record_length_resize: 2097152 reader: - data_loader: pytorch + data_loader: dlio_pytorch batch_size: 4 read_threads: 4 file_shuffle: seed @@ -496,7 +496,7 @@ BERT: Natural Language Processing Model dataset: data_folder: data/bert - format: tfrecord + format: dlio_tfrecord num_files_train: 500 num_samples_per_file: 313532 record_length: 2500 @@ -507,7 +507,7 @@ BERT: Natural Language Processing Model total_training_steps: 5000 reader: - data_loader: tensorflow + data_loader: dlio_tensorflow read_threads: 1 computation_threads: 1 transfer_size: 262144 @@ -545,7 +545,7 @@ CosmoFlow: 3D CNN to Learn the Universe at Scale record_length: 131072 reader: - data_loader: tensorflow + data_loader: dlio_tensorflow computation_threads: 8 read_threads: 8 batch_size: 1 @@ -576,9 +576,9 @@ ResNet50: 3D Image classification num_samples_per_file: 1024 record_length: 150528 data_folder: data/resnet50 - format: tfrecord + format: dlio_tfrecord data_loader: - data_loader: tensorflow + data_loader: dlio_tensorflow read_threads: 8 computation_threads: 8 diff --git a/tests/dlio_benchmark_test.py b/tests/dlio_benchmark_test.py index 7fbc1864..44f4fee8 100644 --- a/tests/dlio_benchmark_test.py +++ b/tests/dlio_benchmark_test.py @@ -27,8 +27,9 @@ import subprocess import logging import os -from dlio_benchmark.utils.config import ConfigArguments import dlio_benchmark +from dlio_benchmark.utils.config import ConfigArguments +from dlio_benchmark.common.enumerations import FormatType config_dir=os.path.dirname(dlio_benchmark.__file__)+"/configs/" logging.basicConfig( @@ -74,10 +75,10 @@ def run_benchmark(cfg, storage_root="./", verify=True): @pytest.mark.timeout(60, method="thread") -@pytest.mark.parametrize("fmt, framework", [("png", "tensorflow"), ("npz", "tensorflow"), - ("jpeg", "tensorflow"), ("tfrecord", "tensorflow"), - ("hdf5", "tensorflow")]) -def test_gen_data(fmt, framework) -> None: +@pytest.mark.parametrize("fmt, framework, dataloader", [("dlio_png", "tensorflow", "dlio_tensorflow"), ("dlio_npz", "tensorflow", "dlio_tensorflow"), + ("dlio_jpeg", "tensorflow", "dlio_tensorflow"), ("dlio_tfrecord", "tensorflow", "dlio_tensorflow"), + ("dlio_hdf5", "tensorflow", "dlio_tensorflow"), ("dali_npz", "pytorch", "native_dali")]) +def test_gen_data(fmt, framework, dataloader) -> None: if (comm.rank == 0): logging.info("") logging.info("=" * 80) @@ -85,23 +86,24 @@ def test_gen_data(fmt, framework) -> None: logging.info("=" * 80) with initialize_config_dir(version_base=None, config_dir=config_dir): cfg = compose(config_name='config', overrides=[f'++workload.framework={framework}', - f'++workload.reader.data_loader={framework}', + f'++workload.reader.data_loader={dataloader}', '++workload.workflow.train=False', '++workload.workflow.generate_data=True', f"++workload.dataset.format={fmt}"]) + ext = FormatType.getextension(fmt) benchmark = run_benchmark(cfg, verify=False) if benchmark.args.num_subfolders_train <= 1: train = pathlib.Path(f"{cfg.workload.dataset.data_folder}/train") - train_files = list(train.glob(f"*.{fmt}")) + train_files = list(train.glob(f"*.{ext}")) valid = pathlib.Path(f"{cfg.workload.dataset.data_folder}/valid") - valid_files = list(valid.glob(f"*.{fmt}")) + valid_files = list(valid.glob(f"*.{ext}")) assert (len(train_files) == cfg.workload.dataset.num_files_train) assert (len(valid_files) == cfg.workload.dataset.num_files_eval) else: train = pathlib.Path(f"{cfg.workload.dataset.data_folder}/train") - train_files = list(train.rglob(f"**/*.{fmt}")) + train_files = list(train.rglob(f"**/*.{ext}")) valid = pathlib.Path(f"{cfg.workload.dataset.data_folder}/valid") - valid_files = list(valid.rglob(f"**/*.{fmt}")) + valid_files = list(valid.rglob(f"**/*.{ext}")) assert (len(train_files) == cfg.workload.dataset.num_files_train) assert (len(valid_files) == cfg.workload.dataset.num_files_eval) clean() @@ -124,10 +126,10 @@ def test_subset() -> None: clean() @pytest.mark.timeout(60, method="thread") -@pytest.mark.parametrize("fmt, framework", [("png", "tensorflow"), ("npz", "tensorflow"), - ("jpeg", "tensorflow"), ("tfrecord", "tensorflow"), - ("hdf5", "tensorflow")]) -def test_storage_root_gen_data(fmt, framework) -> None: +@pytest.mark.parametrize("fmt, framework, dataloader", [("dlio_png", "tensorflow", "dlio_tensorflow"), ("dlio_npz", "tensorflow", "dlio_tensorflow"), + ("dlio_jpeg", "tensorflow", "dlio_tensorflow"), ("dlio_tfrecord", "tensorflow", "dlio_tensorflow"), + ("dlio_hdf5", "tensorflow", "dlio_tensorflow")]) +def test_storage_root_gen_data(fmt, framework, dataloader) -> None: storage_root = "runs" clean(storage_root) @@ -138,30 +140,31 @@ def test_storage_root_gen_data(fmt, framework) -> None: logging.info("=" * 80) with initialize_config_dir(version_base=None, config_dir=config_dir): cfg = compose(config_name='config', overrides=[f'++workload.framework={framework}', - f'++workload.reader.data_loader={framework}', + f'++workload.reader.data_loader={dataloader}', '++workload.workflow.train=False', '++workload.workflow.generate_data=True', f"++workload.storage.storage_root={storage_root}", f"++workload.dataset.format={fmt}"]) + ext = FormatType.getextension(fmt) benchmark = run_benchmark(cfg, verify=False) if benchmark.args.num_subfolders_train <= 1: assert ( len(glob.glob( - os.path.join(storage_root, cfg.workload.dataset.data_folder, f"train/*.{fmt}"))) == + os.path.join(storage_root, cfg.workload.dataset.data_folder, f"train/*.{ext}"))) == cfg.workload.dataset.num_files_train) assert ( len(glob.glob( - os.path.join(storage_root, cfg.workload.dataset.data_folder, f"valid/*.{fmt}"))) == + os.path.join(storage_root, cfg.workload.dataset.data_folder, f"valid/*.{ext}"))) == cfg.workload.dataset.num_files_eval) else: - logging.info(os.path.join(storage_root, cfg.workload.dataset.data_folder, f"train/*/*.{fmt}")) + logging.info(os.path.join(storage_root, cfg.workload.dataset.data_folder, f"train/*/*.{ext}")) assert ( len(glob.glob( - os.path.join(storage_root, cfg.workload.dataset.data_folder, f"train/*/*.{fmt}"))) == + os.path.join(storage_root, cfg.workload.dataset.data_folder, f"train/*/*.{ext}"))) == cfg.workload.dataset.num_files_train) assert ( len(glob.glob( - os.path.join(storage_root, cfg.workload.dataset.data_folder, f"valid/*/*.{fmt}"))) == + os.path.join(storage_root, cfg.workload.dataset.data_folder, f"valid/*/*.{ext}"))) == cfg.workload.dataset.num_files_eval) clean(storage_root) @@ -283,9 +286,9 @@ def test_eval() -> None: @pytest.mark.timeout(60, method="thread") -@pytest.mark.parametrize("framework, nt", [("tensorflow", 0), ("tensorflow", 1),("tensorflow", 2), - ("pytorch", 0), ("pytorch", 1), ("pytorch", 2)]) -def test_multi_threads(framework, nt) -> None: +@pytest.mark.parametrize("framework, nt, dataloader", [("tensorflow", 0, "dlio_tensorflow"), ("tensorflow", 1, "dlio_tensorflow"),("tensorflow", 2, "dlio_tensorflow"), + ("pytorch", 0, "dlio_pytorch"), ("pytorch", 1, "dlio_pytorch"), ("pytorch", 2, "dlio_pytorch")]) +def test_multi_threads(framework, nt, dataloader) -> None: clean() if (comm.rank == 0): logging.info("") @@ -297,7 +300,7 @@ def test_multi_threads(framework, nt) -> None: cfg = compose(config_name='config', overrides=['++workload.workflow.train=True', '++workload.workflow.generate_data=True', f"++workload.framework={framework}", - f"++workload.reader.data_loader={framework}", + f"++workload.reader.data_loader={dataloader}", f"++workload.reader.read_threads={nt}", 'workload.train.computation_time=0.01', 'workload.evaluation.eval_time=0.005', @@ -309,18 +312,19 @@ def test_multi_threads(framework, nt) -> None: @pytest.mark.timeout(60, method="thread") -@pytest.mark.parametrize("fmt, framework, dataloader", [("png", "tensorflow","tensorflow"), ("npz", "tensorflow","tensorflow"), - ("jpeg", "tensorflow","tensorflow"), ("tfrecord", "tensorflow","tensorflow"), - ("hdf5", "tensorflow","tensorflow"), ("csv", "tensorflow","tensorflow"), - ("png", "pytorch", "pytorch"), ("npz", "pytorch", "pytorch"), - ("jpeg", "pytorch", "pytorch"), ("hdf5", "pytorch", "pytorch"), ("csv", "pytorch", "pytorch"), - ("png", "tensorflow", "dali"), ("npz", "tensorflow", "dali"), - ("jpeg", "tensorflow", "dali"), - ("hdf5", "tensorflow", "dali"), ("csv", "tensorflow", "dali"), - ("png", "pytorch", "dali"), ("npz", "pytorch", "dali"), - ("jpeg", "pytorch", "dali"), ("hdf5", "pytorch", "dali"), - ("csv", "pytorch", "dali"), - ]) +@pytest.mark.parametrize("fmt, framework, dataloader", [("dlio_png", "tensorflow","dlio_tensorflow"), ("dlio_npz", "tensorflow", "dlio_tensorflow"), + ("dlio_jpeg", "tensorflow", "dlio_tensorflow"), ("dlio_tfrecord", "tensorflow", "dlio_tensorflow"), + ("dlio_hdf5", "tensorflow", "dlio_tensorflow"), ("dlio_csv", "tensorflow", "dlio_tensorflow"), + ("dlio_png", "pytorch", "dlio_pytorch"), ("dlio_npz", "pytorch", "dlio_pytorch"), + ("dlio_jpeg", "pytorch", "dlio_pytorch"), ("dlio_hdf5", "pytorch", "dlio_pytorch"), ("dlio_csv", "pytorch", "dlio_pytorch"), + ("dlio_png", "tensorflow", "dlio_dali"), ("dlio_npz", "tensorflow", "dlio_dali"), + ("dlio_jpeg", "tensorflow", "dlio_dali"), + ("dlio_hdf5", "tensorflow", "dlio_dali"), ("dlio_csv", "tensorflow", "dlio_dali"), + ("dlio_png", "pytorch", "dlio_dali"), ("dlio_npz", "pytorch", "dlio_dali"), + ("dlio_jpeg", "pytorch", "dlio_dali"), ("dlio_hdf5", "pytorch", "dlio_dali"), + ("dlio_csv", "pytorch", "dlio_dali"), + ("dlio_png", "tensorflow", "dlio_tensorflow"), ("tf_tfrecord", "tensorflow", "native_tensorflow"), + ("dali_tfrecord", "pytorch", "native_dali"), ("dali_npz", "pytorch", "native_dali")]) def test_train(fmt, framework, dataloader) -> None: clean() if comm.rank == 0: @@ -345,14 +349,14 @@ def test_train(fmt, framework, dataloader) -> None: @pytest.mark.timeout(60, method="thread") -@pytest.mark.parametrize("fmt, framework", [("png", "tensorflow"), ("npz", "tensorflow"), - ("jpeg", "tensorflow"), ("tfrecord", "tensorflow"), - ("hdf5", "tensorflow"), ("csv", "tensorflow"), - ("png", "pytorch"), ("npz", "pytorch"), - ("jpeg", "pytorch"), ("hdf5", "pytorch"), - ("csv", "pytorch") +@pytest.mark.parametrize("fmt, framework, dataloader", [("dlio_png", "tensorflow", "dlio_tensorflow"), ("dlio_npz", "tensorflow", "dlio_tensorflow"), + ("dlio_jpeg", "tensorflow", "dlio_tensorflow"), ("dlio_tfrecord", "tensorflow", "dlio_tensorflow"), + ("dlio_hdf5", "tensorflow", "dlio_tensorflow"), ("dlio_csv", "tensorflow", "dlio_tensorflow"), + ("dlio_png", "pytorch", "dlio_pytorch"), ("dlio_npz", "pytorch", "dlio_pytorch"), + ("dlio_jpeg", "pytorch", "dlio_pytorch"), ("dlio_hdf5", "pytorch", "dlio_pytorch"), + ("dlio_csv", "pytorch", "dlio_pytorch") ]) -def test_custom_storage_root_train(fmt, framework) -> None: +def test_custom_storage_root_train(fmt, framework, dataloader) -> None: storage_root = "root_dir" clean(storage_root) if (comm.rank == 0): @@ -364,7 +368,7 @@ def test_custom_storage_root_train(fmt, framework) -> None: cfg = compose(config_name='config', overrides=['++workload.workflow.train=True', \ '++workload.workflow.generate_data=True', \ f"++workload.framework={framework}", \ - f"++workload.reader.data_loader={framework}", \ + f"++workload.reader.data_loader={dataloader}", \ f"++workload.dataset.format={fmt}", f"++workload.storage.storage_root={storage_root}", \ 'workload.train.computation_time=0.01', \ diff --git a/tests/plugins/reader/custom_npz_reader.py b/tests/plugins/reader/custom_npz_reader.py index a3277e47..2a6b4b15 100644 --- a/tests/plugins/reader/custom_npz_reader.py +++ b/tests/plugins/reader/custom_npz_reader.py @@ -17,13 +17,13 @@ import numpy as np from dlio_benchmark.common.constants import MODULE_DATA_READER -from dlio_benchmark.reader.reader_handler import FormatReader +from dlio_benchmark.reader.dlio_base_reader import DLIOBaseReader from dlio_benchmark.utils.utility import Profile dlp = Profile(MODULE_DATA_READER) -class CustomNPZReader(FormatReader): +class CustomNPZReader(DLIOBaseReader): """ Reader for NPZ files """