Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Native Dali Data Loader support for TFRecord, Images, and NPZ files #118

Merged
merged 39 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
00a253d
fixed readthedoc build issue
zhenghh04 Nov 22, 2023
256cecf
partial merged the following PR: https://github.com/argonne-lcf/dlio_…
zhenghh04 Dec 1, 2023
63cb1c1
added back npz_reader
zhenghh04 Dec 1, 2023
cdd7389
fixed bugs
zhenghh04 Dec 1, 2023
4f0008e
fixed bugs
zhenghh04 Dec 1, 2023
fbc9748
fixed image reader issue
zhenghh04 Dec 1, 2023
9bde93c
fixed Profile, PerfTrace
zhenghh04 Dec 6, 2023
c976bc5
removed unnecessary logs
zhenghh04 Dec 6, 2023
58febf3
fixed dali_image_reader
zhenghh04 Dec 6, 2023
3cee0fd
fixed dali_image_reader
zhenghh04 Dec 6, 2023
a666d5c
added support for npy format
zhenghh04 Dec 6, 2023
f0722b6
added support for npy format
zhenghh04 Dec 6, 2023
7155d5d
changed enumerations
zhenghh04 Dec 8, 2023
dcd9855
Merge branch 'dali' of github.com:argonne-lcf/dlio_benchmark into dali
zhenghh04 Dec 8, 2023
9935840
added removed dali base reader
zhenghh04 Dec 11, 2023
5dc3907
fixed a bug
zhenghh04 Dec 11, 2023
3fb3602
added native-dali-loader tests in github action
zhenghh04 Dec 11, 2023
248cfa2
corrected github action formats
zhenghh04 Dec 11, 2023
d2af6a3
Merge branch 'main' into dali
zhenghh04 Dec 11, 2023
344298d
fixed read return
zhenghh04 Dec 11, 2023
d983e6b
Merge branch 'dali' of github.com:argonne-lcf/dlio_benchmark into dali
zhenghh04 Dec 11, 2023
ac025eb
removed abstractmethod
zhenghh04 Dec 11, 2023
d2d544f
fixed bugs
zhenghh04 Dec 11, 2023
7312c01
added dont_use_mmap
zhenghh04 Dec 11, 2023
02c3855
fixed indent
zhenghh04 Dec 11, 2023
60c508c
fixed csvreader
zhenghh04 Dec 11, 2023
5e96841
native_dali test with npy format instead of npz
zhenghh04 Dec 11, 2023
b1412e3
fixed issue of enum
zhenghh04 Dec 11, 2023
d659e5f
modify action so that dlio will always be installed
zhenghh04 Dec 11, 2023
96fa9c3
[skip ci] added documentation for dali
zhenghh04 Dec 11, 2023
f9aaac2
removed read; and define it as pipeline
zhenghh04 Dec 12, 2023
ca760fe
added exceptions for unimplemented methods
zhenghh04 Dec 12, 2023
a8ba464
added preprocessing
zhenghh04 Dec 14, 2023
1f03159
conditional cache for DLIO installation
zhenghh04 Dec 14, 2023
5dd6ebf
fixed bugs
zhenghh04 Dec 14, 2023
6af315b
fixed bugs
zhenghh04 Dec 14, 2023
8dc2b71
fixed bugs
zhenghh04 Dec 14, 2023
7e72ed5
fixing again
zhenghh04 Dec 14, 2023
4e727a8
tests again
zhenghh04 Dec 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,18 @@ jobs:
- name: Install System Tools
run: |
sudo apt update
sudo apt-get install $CC $CXX libc6
sudo apt-get install $CC $CXX libc6 git
sudo apt-get install mpich libhwloc-dev
- name: Install DLIO code only
if: steps.cache-modules.outputs.cache-hit == 'true'
run: |
source ${VENV}/bin/activate
rm -rf *.egg*
rm -rf build
rm -rf dist
pip uninstall -y dlio_benchmark
python setup.py build
python setup.py install
- name: Install DLIO
if: steps.cache-modules.outputs.cache-hit != 'true'
run: |
Expand All @@ -57,8 +67,7 @@ jobs:
pip install virtualenv
python -m venv ${VENV}
source ${VENV}/bin/activate
pip install .[test]
rm -rf dlio_benchmark
pip install .[test]
- name: Install DLIO Profiler
run: |
echo "Profiler ${DLIO_PROFILER} gcc $CC"
Expand Down Expand Up @@ -152,8 +161,18 @@ jobs:
- name: test-tf-loader-npz
run: |
source ${VENV}/bin/activate
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=8 ++workload.dataset.num_files_eval=8 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=8 ++workload.dataset.num_files_eval=8 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.data_reader.data_loader=tensorflow ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=2 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
- name: test-torch-native-dali-loader-npy
run: |
source ${VENV}/bin/activate
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.reader.data_loader=native_dali ++workload.dataset.format=npy ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.reader.data_loader=native_dali ++workload.dataset.format=npy ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
- name: test-tf-native-dali-loader-npy
run: |
source ${VENV}/bin/activate
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.dataset.format=npy ++workload.reader.data_loader=native_dali ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=False ++workload.workflow.generate_data=True ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
mpirun -np 2 dlio_benchmark workload=unet3d ++workload.framework=tensorflow ++workload.dataset.format=npy ++workload.reader.data_loader=native_dali ++workload.train.computation_time=0.05 ++workload.evaluation.eval_time=0.01 ++workload.train.epochs=1 ++workload.workflow.train=True ++workload.workflow.generate_data=False ++workload.dataset.num_files_train=16 ++workload.dataset.num_files_eval=16 ++workload.reader.read_threads=2 ++workload.dataset.record_length=4096 ++workload.dataset.record_length_stdev=0
- name: test_subset
run: |
source ${VENV}/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dlio_benchmark ++workload.workflow.generate_data=True
git clone https://github.com/argonne-lcf/dlio_benchmark
cd dlio_benchmark/
pip install .[dlio_profiler]

