diff --git a/dali/python/nvidia/dali/plugin/pytorch/__init__.py b/dali/python/nvidia/dali/plugin/pytorch/__init__.py index 546c2ecf79a..c2085eec4c3 100644 --- a/dali/python/nvidia/dali/plugin/pytorch/__init__.py +++ b/dali/python/nvidia/dali/plugin/pytorch/__init__.py @@ -27,24 +27,12 @@ from nvidia.dali.plugin.base_iterator import LastBatchPolicy import torch -import torch.multiprocessing as mp import torch.utils.dlpack as torch_dlpack # noqa: F401 -from torch.utils import data -from torch.utils.data._utils.collate import collate -from torch.utils.data.dataloader import ( - DataLoader, - _MultiProcessingDataLoaderIter, - _SingleProcessDataLoaderIter, - _BaseDataLoaderIter, -) import ctypes import numpy as np -import threading -from queue import Empty - -from nvtx import nvtx from . import fn # noqa: F401 +from . import proxy # noqa: F401 from nvidia.dali.plugin.pytorch._torch_function import TorchPythonFunction as TorchPythonFunction @@ -66,51 +54,6 @@ } -def to_torch_tensor(tensor_or_tl, device_id=0): - """ - Copy contents of DALI tensor to PyTorch's Tensor. - - Parameters - ---------- - `tensor_or_tl` : TensorGPU or TensorListGPU - `arr` : torch.Tensor - Destination of the copy - `cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t. - CUDA stream to be used for the copy - (if not provided, an internal user stream will be selected) - In most cases, using pytorch's current stream is expected (for example, - if we are copying to a tensor allocated with torch.zeros(...)) - """ - if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)): - dali_tensor = tensor_or_tl.as_tensor() - else: - dali_tensor = tensor_or_tl - - if isinstance(dali_tensor, (TensorGPU)): - torch_device = torch.device("cuda", device_id) - else: - torch_device = torch.device("cpu") - - out_torch = torch.empty( - dali_tensor.shape(), - dtype=to_torch_type[dali_tensor.dtype], - device=torch_device, - ) - - # turn raw int to a c void pointer - c_type_pointer = ctypes.c_void_p(out_torch.data_ptr()) - if isinstance(dali_tensor, (TensorGPU)): - non_blocking = True - cuda_stream = torch.cuda.current_stream(device=torch_device) - cuda_stream = types._raw_cuda_stream(cuda_stream) - stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream) - tensor_or_tl.copy_to_external(c_type_pointer, stream, non_blocking) - else: - tensor_or_tl.copy_to_external(c_type_pointer) - - return out_torch - - def feed_ndarray( dali_tensor: Union[TensorCPU, TensorGPU, TensorListCPU, TensorListGPU], arr: torch.Tensor, @@ -853,266 +796,3 @@ def __next__(self) -> List[Dict[str, torch.Tensor]]: DENSE_TAG: str = "dense" SPARSE_LIST_TAG: str = "sparse_list" SPARSE_COO_TAG: str = "sparse_coo" - - -class DALIProxy: - """ - Proxy to communicate to send processing requests to a DALI pipeline running on the main loop. - This is used by PyTorch data workers to assign some processing to the loaded samples, which can - execute in the main process via the GPU. - - Background: As the PyTorch workers run on separate processes, using the GPU directly from each - of those is not ideal for performance, due to the usage of several CUDA contexts. - """ - - def __init__(self, input_names): - """ - Initializes a new DALI proxy instance. - - Args: - input_names (list): list of strings representing the inputs to the pipeline. Those should match - the names of the ``external_source`` nodes in the DALI pipeline. - """ - self.input_names = input_names - self.num_inputs = len(input_names) - # Multi-process queue used to transfer data from the pytorch workers to the main process - self.send_q = mp.Queue() - # Multi-process queue used by the main process to remember the actual order of execution of the requests - self.order_q = mp.Queue() - # Torch worker id, to be filled on first call to worker_id() - self._worker_id = None - # Iteration index for the current worker - self.data_idx = 0 - - def worker_id(self): - if self._worker_id is None: - self._worker_id = torch.utils.data.get_worker_info().id - return self._worker_id - - class PipelineOutputRef: - """ - Placeholder for a pipeline output reference, after the iteration has been scheduled to DALI. - """ - def __init__(self, info): - self.info = info - - def schedule_batch(self, inputs): - """ - Schedule a pipeline run to DALI, by queuing the worker id, the iteration index and the inputs - """ - # Identifier of this request - info = (self.worker_id(), self.data_idx) - with nvtx.annotate(f"dali_proxy.send_q.put {info}", color="blue"): - self.send_q.put((info, inputs)) - self.data_idx = self.data_idx + 1 - # Returns a placeholder, which is replaced with the actual data once the iteration completes - return DALIProxy.PipelineOutputRef(info) - - class PipelineRunRef: - """ - Placeholder for a pipeline run reference, which is returned by the data worker instead of the actual data - - The PyTorch worker returns this trivial object, only containing information about this proxy instance and - the input data to the pipeline. Later in the collate function, we send the data for execution to DALI. - """ - def __init__(self, dali_proxy, inputs): - self.dali_proxy = dali_proxy - self.inputs = inputs - assert len(self.inputs) == dali_proxy.num_inputs - - def transform(self, *inputs): - """ - The 'transform' function consists of returning a reference to the pipeline run - """ - assert len(inputs) == self.num_inputs, f"Unexpected number of inputs: {len(inputs)}" - return DALIProxy.PipelineRunRef(self, inputs) - - -def collate_pipeline_run_ref_fn(pipe_out, *, collate_fn_map=None): - """ - Special collate function that schedules a batch for execution - """ - assert len(pipe_out) > 0 - first_elem = pipe_out[0] - inputs = [[] for idx in range(len(first_elem.inputs))] - for elem in pipe_out: - assert first_elem.dali_proxy == elem.dali_proxy - for idx, input_ref in enumerate(elem.inputs): - inputs[idx].append(input_ref) - return first_elem.dali_proxy.schedule_batch(inputs) - - -def custom_collate(batch): - """ - Subscribe a special collate function for PipelineRunRef, that handles the scheduling of the iteration - on the fly - """ - collate_fn_map = data._utils.collate.default_collate_fn_map - collate_fn_map.update({DALIProxy.PipelineRunRef: collate_pipeline_run_ref_fn}) - return collate(batch, collate_fn_map=collate_fn_map) - - -def flatten_tuple(nested_tuple): - """ - Flattens a nested tuple - """ - flat_list = [] - - def _flatten(t): - for item in t: - if isinstance(item, tuple): - _flatten(item) - else: - flat_list.append(item) - - _flatten(nested_tuple) - return tuple(flat_list) - - -class DALIMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): - """ - Data loader iterator used by the DALI proxy data loader - """ - def __init__(self, loader): - super().__init__(loader) - self.loader = loader - - def _next_data(self): - data = super()._next_data() - if not hasattr(data, "__iter__"): - print( - "Warning: Non iterable returned from dataloader. Please " - " review the code, since it usually indicates a bug in a pipeline." - ) - data = [data] - for data_idx, data_elem in enumerate(data): - # If loader returns a dictionary the iterator iterates over its keys. - # We need to access a value. Probably need to address more casess. - if isinstance(data, dict): - if isinstance(data[data_elem], DALIProxy.PipelineOutputRef): - data[data_elem] = self.loader.get_outputs(data[data_elem].info) - if isinstance(data_elem, DALIProxy.PipelineOutputRef): - data[data_idx] = self.loader.get_outputs(data_elem.info) - return data - - -class DALIDataLoader(DataLoader): - """ - DALI data loader to be used in the main loop, which runs the DALI pipeline doing the processing - asynchronously with regards to the training. - """ - def __init__(self, pipe, dali_proxy, *args, **kwargs): - if "collate_fn" in kwargs and kwargs["collate_fn"] is not None: - print( - "Warning: Make sure to handle DALIProxy.PipelineRunRef when providing" - " a custom collate_fn" - ) - else: - kwargs["collate_fn"] = custom_collate - super().__init__(*args, **kwargs) - self.pipe = pipe - self.dali_proxy = dali_proxy - self.t = None - self.thread_stop_event = None - self.cache_outputs = dict() - self.cache_inputs = dict() - - def outputs(self): - # Get the information about the order of execution, so that we know which one is the next iteration - torch.cuda.nvtx.range_push("order_q.get") - info = self.dali_proxy.order_q.get() - torch.cuda.nvtx.range_pop() - - # Get the outputs from the current iteration - torch.cuda.nvtx.range_push(f"pipe.outputs {info}") - outputs = self.pipe.outputs() - torch.cuda.nvtx.range_pop() - - # Return information about the iteration, together with the data - processed_outputs = tuple( - [to_torch_tensor(output, device_id=self.pipe.device_id) for output in outputs] - ) - return (info, processed_outputs) - - def get_outputs(self, req_info): - req_outputs = None - # If the data was already read, just return it (and clear the cache entry) - if req_info in self.cache_outputs: - req_outputs = self.cache_outputs[req_info] - del self.cache_outputs[req_info] - del self.cache_inputs[req_info] - else: - info = None - # If not the data we are looking for, store it and keep processing until we find it - while req_info != info: - info, processed_outputs = self.outputs() - if info == req_info: - req_outputs = processed_outputs - del self.cache_inputs[req_info] - else: - self.cache_outputs[info] = processed_outputs - # Unpack single element tuples - if isinstance(req_outputs, tuple) and len(req_outputs) == 1: - req_outputs = req_outputs[0] - return req_outputs - - def thread_fn(self): - """ - Asynchronous DALI thread that gets iteration data from the queue and schedules it for execution - """ - while not self.thread_stop_event.is_set(): - try: - torch.cuda.nvtx.range_push("dali_proxy.send_q.get") - info, inputs = self.dali_proxy.send_q.get(timeout=5) - torch.cuda.nvtx.range_pop() - self.cache_inputs[info] = inputs - except mp.TimeoutError: - continue - except Empty: - continue - torch.cuda.nvtx.range_push("dali_proxy.order_q.put {info}") - self.dali_proxy.order_q.put(info) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push(f"feed_input {info}") - for idx, input_name in enumerate(self.dali_proxy.input_names): - self.pipe.feed_input(input_name, inputs[idx]) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("schedule_run {info}") - self.pipe.schedule_run() - torch.cuda.nvtx.range_pop() - - def start_thread(self): - """ - Starts the DALI pipeline thread - """ - if self.t is not None: - return - self.t = threading.Thread(target=DALIDataLoader.thread_fn, args=(self,)) - self.thread_stop_event = threading.Event() - self.t.start() - - def stop_thread(self): - """ - Stops the DALI pipeline thread - """ - if self.thread_stop_event is None: - return - self.thread_stop_event.set() - self.t.join() - self.t = None - self.thread_stop_event = None - - def __enter__(self): - self.start_thread() - - def __exit__(self, exc_type, exc_value, tb): - self.stop_thread() - - def _get_iterator(self) -> "_BaseDataLoaderIter": - if self.num_workers == 0: - return _SingleProcessDataLoaderIter(self) - else: - self.check_worker_number_rationality() - return DALIMultiProcessingDataLoaderIter(self) diff --git a/dali/python/nvidia/dali/plugin/pytorch/proxy/__init__.py b/dali/python/nvidia/dali/plugin/pytorch/proxy/__init__.py new file mode 100644 index 00000000000..68c34adfa05 --- /dev/null +++ b/dali/python/nvidia/dali/plugin/pytorch/proxy/__init__.py @@ -0,0 +1,345 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.multiprocessing as mp +from torch.utils import data as torchdata +from torch.utils.data._utils.collate import collate +from nvidia.dali.backend import TensorGPU, TensorListCPU, TensorListGPU +from nvidia.dali import types, Pipeline +from nvidia.dali.external_source import ExternalSource +import ctypes +import threading +from queue import Empty +from nvtx import nvtx +from .. import to_torch_type + + +def _external_source_node_names(pipeline): + if not pipeline._py_graph_built: + pipeline._build_graph() + input_node_names = [] + for op in pipeline._ops: + if isinstance(op._op, ExternalSource): + input_node_names.append(op.name) + return input_node_names + + +def _to_torch_tensor(tensor_or_tl, device_id=0): + """ + Copy contents of DALI tensor to PyTorch's Tensor. + + Parameters + ---------- + `tensor_or_tl` : TensorGPU or TensorListGPU + `arr` : torch.Tensor + Destination of the copy + `cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t. + CUDA stream to be used for the copy + (if not provided, an internal user stream will be selected) + In most cases, using pytorch's current stream is expected (for example, + if we are copying to a tensor allocated with torch.zeros(...)) + """ + if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)): + dali_tensor = tensor_or_tl.as_tensor() + else: + dali_tensor = tensor_or_tl + + if isinstance(dali_tensor, (TensorGPU)): + torch_device = torch.device("cuda", device_id) + else: + torch_device = torch.device("cpu") + + out_torch = torch.empty( + dali_tensor.shape(), + dtype=to_torch_type[dali_tensor.dtype], + device=torch_device, + ) + + # turn raw int to a c void pointer + c_type_pointer = ctypes.c_void_p(out_torch.data_ptr()) + if isinstance(dali_tensor, (TensorGPU)): + non_blocking = True + cuda_stream = torch.cuda.current_stream(device=torch_device) + cuda_stream = types._raw_cuda_stream(cuda_stream) + stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream) + tensor_or_tl.copy_to_external(c_type_pointer, stream, non_blocking) + else: + tensor_or_tl.copy_to_external(c_type_pointer) + + return out_torch + + +class DALIPipelineOutputRef: + """ + Placeholder for a pipeline output reference, after the iteration has been scheduled to DALI. + """ + + def __init__(self, info): + self.info = info + + +class DALIProxy: + def __init__(self, input_names, send_q): + self.input_names = input_names + # Shared queue with the server + self.send_q = send_q + # Torch worker id, to be filled on first call to worker_id() + self._worker_id = None + # Iteration index for the current worker + self.data_idx = 0 + + @property + def worker_id(self): + """Getter for 'worker_id'""" + if self._worker_id is None: + self._worker_id = torchdata.get_worker_info().id + return self._worker_id + + class _PipelineRunRef: + """ + Placeholder for a pipeline run reference, which is returned by the data worker instead of + the actual data. The PyTorch worker returns this trivial object, only containing information + about this proxy instance and the input data to the pipeline. Later in the collate function, + we send the data for execution to DALI. + """ + + def __init__(self, proxy, inputs): + self.proxy = proxy + self.inputs = inputs + if len(inputs) != len(self.proxy.input_names): + raise RuntimeError( + f"Unexpected number of inputs. Expected: {self.input_names}, got: {inputs}" + ) + + def schedule_batch(self, inputs): + # Identifier of this request + info = (self.worker_id(), self.data_idx) + with nvtx.annotate(f"dali_proxy.send_q.put {info}", color="blue"): + self.send_q.put((info, inputs)) + self.data_idx = self.data_idx + 1 + # Returns a placeholder, which is replaced with the actual data once the iteration completes + return DALIPipelineOutputRef(self.info) + + def __call__(self, *inputs): + """ + Returns a reference to the pipeline run + """ + if len(inputs) != len(self.input_names): + raise RuntimeError( + f"Unexpected number of inputs. Expected: {self.input_names}, got: {inputs}" + ) + return self._PipelineRunRef(self, inputs) + + +class DALIServer: + def __init__(self, pipeline, input_names=None): + """ + Initializes a new DALI proxy instance. + + Args: + input_names (list): list of strings representing the inputs to the pipeline. Those + should match the names of the ``external_source`` nodes in the DALI pipeline. + """ + assert isinstance(pipeline, Pipeline), f"Expected an NVIDIA DALI pipeline, got: {pipeline}" + self.pipe = pipeline + self.pipe_input_names = _external_source_node_names(self.pipe) + if len(self.pipe_input_names) == 0: + raise RuntimeError("The provided pipeline doesn't have any inputs") + if len(self.pipe_input_names) == 1: + assert input_names is None or input_names[0] == self.pipe_input_names[0] + self.input_names = self.pipe_input_names + elif input_names is None or len(input_names) != len(self.pipe_input_names): + raise RuntimeError( + "The provided pipeline has more than one output. In such case, the argument " + "`input_names` should containi the same exact number of strings, one for " + "each pipeline input to be mapped by the proxy callable object" + ) + self.input_names = input_names + self.num_inputs = len(input_names) + + # Multi-process queue used to transfer data from the pytorch workers to the main process + self.send_q = mp.Queue() + # Multi-process queue used by the main process to remember the actual order + # of execution of the requests + self.order_q = mp.Queue() + + @property + def proxy(self): + return DALIProxy(self.input_names, self.send_q) + + def next_outputs(self): + # Get the information about the order of execution, so that we know which one is + # the next iteration + torch.cuda.nvtx.range_push("order_q.get") + info = self.dali_proxy.order_q.get() + torch.cuda.nvtx.range_pop() + + # Get the outputs from the current iteration + torch.cuda.nvtx.range_push(f"pipe.outputs {info}") + outputs = self.pipe.outputs() + torch.cuda.nvtx.range_pop() + + # Return information about the iteration, together with the data + processed_outputs = tuple( + [_to_torch_tensor(output, device_id=self.pipe.device_id) for output in outputs] + ) + return (info, processed_outputs) + + def get_outputs(self, req_info): + req_outputs = None + # If the data was already read, just return it (and clear the cache entry) + if req_info in self.cache_outputs: + req_outputs = self.cache_outputs[req_info] + del self.cache_outputs[req_info] + del self.cache_inputs[req_info] + else: + info = None + # If not the data we are looking for, store it and keep processing until we find it + while req_info != info: + info, processed_outputs = self.next_outputs() + if info == req_info: + req_outputs = processed_outputs + del self.cache_inputs[req_info] + else: + self.cache_outputs[info] = processed_outputs + # Unpack single element tuples + if isinstance(req_outputs, tuple) and len(req_outputs) == 1: + req_outputs = req_outputs[0] + return req_outputs + + def thread_fn(self): + """ + Asynchronous DALI thread that gets iteration data from the queue and schedules it + for execution + """ + while not self.thread_stop_event.is_set(): + try: + torch.cuda.nvtx.range_push("send_q.get") + info, inputs = self.send_q.get(timeout=5) + torch.cuda.nvtx.range_pop() + self.cache_inputs[info] = inputs + except mp.TimeoutError: + continue + except Empty: + continue + torch.cuda.nvtx.range_push(f"order_q.put {info}") + self.order_q.put(info) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push(f"feed_input {info}") + for idx, input_name in enumerate(self.input_names): + self.pipe.feed_input(input_name, inputs[idx]) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push(f"schedule_run {info}") + self.pipe.schedule_run() + torch.cuda.nvtx.range_pop() + + def start_thread(self): + """ + Starts the DALI pipeline thread + """ + if self.thread is not None: + return + self.thread = threading.Thread(target=DALIServer.thread_fn, args=(self,)) + self.thread_stop_event = threading.Event() + self.thread.start() + + def stop_thread(self): + """ + Stops the DALI pipeline thread + """ + if self.thread_stop_event is None: + return + self.thread_stop_event.set() + self.thread.join() + self.thread = None + self.thread_stop_event = None + + def __enter__(self): + self.start_thread() + + def __exit__(self, exc_type, exc_value, tb): + self.stop_thread() + + +def _collate_pipeline_run_ref_fn(pipe_out, *, collate_fn_map=None): + """ + Special collate function that schedules a batch for execution + """ + assert len(pipe_out) > 0 + first_elem = pipe_out[0] + inputs = [[] for idx in range(len(first_elem.inputs))] + proxy = first_elem.proxy + for elem in pipe_out: + assert proxy == elem.proxy + for idx, input_ref in enumerate(elem.inputs): + inputs[idx].append(input_ref) + return proxy.schedule_batch(inputs) + + +def _custom_collate(batch): + """ + Subscribe a special collate function for PipelineRunRef, that handles the scheduling + of the iteration on the fly + """ + collate_fn_map = torchdata._utils.collate.default_collate_fn_map + collate_fn_map.update({DALIServer.PipelineRunRef: _collate_pipeline_run_ref_fn}) + return collate(batch, collate_fn_map=collate_fn_map) + + +class DataLoader(torchdata.dataloader.DataLoader): + """ + DALI data loader to be used in the main loop, which runs the DALI pipeline doing the + processing asynchronously with regards to the training. + """ + + class Iterator(torchdata.dataloader._MultiProcessingDataLoaderIter): + """ + Data loader iterator used by the DALI proxy data loader + """ + + def __init__(self, loader): + super().__init__(loader) + self.loader = loader + + def _next_data(self): + data = super()._next_data() + if not hasattr(data, "__iter__"): + print( + "Warning: Non iterable returned from dataloader. Please " + " review the code, since it usually indicates a bug in the pipeline." + ) + data = [data] + for data_idx, data_elem in enumerate(data): + # If loader returns a dictionary the iterator iterates over its keys. + # We need to access a value. Probably need to address more casess. + if isinstance(data, dict): + if isinstance(data[data_elem], DALIPipelineOutputRef): + data[data_elem] = self.loader.dali_server.get_outputs(data[data_elem].info) + if isinstance(data_elem, DALIPipelineOutputRef): + data[data_idx] = self.loader.dali_server.get_outputs(data_elem.info) + return data + + def __init__(self, dali_server, *args, **kwargs): + self.dali_server = dali_server + super().__init__() + if "collate_fn" in kwargs and kwargs["collate_fn"] is not None: + print( + "Warning: Make sure to handle DALIServer.PipelineRunRef when providing" + " a custom collate_fn (see collate_pipeline_run_ref_fn)" + ) + else: + kwargs["collate_fn"] = _custom_collate diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py index 44e3f12f945..8d7360ac07e 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py @@ -19,7 +19,6 @@ from nvidia.dali.auto_aug import auto_augment, trivial_augment - @pipeline_def(enable_conditionals=True) def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_augmentation, dali_device="gpu", rank=0, world_size=1, send_filepaths=False): diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py index 637f291056a..1a200b3bc58 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py @@ -38,12 +38,11 @@ DATA_BACKEND_CHOICES = ["pytorch", "synthetic"] try: from nvidia.dali.plugin.pytorch import DALIClassificationIterator + from nvidia.dali.plugin.pytorch import proxy as dali_proxy import nvidia.dali.types as types from image_classification.dali import training_pipe, validation_pipe - from nvidia.dali.plugin.pytorch import DALIProxy, DALIDataLoader - DATA_BACKEND_CHOICES.append("dali") DATA_BACKEND_CHOICES.append("dali_proxy") @@ -314,14 +313,6 @@ def __iter__(self): self.dataloader, self.num_classes, self.one_hot, self.normalize ) - def __enter__(self): - if hasattr(self.dataloader, "__enter__"): - self.dataloader.__enter__() - - def __exit__(self, exc_type, exc_value, tb): - if hasattr(self.dataloader, "__exit__"): - self.dataloader.__exit__(exc_type, exc_value, tb) - def __len__(self): return len(self.dataloader) @@ -463,7 +454,7 @@ def get_impl(data_path, "triangular": types.INTERP_TRIANGULAR, }[interpolation] - output_layout = 'CHW' #"HWC" if memory_format == torch.channels_last else "CHW" + output_layout = "HWC" if memory_format == torch.channels_last else "CHW" rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 pipeline_kwargs = { @@ -479,11 +470,11 @@ def get_impl(data_path, **pipeline_kwargs) pipe.build() - dali_proxy = DALIProxy(input_names=["images"]) + dali_server = dali_proxy.DALIServer(pipe) train_dataset = datasets.ImageFolder( traindir, - transform=dali_proxy.transform, + transform=dali_server.proxy, loader=read_filepath if send_filepaths else read_file ) @@ -494,9 +485,8 @@ def get_impl(data_path, else: train_sampler = None - train_loader = DALIDataLoader( - pipe, - dali_proxy, + train_loader = dali_proxy.DataLoader( + dali_server, train_dataset, sampler=train_sampler, batch_size=batch_size, @@ -534,7 +524,7 @@ def get_impl(data_path, "triangular": types.INTERP_TRIANGULAR, }[interpolation] - output_layout = 'CHW' #"HWC" if memory_format == torch.channels_last else "CHW" + output_layout = "HWC" if memory_format == torch.channels_last else "CHW" valdir = os.path.join(data_path, "val") rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 @@ -554,10 +544,10 @@ def get_impl(data_path, pipe.build() - dali_proxy = DALIProxy(input_names=["images"]) + dali_server = dali_proxy.DALIServer(pipe) val_dataset = datasets.ImageFolder( valdir, - transform=dali_proxy.transform, + transform=dali_server.proxy, loader=read_filepath if send_filepaths else read_file ) @@ -569,9 +559,8 @@ def get_impl(data_path, else: val_sampler = None - val_loader = DALIDataLoader( - pipe, - dali_proxy, + val_loader = dali_proxy.DataLoader( + dali_server, val_dataset, sampler=val_sampler, batch_size=batch_size, diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/mixup.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/mixup.py index 09fb061ddab..b135f336c40 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/mixup.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/mixup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the BSD 3-Clause License (the "License"); # you may not use this file except in compliance with the License. @@ -38,14 +38,6 @@ def mixup_loader(self, loader): i, t = mixup(self.alpha, input, target) yield i, t - def __enter__(self): - if hasattr(self.dataloader, '__enter__'): - self.dataloader.__enter__() - - def __exit__(self, exc_type, exc_value, tb): - if hasattr(self.dataloader, '__exit__'): - self.dataloader.__exit__(exc_type, exc_value, tb) - def __iter__(self): return self.mixup_loader(self.dataloader) diff --git a/docs/examples/use_cases/pytorch/efficientnet/main.py b/docs/examples/use_cases/pytorch/efficientnet/main.py index a098fce45ea..726c38b6cf2 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/main.py +++ b/docs/examples/use_cases/pytorch/efficientnet/main.py @@ -631,7 +631,10 @@ def main(args, model_args, model_arch): best_prec1, ) = prepare_for_training(args, model_args, model_arch) - with conditional_with(train_loader), conditional_with(val_loader): + dali_server_train = train_loader.dali_server if hasattr(train_loader, "dali_server") else None + dali_server_val = val_loader.dali_server if hasattr(val_loader, "dali_server") else None + + with conditional_with(dali_server_train), conditional_with(dali_server_val): train_loop( trainer, lr_policy, diff --git a/docs/examples/use_cases/pytorch/resnet50/main.py b/docs/examples/use_cases/pytorch/resnet50/main.py index a088afece99..7f8a9d0df4b 100644 --- a/docs/examples/use_cases/pytorch/resnet50/main.py +++ b/docs/examples/use_cases/pytorch/resnet50/main.py @@ -19,7 +19,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP try: - from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy, DALIProxy, DALIDataLoader + from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy + from nvidia.dali.plugin.pytorch import proxy as dali_proxy from nvidia.dali.pipeline import pipeline_def import nvidia.dali.types as types import nvidia.dali.fn as fn @@ -334,7 +335,7 @@ def resume(): elif args.dali_proxy: assert train_pipe is not None assert val_pipe is not None - dali_proxy_train = DALIProxy(input_names=["images"]) + dali_server_train = dali_proxy.DALIServer(train_pipe) def read_file(path): return np.fromfile(path, dtype=np.uint8) @@ -344,14 +345,14 @@ def read_filepath(path): train_dataset = datasets.ImageFolder( traindir, - transform=dali_proxy_train.transform, + transform=dali_server_train.proxy, loader=read_filepath if args.send_filepaths else read_file ) - dali_proxy_val = DALIProxy(input_names=["images"]) + dali_server_val = dali_proxy.DALIServer(val_pipe) val_dataset = datasets.ImageFolder( valdir, - transform=dali_proxy_val.transform, + transform=dali_server_val.proxy, loader=read_filepath if args.send_filepaths else read_file ) @@ -361,9 +362,8 @@ def read_filepath(path): train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) - train_loader = DALIDataLoader( - train_pipe, - dali_proxy_train, + train_loader = dali_proxy.DataLoader( + dali_server_train, train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), @@ -373,9 +373,8 @@ def read_filepath(path): collate_fn=None ) - val_loader = DALIDataLoader( - val_pipe, - dali_proxy_val, + val_loader = dali_proxy.DataLoader( + dali_server_val, val_dataset, batch_size=args.batch_size, shuffle=False, @@ -428,7 +427,7 @@ def read_filepath(path): enabled=args.fp16_mode) total_time = AverageMeter() - with conditional_with(train_loader), conditional_with(val_loader): + with conditional_with(dali_server_train), conditional_with(dali_server_val): for epoch in range(args.start_epoch, args.epochs): # train for one epoch avg_train_time = train(train_loader, model, criterion, scaler, optimizer, epoch)