diff --git a/dlio_benchmark/common/enumerations.py b/dlio_benchmark/common/enumerations.py index 64227772..0081195e 100644 --- a/dlio_benchmark/common/enumerations.py +++ b/dlio_benchmark/common/enumerations.py @@ -124,6 +124,7 @@ class DataLoaderType(Enum): TENSORFLOW='tensorflow' PYTORCH='pytorch' DALI='dali' + NATIVE_DALI = 'native_dali' CUSTOM='custom' NONE='none' diff --git a/dlio_benchmark/data_loader/data_loader_factory.py b/dlio_benchmark/data_loader/data_loader_factory.py index e8457450..13bf16b0 100644 --- a/dlio_benchmark/data_loader/data_loader_factory.py +++ b/dlio_benchmark/data_loader/data_loader_factory.py @@ -45,6 +45,9 @@ def get_loader(type, format_type, dataset_type, epoch): elif type == DataLoaderType.DALI: from dlio_benchmark.data_loader.dali_data_loader import DaliDataLoader return DaliDataLoader(format_type, dataset_type, epoch) + elif type == DataLoaderType.NATIVE_DALI: + from dlio_benchmark.data_loader.native_dali_data_loader import NativeDaliDataLoader + return NativeDaliDataLoader(format_type, dataset_type, epoch) else: print("Data Loader %s not supported or plugins not found" % type) raise Exception(str(ErrorCodes.EC1004)) diff --git a/dlio_benchmark/data_loader/native_dali_data_loader.py b/dlio_benchmark/data_loader/native_dali_data_loader.py new file mode 100644 index 00000000..2df04a5e --- /dev/null +++ b/dlio_benchmark/data_loader/native_dali_data_loader.py @@ -0,0 +1,60 @@ +from time import time +import logging +import math +import numpy as np +from nvidia.dali.pipeline import Pipeline +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali as dali +from nvidia.dali.plugin.pytorch import DALIGenericIterator + +from dlio_benchmark.common.constants import MODULE_DATA_LOADER +from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType +from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader +from dlio_benchmark.reader.reader_factory import ReaderFactory +from dlio_benchmark.utils.utility import utcnow, get_rank, timeit, Profile + +dlp = Profile(MODULE_DATA_LOADER) + + +class NativeDaliDataLoader(BaseDataLoader): + @dlp.log_init + def __init__(self, format_type, dataset_type, epoch): + super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_DALI) + self.pipelines = [] + + @dlp.log + def read(self): + num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + parallel = True if self._args.read_threads > 0 else False + self.pipelines = [] + num_threads = 1 + if self._args.read_threads > 0: + num_threads = self._args.read_threads + # None executes pipeline on CPU and the reader does the batching + pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads, + exec_async=False, exec_pipelined=False) + with pipeline: + images = ReaderFactory.get_reader(type=self.format_type, + dataset_type=self.dataset_type, + thread_index=-1, + epoch_number=self.epoch_number).read() + pipeline.set_outputs(images) + self.pipelines.append(pipeline) + logging.info(f"{utcnow()} Creating {num_threads} pipelines by {self._args.my_rank} rank ") + + @dlp.log + def next(self): + super().next() + num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval + batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval + for step in range(num_samples // batch_size): + _dataset = DALIGenericIterator(self.pipelines, ['data']) + for batch in _dataset: + logging.info(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ") + yield batch + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/dali_image_reader.py b/dlio_benchmark/reader/dali_image_reader.py new file mode 100644 index 00000000..adbc2c55 --- /dev/null +++ b/dlio_benchmark/reader/dali_image_reader.py @@ -0,0 +1,69 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import math +import logging +from time import time + +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.reader.dali_base_reader import DaliBaseReader +from dlio_benchmark.reader.tf_base_reader import TFBaseReader +from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec + +dlp = Profile(MODULE_DATA_READER) + + +class DaliImageReader(DaliBaseReader): + @dlp.log_init + def __init__(self, dataset_type): + super().__init__(dataset_type) + + @dlp.log + def _load(self): + logging.debug( + f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + seed_change_epoch = False + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + seed_change_epoch = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + + stick_to_shard = True + if seed_change_epoch: + stick_to_shard = False + images, labels = fn.readers.file(files=files, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, + initial_fill=initial_fill, random_shuffle=random_shuffle, + shuffle_after_epoch=seed_change_epoch, + stick_to_shard=stick_to_shard, pad_last_batch=True) + dataset = fn.decoders.image(jpegs, device='cpu') + return dataset + + @dlp.log + def finalize(self): + pass \ No newline at end of file diff --git a/dlio_benchmark/reader/dali_npz_reader.py b/dlio_benchmark/reader/dali_npz_reader.py new file mode 100644 index 00000000..f2887ef5 --- /dev/null +++ b/dlio_benchmark/reader/dali_npz_reader.py @@ -0,0 +1,68 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import math +import logging +from time import time + +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.reader.dali_base_reader import DaliBaseReader +from dlio_benchmark.reader.tf_base_reader import TFBaseReader +from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec + +dlp = Profile(MODULE_DATA_READER) + + +class DaliNPZReader(DaliBaseReader): + @dlp.log_init + def __init__(self, dataset_type): + super().__init__(dataset_type) + + @dlp.log + def _load(self): + logging.debug( + f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + seed_change_epoch = False + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + seed_change_epoch = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + + stick_to_shard = True + if seed_change_epoch: + stick_to_shard = False + + dataset = fn.readers.numpy(device='cpu', files=self.file_list, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, initial_fill=initial_fill, + random_shuffle=random_shuffle, seed=seed, shuffle_after_epoch=seed_change_epoch, + stick_to_shard=stick_to_shard, pad_last_batch=True) + return dataset + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/dali_tfrecord_reader.py b/dlio_benchmark/reader/dali_tfrecord_reader.py new file mode 100644 index 00000000..4b8147af --- /dev/null +++ b/dlio_benchmark/reader/dali_tfrecord_reader.py @@ -0,0 +1,78 @@ +""" + Copyright (c) 2022, UChicago Argonne, LLC + All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os.path + +import math +import logging +from time import time + +import nvidia +import nvidia.dali.fn as fn +from dlio_benchmark.common.constants import MODULE_DATA_READER +from dlio_benchmark.reader.dali_base_reader import DaliBaseReader +from dlio_benchmark.reader.tf_base_reader import TFBaseReader +from dlio_benchmark.utils.utility import utcnow, PerfTrace, Profile +from dlio_benchmark.common.enumerations import DatasetType, Shuffle +import nvidia.dali.tfrecord as tfrec + +dlp = Profile(MODULE_DATA_READER) + + +class DaliTFRecordReader(DaliBaseReader): + @dlp.log_init + def __init__(self, dataset_type): + super().__init__(dataset_type) + + @dlp.log + def _load(self): + folder = "valid" + if self.dataset_type == DatasetType.TRAIN: + folder = "train" + index_folder = f"{self._args.data_folder}/index/{folder}" + index_files = [] + for file in self.file_list: + filename = os.path.basename(file) + index_files.append(f"{index_folder}/{filename}.idx") + logging.info( + f"{utcnow()} Reading {len(self.file_list)} files rank {self._args.my_rank}") + random_shuffle = False + seed = -1 + if self._args.sample_shuffle is not Shuffle.OFF: + if self._args.sample_shuffle is not Shuffle.SEED: + seed = self._args.seed + random_shuffle = True + initial_fill = 1024 + if self._args.shuffle_size > 0: + initial_fill = self._args.shuffle_size + prefetch_size = 1 + if self._args.prefetch_size > 0: + prefetch_size = self._args.prefetch_size + dataset = fn.readers.tfrecord(path=self.file_list, + index_path=index_files, + features={ + 'image': tfrec.FixedLenFeature((), tfrec.string, ""), + 'size': tfrec.FixedLenFeature([1], tfrec.int64, 0) + }, num_shards=self._args.comm_size, + prefetch_queue_depth=prefetch_size, + initial_fill=initial_fill, + random_shuffle=random_shuffle, seed=seed, + stick_to_shard=True, pad_last_batch=True) + return dataset["image"] + + @dlp.log + def finalize(self): + pass diff --git a/dlio_benchmark/reader/png_reader.py b/dlio_benchmark/reader/image_reader.py similarity index 96% rename from dlio_benchmark/reader/png_reader.py rename to dlio_benchmark/reader/image_reader.py index 64183dd3..1fe63a05 100644 --- a/dlio_benchmark/reader/png_reader.py +++ b/dlio_benchmark/reader/image_reader.py @@ -26,9 +26,9 @@ dlp = Profile(MODULE_DATA_READER) -class PNGReader(FormatReader): +class ImageReader(FormatReader): """ - Reader for PNG files + Reader for PNG / JPEG files """ @dlp.log_init diff --git a/dlio_benchmark/reader/jpeg_reader.py b/dlio_benchmark/reader/jpeg_reader.py deleted file mode 100644 index 664cde04..00000000 --- a/dlio_benchmark/reader/jpeg_reader.py +++ /dev/null @@ -1,62 +0,0 @@ -""" - Copyright (c) 2022, UChicago Argonne, LLC - All Rights Reserved - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import numpy as np -from PIL import Image - -from dlio_benchmark.common.constants import MODULE_DATA_READER -from dlio_benchmark.reader.reader_handler import FormatReader -from dlio_profiler.logger import fn_interceptor as Profile - -dlp = Profile(MODULE_DATA_READER) - - -class JPEGReader(FormatReader): - """ - Reader for JPEG files - """ - - @dlp.log_init - def __init__(self, dataset_type, thread_index, epoch): - super().__init__(dataset_type, thread_index) - - @dlp.log - def open(self, filename): - super().open(filename) - return np.asarray(Image.open(filename)) - - @dlp.log - def close(self, filename): - super().close(filename) - - @dlp.log - def get_sample(self, filename, sample_index): - super().get_sample(filename, sample_index) - image = self.open_file_map[filename] - dlp.update(image_size=image.nbytes) - - def next(self): - for batch in super().next(): - yield batch - - @dlp.log - def read_index(self, image_idx, step): - return super().read_index(image_idx, step) - - @dlp.log - def finalize(self): - return super().finalize() diff --git a/dlio_benchmark/reader/npz_reader.py b/dlio_benchmark/reader/npz_reader.py deleted file mode 100644 index f0144f74..00000000 --- a/dlio_benchmark/reader/npz_reader.py +++ /dev/null @@ -1,60 +0,0 @@ -""" - Copyright (c) 2022, UChicago Argonne, LLC - All Rights Reserved - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -import numpy as np - -from dlio_benchmark.common.constants import MODULE_DATA_READER -from dlio_benchmark.reader.reader_handler import FormatReader -from dlio_profiler.logger import fn_interceptor as Profile - -dlp = Profile(MODULE_DATA_READER) - - -class NPZReader(FormatReader): - """ - Reader for NPZ files - """ - - @dlp.log_init - def __init__(self, dataset_type, thread_index, epoch): - super().__init__(dataset_type, thread_index) - - @dlp.log - def open(self, filename): - super().open(filename) - return np.load(filename, allow_pickle=True)['x'] - - @dlp.log - def close(self, filename): - super().close(filename) - - @dlp.log - def get_sample(self, filename, sample_index): - super().get_sample(filename, sample_index) - image = self.open_file_map[filename][..., sample_index] - dlp.update(image_size=image.nbytes) - - def next(self): - for batch in super().next(): - yield batch - - @dlp.log - def read_index(self, image_idx, step): - return super().read_index(image_idx, step) - - @dlp.log - def finalize(self): - return super().finalize() \ No newline at end of file diff --git a/dlio_benchmark/reader/reader_factory.py b/dlio_benchmark/reader/reader_factory.py index 74fc353e..e6055dc4 100644 --- a/dlio_benchmark/reader/reader_factory.py +++ b/dlio_benchmark/reader/reader_factory.py @@ -43,18 +43,24 @@ def get_reader(type, dataset_type, thread_index, epoch_number): elif type == FormatType.CSV: from dlio_benchmark.reader.csv_reader import CSVReader return CSVReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.JPEG: - from dlio_benchmark.reader.jpeg_reader import JPEGReader - return JPEGReader(dataset_type, thread_index, epoch_number) - elif type == FormatType.PNG: - from dlio_benchmark.reader.png_reader import PNGReader - return PNGReader(dataset_type, thread_index, epoch_number) + elif type == FormatType.JPEG or FormatType.PNG: + if _args.data_loader == DataLoaderType.NATIVE_DALI + from dlio_benchmark.reader.image_reader import ImageReader + return DaliImageReader(dataset_type, thread_index, epoch_number) elif type == FormatType.NPZ: - from dlio_benchmark.reader.npz_reader import NPZReader - return NPZReader(dataset_type, thread_index, epoch_number) + if _args.data_loader == DataLoaderType.NATIVE_DALI + from dlio_benchmark.reader.dali_npz_reader import DaliNPZReader + return DaliNPZReader(dataset_type, thread_index, epoch_number) + else: + from dlio_benchmark.reader.npz_reader import NPZReader + return NPZReader(dataset_type, thread_index, epoch_number) elif type == FormatType.TFRECORD: - from dlio_benchmark.reader.tf_reader import TFReader - return TFReader(dataset_type, thread_index, epoch_number) + if _args.data_loader == DataLoaderType.NATIVE_DALI: + from dlio_benchmark.reader.dali_tf_reader import DaliTFReader + return TFReader(dataset_type, thread_index, epoch_number) + else: + from dlio_benchmark.reader.tf_reader import TFReader + return TFReader(dataset_type, thread_index, epoch_number) else: print("Loading data of %s format is not supported without framework data loader" %type) raise Exception(type)