```
## Container

```bash
Expand Down
6 changes: 5 additions & 1 deletion dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ class FormatType(Enum):
HDF5 = 'hdf5'
CSV = 'csv'
NPZ = 'npz'
NPY = 'npy'
HDF5_OPT = 'hdf5_opt'
JPEG = 'jpeg'
PNG = 'png'

def __str__(self):
return self.value

@ staticmethod
@staticmethod
def get_enum(value):
if FormatType.TFRECORD.value == value:
return FormatType.TFRECORD
Expand All @@ -110,6 +111,8 @@ def get_enum(value):
return FormatType.CSV
elif FormatType.NPZ.value == value:
return FormatType.NPZ
elif FormatType.NPY.value == value:
return FormatType.NPY
elif FormatType.HDF5_OPT.value == value:
return FormatType.HDF5_OPT
elif FormatType.JPEG.value == value:
Expand All @@ -124,6 +127,7 @@ class DataLoaderType(Enum):
TENSORFLOW='tensorflow'
PYTORCH='pytorch'
DALI='dali'
NATIVE_DALI='native_dali'
CUSTOM='custom'
NONE='none'

Expand Down
3 changes: 3 additions & 0 deletions dlio_benchmark/data_generator/generator_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def get_generator(type):
elif type == FormatType.NPZ:
from dlio_benchmark.data_generator.npz_generator import NPZGenerator
return NPZGenerator()
elif type == FormatType.NPY:
from dlio_benchmark.data_generator.npy_generator import NPYGenerator
return NPYGenerator()
elif type == FormatType.JPEG:
from dlio_benchmark.data_generator.jpeg_generator import JPEGGenerator
return JPEGGenerator()
Expand Down
53 changes: 53 additions & 0 deletions dlio_benchmark/data_generator/npy_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from dlio_benchmark.common.enumerations import Compression
from dlio_benchmark.data_generator.data_generator import DataGenerator

import logging
import numpy as np

from dlio_benchmark.utils.utility import progress, utcnow
from dlio_profiler.logger import fn_interceptor as Profile
from shutil import copyfile
from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR

dlp = Profile(MODULE_DATA_GENERATOR)

"""
Generator for creating data in NPZ format.
"""
class NPYGenerator(DataGenerator):
def __init__(self):
super().__init__()

@dlp.log
def generate(self):
"""
Generator for creating data in NPY format of 3d dataset.
"""
super().generate()
np.random.seed(10)
record_labels = [0] * self.num_samples
for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)):
dim1, dim2 = self.get_dimension()
records = np.random.randint(255, size=(dim1, dim2, self.num_samples), dtype=np.uint8)
out_path_spec = self.storage.get_uri(self._file_list[i])
progress(i+1, self.total_files_to_generate, "Generating NPY Data")
prev_out_spec = out_path_spec
np.save(out_path_spec, records)
np.random.seed()
15 changes: 14 additions & 1 deletion dlio_benchmark/data_generator/tf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from subprocess import call

from dlio_benchmark.data_generator.data_generator import DataGenerator
import numpy as np
import tensorflow as tf
from dlio_benchmark.utils.utility import progress, utcnow
from dlio_profiler.logger import fn_interceptor as Profile

from dlio_benchmark.utils.utility import progress, utcnow
from shutil import copyfile
from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR

Expand Down Expand Up @@ -64,4 +67,14 @@ def generate(self):
serialized = example.SerializeToString()
# Write the serialized data to the TFRecords file.
writer.write(serialized)
tfrecord2idx_script = "tfrecord2idx"
folder = "train"
if "valid" in out_path_spec:
folder = "valid"
index_folder = f"{self._args.data_folder}/index/{folder}"
filename = os.path.basename(out_path_spec)
self.storage.create_node(index_folder, exist_ok=True)
tfrecord_idx = f"{index_folder}/{filename}.idx"
if not os.path.isfile(tfrecord_idx):
call([tfrecord2idx_script, out_path_spec, tfrecord_idx])
np.random.seed()
3 changes: 3 additions & 0 deletions dlio_benchmark/data_loader/data_loader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def get_loader(type, format_type, dataset_type, epoch):
elif type == DataLoaderType.DALI:
from dlio_benchmark.data_loader.dali_data_loader import DaliDataLoader
return DaliDataLoader(format_type, dataset_type, epoch)
elif type == DataLoaderType.NATIVE_DALI:
from dlio_benchmark.data_loader.native_dali_data_loader import NativeDaliDataLoader
return NativeDaliDataLoader(format_type, dataset_type, epoch)
else:
print("Data Loader %s not supported or plugins not found" % type)
raise Exception(str(ErrorCodes.EC1004))
60 changes: 60 additions & 0 deletions dlio_benchmark/data_loader/native_dali_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from time import time
import logging
import math
import numpy as np
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali as dali
from nvidia.dali.plugin.pytorch import DALIGenericIterator

from dlio_benchmark.common.constants import MODULE_DATA_LOADER
from dlio_benchmark.common.enumerations import Shuffle, DataLoaderType, DatasetType
from dlio_benchmark.data_loader.base_data_loader import BaseDataLoader
from dlio_benchmark.reader.reader_factory import ReaderFactory
from dlio_benchmark.utils.utility import utcnow, get_rank, timeit
from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile

dlp = Profile(MODULE_DATA_LOADER)


class NativeDaliDataLoader(BaseDataLoader):
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch):
super().__init__(format_type, dataset_type, epoch, DataLoaderType.NATIVE_DALI)
self.pipelines = []

@dlp.log
def read(self):
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
parallel = True if self._args.read_threads > 0 else False
self.pipelines = []
num_threads = 1
if self._args.read_threads > 0:
num_threads = self._args.read_threads
# None executes pipeline on CPU and the reader does the batching
pipeline = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=None, py_num_workers=num_threads,
exec_async=False, exec_pipelined=False)
with pipeline:
images = ReaderFactory.get_reader(type=self.format_type,
dataset_type=self.dataset_type,
thread_index=-1,
epoch_number=self.epoch_number).pipeline()
pipeline.set_outputs(images)
self.pipelines.append(pipeline)

@dlp.log
def next(self):
super().next()
num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
for step in range(num_samples // batch_size):
_dataset = DALIGenericIterator(self.pipelines, ['data'])
for batch in _dataset:
logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ")
yield batch

@dlp.log
def finalize(self):
pass
2 changes: 1 addition & 1 deletion dlio_benchmark/reader/csv_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def read_index(self, image_idx, step):

@dlp.log
def finalize(self):
return super().finalize()
return super().finalize()
96 changes: 96 additions & 0 deletions dlio_benchmark/reader/dali_image_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Copyright (c) 2022, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import logging
from time import time, sleep
import numpy as np

import nvidia.dali.fn as fn
from dlio_benchmark.common.constants import MODULE_DATA_READER
from dlio_benchmark.dlio_benchmark.reader.reader_handler import FormatReader
from dlio_benchmark.utils.utility import utcnow
from dlio_benchmark.common.enumerations import DatasetType, Shuffle
import nvidia.dali.tfrecord as tfrec
from dlio_profiler.logger import dlio_logger as PerfTrace, fn_interceptor as Profile

dlp = Profile(MODULE_DATA_READER)


class DaliImageReader(FormatReader):
@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)

@dlp.log
def open(self, filename):
super().open(filename)

def close(self):
super().close()

def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
raise Exception("get sample method is not implemented in dali readers")

def next(self):
super().next()
raise Exception("next method is not implemented in dali readers")

def read_index(self):
super().read_index()
raise Exception("read_index method is not implemented in dali readers")

@dlp.log
def pipeline(self):
logging.debug(
zhenghh04 marked this conversation as resolved.
Show resolved Hide resolved
f"{utcnow()} Reading {len(self._file_list)} files rank {self._args.my_rank}")
random_shuffle = False
seed = -1
seed_change_epoch = False
if self._args.sample_shuffle is not Shuffle.OFF:
if self._args.sample_shuffle is not Shuffle.SEED:
seed = self._args.seed
random_shuffle = True
seed_change_epoch = True
initial_fill = 1024
if self._args.shuffle_size > 0:
initial_fill = self._args.shuffle_size
prefetch_size = 1
if self._args.prefetch_size > 0:
prefetch_size = self._args.prefetch_size

stick_to_shard = True
if seed_change_epoch:
stick_to_shard = False
images, labels = fn.readers.file(files=self._file_list, num_shards=self._args.comm_size,
zhenghh04 marked this conversation as resolved.
Show resolved Hide resolved
prefetch_queue_depth=prefetch_size,
initial_fill=initial_fill, random_shuffle=random_shuffle,
shuffle_after_epoch=seed_change_epoch,
stick_to_shard=stick_to_shard, pad_last_batch=True,
dont_use_mmap=self._args.dont_use_mmap)
images = fn.decoders.image(images, device='cpu')
fn.python_function(dataset, function=self.preprocess, num_outputs=0)
dataset = self._resize(images)
return dataset

@dlp.log
def _resize(self, dataset):
return fn.resize(dataset, size=[self._args.max_dimension, self._args.max_dimension])

@dlp.log
def finalize(self):
pass
hariharan-devarajan marked this conversation as resolved.
Show resolved Hide resolved
Loading