From 826f463084d91f05a66c2821abf00814046f0d11 Mon Sep 17 00:00:00 2001 From: Joaquin Anton Guirao Date: Thu, 19 Dec 2024 21:21:53 +0100 Subject: [PATCH] Apply code review suggestions Signed-off-by: Joaquin Anton Guirao --- dali/python/nvidia/dali/external_source.py | 1 + dali/python/nvidia/dali/pipeline.py | 6 + .../pytorch/experimental/proxy/__init__.py | 438 +++++++++++------- dali/test/python/test_dali_proxy.py | 180 ++++++- .../efficientnet/image_classification/dali.py | 117 +++-- .../image_classification/dataloaders.py | 25 +- .../use_cases/pytorch/resnet50/main.py | 104 ++++- docs/plugins/pytorch_dali_proxy.rst | 2 +- 8 files changed, 627 insertions(+), 246 deletions(-) diff --git a/dali/python/nvidia/dali/external_source.py b/dali/python/nvidia/dali/external_source.py index 9e7ab9b65c..bb86627e93 100644 --- a/dali/python/nvidia/dali/external_source.py +++ b/dali/python/nvidia/dali/external_source.py @@ -26,6 +26,7 @@ SourceKind as _SourceKind, ) + def _get_shape(data): if isinstance(data, (_tensors.TensorCPU, _tensors.TensorGPU)): if callable(data.shape): diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index 0b909fc4fc..b934975151 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -478,6 +478,11 @@ def is_restored_from_checkpoint(self): """If True, this pipeline was restored from checkpoint.""" return self._is_restored_from_checkpoint + @property + def num_outputs(self): + """Number of pipeline outputs.""" + return self._num_outputs + def output_dtype(self) -> list: """Data types expected at the outputs.""" self.build() @@ -854,6 +859,7 @@ def contains_nested_datanode(nested): self._require_no_foreign_ops("The pipeline does not support checkpointing") self._graph_outputs = outputs + self._num_outputs = len(self._graph_outputs) self._setup_input_callbacks() self._disable_pruned_external_source_instances() self._py_graph_built = True diff --git a/dali/python/nvidia/dali/plugin/pytorch/experimental/proxy/__init__.py b/dali/python/nvidia/dali/plugin/pytorch/experimental/proxy/__init__.py index f3fb023d41..b79899057a 100644 --- a/dali/python/nvidia/dali/plugin/pytorch/experimental/proxy/__init__.py +++ b/dali/python/nvidia/dali/plugin/pytorch/experimental/proxy/__init__.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -import torch.multiprocessing as mp -from torch.utils import data as torchdata -from torch.utils.data._utils.collate import default_collate_fn_map -from nvidia.dali import Pipeline -from nvidia.dali.external_source import ExternalSource +__all__ = ["DALIServer", "DataLoader"] + +import torch as _torch +import torch.multiprocessing as _mp +from torch.utils import data as _torchdata +from torch.utils.data._utils.collate import default_collate_fn_map as _default_collate_fn_map +from nvidia.dali import Pipeline as _Pipeline +from nvidia.dali.external_source import ExternalSource as _ExternalSource import threading import queue from queue import Empty from nvidia.dali.plugin.pytorch.torch_utils import to_torch_tensor import warnings +import inspect # DALI proxy is a PyTorch specific layer that connects a multi-process PyTorch DataLoader with # a single DALI pipeline. This allows to run CUDA processing in a single process, avoiding the @@ -37,9 +40,9 @@ # +-------+ +---------------+ +-------------+ +---------------+ +-----------+ +-----------+ # | main | | dali_output_q | | data_thread | | dali_input_q | | worker_0 | | worker_1 | # +-------+ +---------------+ +-------------+ +---------------+ +-----------+ +-----------+ -# |~~~get()~~~~~~>| | | | | -# | | |~~~get()~~~~~~~>| | | -# | | | | | | +# |~~~get()~~~~~~>| | | | sample 0 | | +# | | |~~~get()~~~~~~~>| | sample 1 | | +# | | | | | sample N | | # | | | | ------------- | # | | | | | collate | | # | | | | ------------- | @@ -49,10 +52,10 @@ # | | | |---------------->| | # | | | | | | # | | |<--req_0_0------| | | -# | | | | | | -# | | --------------- | | | -# | | | run | | | | -# | | --------------- | | | +# | | | | | ------------- +# | | --------------- | | | sample 0 | +# | | | run #0 | | | | sample 1 | +# | | --------------- | | | sample N | # | | | | | ------------- # | | | | | | collate | # | | | | | ------------- @@ -61,11 +64,11 @@ # | | | | | | # | | | |------------------------------>| # | | | | | | -# | |<~~put(data_0_0)~~| | | | -# | | | | | | -# | |----------------->| | | | -# | | | | | | -# | | | | | ------------- +# | |<~~put(data_0_0)~~| | | ------------- +# | | | | | | sample 0 | +# | |----------------->| | | | sample 1 | +# |<--data_0_0----| | | | | sample N | +# |~~~get()~~~~~~>| | | | ------------- # | | | | | | collate | # | | | | | ------------- # | | | | | | @@ -73,17 +76,17 @@ # | | | | | | # | | | |------------------------------>| # | | | | | | -# |<--data_0_0----| | | | | -# | | | | | | -# |~~~get()~~~~~~>| | | | | -# | | |~~~get()~~~~~~~>| | | # | | | | | | -# | | |<--req_1_0------| | | # | | | | | | -# | | --------------- | | | -# | | | run | | | | -# | | --------------- | | | # | | | | | | +# | | |~~~get()~~~~~~~>| | | +# | | | | ------------- | +# | | |<--req_1_0------| | sample 0 | | +# | | | | | sample 1 | | +# | | --------------- | | sample N | | +# | | | run #1 | | ------------- | +# | | --------------- | | collate | | +# | | | | ------------- | # | | | |<~~put(req_0_1)~~| | # | | | | | | # | | | |-----------------| | @@ -91,9 +94,9 @@ # | |<~~put(data_1_0)~~| | | | # | | | | | | # | |----------------->| | | | -# | | | | | | # |<--data_1_0----| | | | | # | | | | | | +# | | | | | | # +-------+ +---------------+ +-------------+ +---------------+ +-----------+ +-----------+ # | main | | dali_output_q | | data_thread | | dali_input_q | | worker_0 | | worker_1 | # +-------+ +---------------+ +-------------+ +---------------+ +-----------+ +-----------+ @@ -114,74 +117,87 @@ def _external_source_node_names(pipeline): pipeline._build_graph() input_node_names = [] for op in pipeline._ops: - if isinstance(op._op, ExternalSource): + if isinstance(op._op, _ExternalSource): input_node_names.append(op.name) return input_node_names +class DALIOutputSampleRef: + """ + Reference for a single sample output bound to a pipeline run. + """ + + def __init__(self, proxy, pipe_run_ref, output_idx, sample_idx): + self.proxy = proxy + self.pipe_run_ref = pipe_run_ref + self.output_idx = output_idx + self.sample_idx = sample_idx + + def __repr__(self): + return ( + f"DALIOutputSampleRef({self.pipe_run_ref}, " + + f"output_idx={self.output_idx}, sample_idx={self.sample_idx})" + ) + + class DALIPipelineRunRef: """ Reference for a DALI pipeline run iteration. """ - def __init__(self, batch_id, inputs): - """ - batch_id: Identifier of the batch - is_scheduled: Whether the iteration has been scheduled for execution already - inputs: Inputs to be used when scheduling the iteration (makes sense only if - is_scheduled is False) - """ + def __init__(self, proxy, batch_id): self.batch_id = batch_id - self.inputs = inputs - assert self.inputs is not None + self.inputs = {name: [] for name in proxy._dali_input_names} self.is_scheduled = False + self.is_complete = False + + def __repr__(self): + return ( + f"DALIPipelineRunRef(batch_id={self.batch_id}, is_scheduled=" + f"{self.is_scheduled}, is_complete={self.is_complete})" + ) -class DALIProcessedSampleRef: +class DALIOutputBatchRef: """ - Placeholder for a pipeline run reference, which is returned by the data worker instead of - the actual data. The PyTorch worker returns this trivial object, only containing information - about this proxy instance and the input data to the pipeline. Later in the collate function, - we send the data for execution to DALI. + Reference for a batched output bound to a pipeline run. """ - def __init__(self, proxy, inputs): - self.proxy = proxy - self.inputs = inputs - if len(inputs) != len(self.proxy._dali_input_names): - raise RuntimeError( - f"Unexpected number of inputs. Expected: {self._dali_input_names}, got: {inputs}" - ) + def __init__(self, pipe_run_ref, output_idx): + self.pipe_run_ref = pipe_run_ref + self.output_idx = output_idx + + def __repr__(self): + return f"DALIOutputBatchRef(pipe_run_ref={self.pipe_run_ref}, output_idx={self.output_idx})" -def _collate_dali_processed_sample_ref_fn(samples, *, collate_fn_map=None): +def _collate_dali_output_sample_ref_fn(samples, *, collate_fn_map=None): """ Special collate function that schedules a DALI iteration for execution """ assert len(samples) > 0 - sample = samples[0] - inputs = [[] for _ in range(len(sample.inputs))] - proxy = sample.proxy - for sample in samples: - assert proxy == sample.proxy - for idx, input_ref in enumerate(sample.inputs): - inputs[idx].append(input_ref) - pipe_run_ref = proxy._create_pipe_run_ref(inputs) - if not proxy._deterministic: + pipe_run_ref = samples[0].pipe_run_ref + output_idx = samples[0].output_idx + proxy = samples[0].proxy + for i, sample in enumerate(samples): + # Sanity check + assert sample.proxy == proxy + assert sample.pipe_run_ref == pipe_run_ref + assert sample.output_idx == output_idx + assert sample.sample_idx == i, f"{samples}" + if not proxy._deterministic and not pipe_run_ref.is_scheduled: proxy._schedule_batch(pipe_run_ref) - # No need for the inputs now pipe_run_ref.inputs = None - # Mark as already scheduled - pipe_run_ref.is_scheduled = True - return pipe_run_ref + pipe_run_ref.is_complete = True + return DALIOutputBatchRef(pipe_run_ref, output_idx) -# In-place modify `default_collate_fn_map` to handle DALIProcessedSampleRef -default_collate_fn_map.update({DALIProcessedSampleRef: _collate_dali_processed_sample_ref_fn}) +# In-place modify `default_collate_fn_map` to handle DALIOutputSampleRef +_default_collate_fn_map.update({DALIOutputSampleRef: _collate_dali_output_sample_ref_fn}) class _DALIProxy: - def __init__(self, dali_input_names, dali_input_q, deterministic): + def __init__(self, dali_input_names, dali_input_q, dali_num_outputs, deterministic): # External source instance names self._dali_input_names = dali_input_names # If True, the request is not sent to DALI upon creation, so that it can be scheduled @@ -189,43 +205,64 @@ def __init__(self, dali_input_names, dali_input_q, deterministic): self._deterministic = deterministic # Shared queue with the server self._dali_input_q = dali_input_q - # Torch worker id, to be filled on first call to worker_id() - self._worker_id = None - # Iteration index for the current worker - self.data_idx = 0 - - @property - def worker_id(self): + # Number of outputs in the pipeline + self._dali_num_outputs = dali_num_outputs + # Per worker + self._worker_data = None + + def _init_worker_data(self): + self._worker_data = { + "worker_id": self._get_worker_id(), + "data_idx": 0, + "pipe_run_ref": None, + "batch_sample_idx": 0, + } + + def _get_worker_data(self): + if self._worker_data is None: + self._init_worker_data() + return self._worker_data + + def _get_worker_id(self): """ Getter for 'worker_id'. In case of torch data worker it is the worker info, and in case of a thread the thread identifier """ - if self._worker_id is None: - worker_info = torchdata.get_worker_info() - self._worker_id = worker_info.id if worker_info else threading.get_ident() - return self._worker_id - - def _create_pipe_run_ref(self, inputs): - # Identifier of this request - batch_id = (self.worker_id, self.data_idx) - self.data_idx = self.data_idx + 1 - return DALIPipelineRunRef(batch_id, inputs=inputs) + worker_info = _torchdata.get_worker_info() + return worker_info.id if worker_info else threading.get_ident() + + def _add_sample(self, bound_args): + state = self._get_worker_data() + if state["pipe_run_ref"] is None or state["pipe_run_ref"].is_complete: + state["pipe_run_ref"] = DALIPipelineRunRef( + self, (state["worker_id"], state["data_idx"]) + ) + state["data_idx"] += 1 + state["batch_sample_idx"] = 0 + + for name, value in bound_args.arguments.items(): + if name != "self": + state["pipe_run_ref"].inputs[name].append(value) + + ret = tuple( + DALIOutputSampleRef( + self, + pipe_run_ref=state["pipe_run_ref"], + output_idx=i, + sample_idx=state["batch_sample_idx"], + ) + for i in range(self._dali_num_outputs) + ) + state["batch_sample_idx"] += 1 + return ret[0] if len(ret) == 1 else ret def _schedule_batch(self, pipe_run_ref): - torch.cuda.nvtx.range_push(f"dali_proxy.dali_input_q.put {pipe_run_ref.batch_id}") + if pipe_run_ref.inputs is None: + raise RuntimeError("No inputs for the pipeline to run (was it already scheduled?)") + _torch.cuda.nvtx.range_push(f"dali_proxy.dali_input_q.put {pipe_run_ref.batch_id}") self._dali_input_q.put((pipe_run_ref.batch_id, pipe_run_ref.inputs)) - torch.cuda.nvtx.range_pop() - - def __call__(self, *inputs): - """ - Returns a reference to the pipeline run - """ - if len(inputs) != len(self._dali_input_names): - raise RuntimeError( - f"Unexpected number of inputs. Expected: {self._dali_input_names}, got: {inputs}" - ) - - return DALIProcessedSampleRef(self, inputs) + pipe_run_ref.is_scheduled = True + _torch.cuda.nvtx.range_pop() class DALIServer: @@ -309,24 +346,16 @@ def read_filepath(path): pass """ - assert isinstance(pipeline, Pipeline), f"Expected an NVIDIA DALI pipeline, got: {pipeline}" - self._pipe = pipeline - self._pipe_input_names = _external_source_node_names(self._pipe) - if len(self._pipe_input_names) == 0: - raise RuntimeError("The provided pipeline doesn't have any inputs") - elif len(self._pipe_input_names) == 1: - assert input_names is None or input_names[0] == self._pipe_input_names[0] - self._dali_input_names = self._pipe_input_names - elif input_names is None or len(input_names) != len(self._pipe_input_names): - raise RuntimeError( - "The provided pipeline has more than one output. In such case, the argument " - "`input_names` should containi the same exact number of strings, one for " - "each pipeline input to be mapped by the proxy callable object" - ) - self._num_inputs = len(self._dali_input_names) - + if not isinstance(pipeline, _Pipeline): + raise RuntimeError(f"Expected an NVIDIA DALI pipeline, got: {pipeline}") + else: + self._pipe = pipeline + # get and validate dali pipeline input names + self._dali_input_names, self._allow_positional_args = self._check_dali_input_names( + input_names + ) # Multi-process queue used to transfer data from the pytorch workers to the main process - self._dali_input_q = mp.Queue() + self._dali_input_q = _mp.Queue() # Multi-process queue used by the main process to consume outputs from the DALI pipeline self._dali_output_q = queue.Queue() # Thread @@ -335,15 +364,68 @@ def read_filepath(path): self._cache_outputs = dict() # Whether we want the order of DALI execution to be reproducible self._deterministic = deterministic + # Proxy + self._proxy = None + + def _check_dali_input_names(self, input_names): + pipe_input_names = _external_source_node_names(self._pipe) + if len(pipe_input_names) == 0: + raise RuntimeError("The provided pipeline doesn't have any inputs") + pipe_input_names_set = set(pipe_input_names) + input_names_set = set(input_names or []) + if len(input_names_set) != len(input_names_set): + raise RuntimeError("``input_names`` argument should not contain any duplicated values") + + if len(input_names_set) == 0: + allow_positional_args = True if len(pipe_input_names) == 1 else False + return pipe_input_names, allow_positional_args + + if input_names_set != pipe_input_names_set: + raise RuntimeError( + "The set of DALI input names provided should match exactly the " + "ones provided by the pipeline. " + f"\nProvided input names are: {input_names}" + f"\nPipeline input names are: {pipe_input_names}" + ) + return input_names, True @property def proxy(self): - return _DALIProxy(self._dali_input_names, self._dali_input_q, self._deterministic) + if self._proxy is None: + parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + for input_name in self._dali_input_names: + if self._allow_positional_args: + parameters.append( + inspect.Parameter(input_name, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ) + else: + parameters.append(inspect.Parameter(input_name, inspect.Parameter.KEYWORD_ONLY)) + + signature = inspect.Signature(parameters) + + def call_impl(self, *args, **kwargs): + try: + bound_args = signature.bind(self, *args, **kwargs) + except Exception as exc: + raise RuntimeError(f"{exc}. Signature is {signature}") + return self._add_sample(bound_args) + + call_impl.__signature__ = inspect.Signature(parameters) + _DALIProxy.__call__ = call_impl + self._proxy = _DALIProxy( + self._dali_input_names, + self._dali_input_q, + self._pipe.num_outputs, + self._deterministic, + ) + + return self._proxy def _schedule_batch(self, pipe_run_ref): - torch.cuda.nvtx.range_push(f"dali_proxy.dali_input_q.put {pipe_run_ref.batch_id}") + _torch.cuda.nvtx.range_push(f"dali_proxy.dali_input_q.put {pipe_run_ref.batch_id}") self._dali_input_q.put((pipe_run_ref.batch_id, pipe_run_ref.inputs)) - torch.cuda.nvtx.range_pop() + pipe_run_ref.is_scheduled = True + _torch.cuda.nvtx.range_pop() def _get_outputs(self, pipe_run_ref: DALIPipelineRunRef): """ @@ -355,7 +437,6 @@ def _get_outputs(self, pipe_run_ref: DALIPipelineRunRef): # In case we haven't scheduled the iteration yet (i.e. deterministic config), do it now if not pipe_run_ref.is_scheduled: self._schedule_batch(pipe_run_ref) - pipe_run_ref.is_scheduled = True # Wait for the requested output to be ready req_outputs = None @@ -368,9 +449,9 @@ def _get_outputs(self, pipe_run_ref: DALIPipelineRunRef): curr_batch_id = None # If not the data we are looking for, store it and keep processing until we find it while req_batch_id != curr_batch_id: - torch.cuda.nvtx.range_push("dali_output_q.get") + _torch.cuda.nvtx.range_push("dali_output_q.get") curr_batch_id, curr_processed_outputs, err = self._dali_output_q.get() - torch.cuda.nvtx.range_pop() + _torch.cuda.nvtx.range_pop() if err is not None: raise err @@ -379,47 +460,78 @@ def _get_outputs(self, pipe_run_ref: DALIPipelineRunRef): req_outputs = curr_processed_outputs else: self._cache_outputs[curr_batch_id] = curr_processed_outputs - # Unpack single element tuples - if isinstance(req_outputs, tuple) and len(req_outputs) == 1: - req_outputs = req_outputs[0] return req_outputs - def produce_data(self, obj): + def _produce_data_impl(self, obj, cache): """ - A generic function to recursively visits all elements in a nested structure and replace - instances of DALIPipelineRunRef with the actual data provided by the DALI server - - Args: - obj: The object to map (can be an instance of any class). - - Returns: - A new object where any instance of DALIPipelineRunRef has been replaced with actual - data. + Recursive implementation of produce_data """ - # If it is an instance of DALIPipelineRunRef, replace it with data - if isinstance(obj, DALIPipelineRunRef): - return self._get_outputs(obj) + # If it is an instance of DALIOutputBatchRef, replace it with data + if isinstance(obj, DALIOutputBatchRef): + if obj.pipe_run_ref.batch_id not in cache: + cache[obj.pipe_run_ref.batch_id] = self._get_outputs(obj.pipe_run_ref) + return cache[obj.pipe_run_ref.batch_id][obj.output_idx] + # If it is a custom class, recursively call produce data on its members elif hasattr(obj, "__dict__"): for attr_name, attr_value in obj.__dict__.items(): - setattr(obj, attr_name, self.produce_data(attr_value)) + setattr(obj, attr_name, self._produce_data_impl(attr_value, cache)) return obj # If it's a list, recursively apply the function to each element elif isinstance(obj, list): - return [self.produce_data(item) for item in obj] + return [self._produce_data_impl(item, cache) for item in obj] # If it's a tuple, recursively apply the function to each element (and preserve tuple type) elif isinstance(obj, tuple): - return tuple(self.produce_data(item) for item in obj) + return tuple(self._produce_data_impl(item, cache) for item in obj) - # If it's a dictionary, apply the function to both keys and values + # If it's a dictionary, apply the function to the values elif isinstance(obj, dict): - return {key: self.produce_data(value) for key, value in obj.items()} + return {key: self._produce_data_impl(value, cache) for key, value in obj.items()} else: # return directly anything else return obj + def produce_data(self, obj): + """ + A generic function to recursively visits all elements in a nested structure and replace + instances of DALIOutputBatchRef with the actual data provided by the DALI server + + Args: + obj: The object to map (can be an instance of any class). + + Returns: + A new object where any instance of DALIOutputBatchRef has been replaced with actual + data. + """ + cache = dict() + ret = self._produce_data_impl(obj, cache) + del cache + return ret + + def _get_input_batches(self, max_num_batches, timeout=None): + _torch.cuda.nvtx.range_push("dali_input_q.get") + count = 0 + batches = [] + if timeout is not None: + try: + batches.append(self._dali_input_q.get(timeout=timeout)) + count = count + 1 + except Empty: + return None + except _mp.TimeoutError: + return None + + while count < max_num_batches: + try: + batches.append(self._dali_input_q.get_nowait()) + count = count + 1 + except Empty: + break + _torch.cuda.nvtx.range_pop() + return batches + def _thread_fn(self): """ Asynchronous DALI thread that gets iteration data from the queue and schedules it @@ -427,23 +539,33 @@ def _thread_fn(self): """ self._pipe.build() # just in case + fed_batches = queue.Queue() while not self._thread_stop_event.is_set(): - try: - torch.cuda.nvtx.range_push("dali_input_q.get") - batch_id, inputs = self._dali_input_q.get(timeout=5) - torch.cuda.nvtx.range_pop() - except mp.TimeoutError: - continue - except Empty: + _torch.cuda.nvtx.range_push("get_input_batches") + timeout = 5 if fed_batches.empty() else None + # We try to feed as many batches as the prefetch queue (if available) + batches = self._get_input_batches( + self._pipe.prefetch_queue_depth - fed_batches.qsize(), timeout=timeout + ) + _torch.cuda.nvtx.range_pop() + if batches is not None and len(batches) > 0: + _torch.cuda.nvtx.range_push("feed_pipeline") + for batch_id, inputs in batches: + for input_name, input_data in inputs.items(): + self._pipe.feed_input(input_name, input_data) + self._pipe._run_once() + fed_batches.put(batch_id) + _torch.cuda.nvtx.range_pop() + + # If no batches to consume, continue + if fed_batches.qsize() == 0: continue + _torch.cuda.nvtx.range_push("outputs") + batch_id = fed_batches.get_nowait() # we are sure there's at least one err = None torch_outputs = None - torch.cuda.nvtx.range_push(f"schedule iteration {batch_id}") try: - for idx, input_name in enumerate(self._dali_input_names): - self._pipe.feed_input(input_name, inputs[idx]) - self._pipe._run_once() pipe_outputs = self._pipe.outputs() torch_outputs = tuple( [ @@ -453,8 +575,9 @@ def _thread_fn(self): ) except Exception as exception: err = exception + self._dali_output_q.put((batch_id, torch_outputs, err)) - torch.cuda.nvtx.range_pop() + _torch.cuda.nvtx.range_pop() def start_thread(self): """ @@ -488,13 +611,13 @@ def __exit__(self, exc_type, exc_value, tb): return False # Return False to propagate exceptions -class DataLoader(torchdata.dataloader.DataLoader): +class DataLoader(_torchdata.dataloader.DataLoader): """ DALI data loader to be used in the main loop, which runs the DALI pipeline doing the processing asynchronously with regards to the training. """ - class _Iter(torchdata.dataloader._MultiProcessingDataLoaderIter): + class _Iter(_torchdata.dataloader._MultiProcessingDataLoaderIter): """ Data loader iterator used by the DALI proxy data loader """ @@ -502,13 +625,12 @@ class _Iter(torchdata.dataloader._MultiProcessingDataLoaderIter): def __init__(self, loader): super().__init__(loader) self.loader = loader + if self.loader.dali_server._thread is None: + raise RuntimeError("DALI server is not running") def _next_data(self): data = super()._next_data() - if self.loader.dali_server._thread is None: - raise RuntimeError("DALI server is not running") - data = self.loader.dali_server.produce_data(data) - return data + return self.loader.dali_server.produce_data(data) def __init__(self, dali_server, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/dali/test/python/test_dali_proxy.py b/dali/test/python/test_dali_proxy.py index ef6ed96a35..ee4e5264a6 100644 --- a/dali/test/python/test_dali_proxy.py +++ b/dali/test/python/test_dali_proxy.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from nvidia.dali import pipeline_def, fn, types import numpy as np import os @@ -5,6 +19,7 @@ from nose_utils import attr, assert_raises import PIL.Image + def read_file(path): return np.fromfile(path, dtype=np.uint8) @@ -49,9 +64,7 @@ def image_pipe(dali_device="gpu", include_decoder=True, random_pipe=True): if random_pipe: shapes = images.shape() crop_anchor, crop_shape = fn.random_crop_generator( - shapes, - random_aspect_ratio=[0.75, 4.0 / 3.0], - random_area=[0.08, 1.0] + shapes, random_aspect_ratio=[0.75, 4.0 / 3.0], random_area=[0.08, 1.0] ) images = fn.slice(images, start=crop_anchor, shape=crop_shape, axes=[0, 1]) @@ -75,7 +88,7 @@ def image_pipe(dali_device="gpu", include_decoder=True, random_pipe=True): @attr("pytorch") -@params(("cpu", False), ("cpu", True), ("cpu", False), ("gpu", True)) +@params(("cpu", False), ("cpu", True), ("gpu", False), ("gpu", True)) def test_dali_proxy_torch_data_loader(device, include_decoder, debug=False): # Shows how DALI proxy is used in practice with a PyTorch data loader @@ -114,7 +127,9 @@ def test_dali_proxy_torch_data_loader(device, include_decoder, debug=False): if include_decoder: dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy, loader=read_filepath) - dataset_ref = datasets.ImageFolder(jpeg, transform=lambda x: x.copy(), loader=read_filepath) + dataset_ref = datasets.ImageFolder( + jpeg, transform=lambda x: x.copy(), loader=read_filepath + ) else: dataset = datasets.ImageFolder(jpeg, transform=dali_server.proxy) dataset_ref = datasets.ImageFolder(jpeg, transform=lambda x: x.copy()) @@ -149,7 +164,9 @@ def ref_collate_fn(batch): target.shape, ) np.testing.assert_array_equal(target, ref_target) - ref_data_nparrays = [np.array(obj) if isinstance(obj, PIL.Image.Image) else obj for obj in ref_data] + ref_data_nparrays = [ + np.array(obj) if isinstance(obj, PIL.Image.Image) else obj for obj in ref_data + ] ref_data_tensors = [TensorCPU(arr) for arr in ref_data_nparrays] pipe_ref.feed_input("images", ref_data_tensors) (ref_data,) = pipe_ref.run() @@ -161,7 +178,8 @@ def ref_collate_fn(batch): @attr("pytorch") -def test_dali_proxy_torch_data_loader_manual_integration(device="gpu", debug=False): +@params(("gpu",)) +def test_dali_proxy_manual_integration(device, debug=False): # Shows how to integrate with DALI proxy manually with an existing data loader from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy @@ -251,18 +269,16 @@ def __getitem__(self, idx): return img2, other # This is just for educational purposes. It is recommended to rely - # default_collate_fn_map, which is updated to handle DALIProcessedSampleRef + # default_collate_fn_map, which is updated to handle DALIOuputSampleRef def custom_collate_fn(batch): images, labels = zip(*batch) - return dali_proxy._collate_dali_processed_sample_ref_fn(images), torch.tensor( + return dali_proxy._collate_dali_output_sample_ref_fn(images), torch.tensor( labels, dtype=torch.long ) # Run the server (it also cleans up on scope exit) with dali_proxy.DALIServer(pipe) as dali_server: - dataset = CustomDatasetDALI(plain_dataset, dali_server.proxy) - loader = torchdata.dataloader.DataLoader( dataset, batch_size=batch_size, @@ -272,9 +288,8 @@ def custom_collate_fn(batch): ) assert len(loader) > 0 - for next_input, next_target in loader: - assert isinstance(next_input, dali_proxy.DALIPipelineRunRef) + assert isinstance(next_input, dali_proxy.DALIOutputBatchRef) next_input = dali_server.produce_data(next_input) assert isinstance(next_input, torch.Tensor) np.testing.assert_equal([batch_size, 3, 224, 224], next_input.shape) @@ -392,3 +407,142 @@ def pipe_with_error(): # messages in the next test pipe._shutdown() del pipe + + +@attr("pytorch") +@params(("cpu",), ("gpu",)) +def test_dali_proxy_duplicated_outputs(device, debug=False): + from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy + from torch.utils import data as torchdata + from PIL import Image + + batch_size = 4 + num_threads = 3 + device_id = 0 + nworkers = 4 + pipe = image_pipe( + dali_device=device, + include_decoder=False, + random_pipe=False, + batch_size=batch_size, + num_threads=num_threads, + device_id=device_id, + prefetch_queue_depth=2 + nworkers, + ) + + class MyDataset(torchdata.Dataset): + def __init__(self, folder_path, transform): + self.folder_path = folder_path + self.image_files = self._find_images_in_folder(folder_path) + self.transform = transform + + def _find_images_in_folder(self, folder_path): + """ + Recursively find all image files in the folder and its subdirectories. + """ + image_files = [] + + # Walk through all directories and subdirectories + for root, _, files in os.walk(folder_path): + for file in files: + if file.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")): + image_files.append(os.path.join(root, file)) + + return image_files + + def __len__(self): + """Returns the number of images in the folder.""" + return len(self.image_files) + + def __getitem__(self, idx): + img_name = self.image_files[idx] + img_path = os.path.join(self.folder_path, img_name) + img = Image.open(img_path).convert("RGB") # Convert image to RGB (3 channels) + img = self.transform(img) + return img, 1, img + + with dali_proxy.DALIServer(pipe) as dali_server: + dataset = MyDataset(jpeg, transform=dali_server.proxy) + loader = dali_proxy.DataLoader( + dali_server, + dataset, + batch_size=batch_size, + num_workers=nworkers, + drop_last=True, + ) + + for data1, _, data2 in loader: + np.testing.assert_array_equal(data1, data2) + + +@attr("pytorch") +@params(("cpu",), ("gpu",)) +def test_dali_proxy_rearrange_output_order_and_positional_args(device, debug=False): + from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy + from torch.utils import data as torchdata + + batch_size = 4 + num_threads = 3 + device_id = 0 + nworkers = 4 + arrs = np.random.rand(20, 3) + + @pipeline_def + def pipe_2_outputs(): + a = fn.external_source(name="a", no_copy=True) + b = fn.external_source(name="b", no_copy=True) + if device == "gpu": + a = a.gpu() + b = b.gpu() + return a + b, b - a + + pipe1 = pipe_2_outputs( + batch_size=batch_size, + num_threads=num_threads, + device_id=device_id, + prefetch_queue_depth=2 + nworkers, + ) + pipe2 = pipe_2_outputs( + batch_size=batch_size, + num_threads=num_threads, + device_id=device_id, + prefetch_queue_depth=2 + nworkers, + ) + + class MyDataset(torchdata.Dataset): + def __init__(self, arrs, transform, reverse_order): + self.arrs = arrs + self.n = len(arrs) + self.transform = transform + self.reverse_order = reverse_order + + def __len__(self): + """Returns the number of images in the folder.""" + return self.n + + def __getitem__(self, idx): + a = self.arrs[idx] + b = self.arrs[idx + 1 if idx < self.n - 1 else 0] + a_plus_b, b_minus_a = self.transform(b=b, a=a) # reverse order in purpose + return (b_minus_a, 1, a_plus_b) if self.reverse_order else (a_plus_b, 1, b_minus_a) + + with dali_proxy.DALIServer(pipe1) as dali_server1, dali_proxy.DALIServer(pipe2) as dali_server2: + loader1 = dali_proxy.DataLoader( + dali_server1, + MyDataset(arrs, dali_server1.proxy, reverse_order=False), + batch_size=batch_size, + num_workers=nworkers, + drop_last=True, + ) + loader2 = dali_proxy.DataLoader( + dali_server2, + MyDataset(arrs, dali_server2.proxy, reverse_order=True), + batch_size=batch_size, + num_workers=nworkers, + drop_last=True, + ) + + for data1, data2 in zip(loader1, loader2): + np.testing.assert_array_equal(data1[0].cpu(), data2[2].cpu()) + np.testing.assert_array_equal(data1[1].cpu(), data2[1].cpu()) + np.testing.assert_array_equal(data1[2].cpu(), data2[0].cpu()) diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py index c43c5f689e..88ff08c44a 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dali.py @@ -21,31 +21,82 @@ @pipeline_def(enable_conditionals=True, exec_dynamic=True) def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_augmentation, - dali_device="gpu", rank=0, world_size=1, send_filepaths=False): + dali_device="gpu", rank=0, world_size=1): rng = fn.random.coin_flip(probability=0.5) - if data_dir is None: - if send_filepaths: - filepaths = fn.external_source(name="images", no_copy=True) - jpegs = fn.io.file.read(filepaths) - else: - jpegs = fn.external_source(name="images", no_copy=True) + jpegs, labels = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank, + num_shards=world_size, random_shuffle=True, pad_last_batch=True) + + decoder_device = "mixed" if dali_device == "gpu" else "cpu" + images = fn.decoders.image_random_crop(jpegs, device=decoder_device, output_type=types.RGB, + random_aspect_ratio=[0.75, 4.0 / 3.0], + random_area=[0.08, 1.0]) + + images = fn.resize(images, size=[image_size, image_size], + interp_type=interpolation, antialias=False) + + # Make sure that from this point we are processing on GPU regardless of dali_device parameter + images = images.gpu() + + images = fn.flip(images, horizontal=rng) + + # Based on the specification, apply the automatic augmentation policy. Note, that from the point + # of Pipeline definition, this `if` statement relies on static scalar parameter, so it is + # evaluated exactly once during build - we either include automatic augmentations or not. + # We pass the shape of the image after the resize so the translate operations are done + # relative to the image size. + if automatic_augmentation == "autoaugment": + output = auto_augment.auto_augment_image_net(images, shape=[image_size, image_size]) + elif automatic_augmentation == "trivialaugment": + output = trivial_augment.trivial_augment_wide(images, shape=[image_size, image_size]) else: - jpegs, labels = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank, - num_shards=world_size, random_shuffle=True, pad_last_batch=True) + output = images - if dali_device == "gpu": - decoder_device = "mixed" - resize_device = "gpu" + output = fn.crop_mirror_normalize(output, dtype=types.FLOAT, output_layout=output_layout, + crop=(image_size, image_size), + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + return output, labels + + +@pipeline_def(exec_dynamic=True) +def validation_pipe(data_dir, interpolation, image_size, image_crop, output_layout, + dali_device="gpu", rank=0, world_size=1): + jpegs, label = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank, + num_shards=world_size, random_shuffle=False, pad_last_batch=True) + + decoder_device = "mixed" if dali_device == "gpu" else "cpu" + images = fn.decoders.image(jpegs, device=decoder_device, output_type=types.RGB) + + images = fn.resize(images, resize_shorter=image_size, + interp_type=interpolation, antialias=False) + + images = images.gpu() + + output = fn.crop_mirror_normalize(images, dtype=types.FLOAT, output_layout=output_layout, + crop=(image_crop, image_crop), + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + return output, label + + + +@pipeline_def(enable_conditionals=True, exec_dynamic=True) +def training_pipe_external_source(interpolation, image_size, output_layout, automatic_augmentation, + dali_device="gpu", send_filepaths=False): + rng = fn.random.coin_flip(probability=0.5) + if send_filepaths: + filepaths = fn.external_source(name="images", no_copy=True) + jpegs = fn.io.file.read(filepaths) else: - decoder_device = "cpu" - resize_device = "cpu" + jpegs = fn.external_source(name="images", no_copy=True) + decoder_device = "mixed" if dali_device == "gpu" else "cpu" images = fn.decoders.image_random_crop(jpegs, device=decoder_device, output_type=types.RGB, random_aspect_ratio=[0.75, 4.0 / 3.0], random_area=[0.08, 1.0]) - images = fn.resize(images, device=resize_device, size=[image_size, image_size], + images = fn.resize(images, size=[image_size, image_size], interp_type=interpolation, antialias=False) # Make sure that from this point we are processing on GPU regardless of dali_device parameter @@ -70,35 +121,22 @@ def training_pipe(data_dir, interpolation, image_size, output_layout, automatic_ mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) - if data_dir is None: - return output - else: - return output, labels + return output @pipeline_def(exec_dynamic=True) -def validation_pipe(data_dir, interpolation, image_size, image_crop, output_layout, - dali_device="gpu", rank=0, world_size=1, send_filepaths=False): - if data_dir is None: - if send_filepaths: - filepaths = fn.external_source(name="images", no_copy=True) - jpegs = fn.io.file.read(filepaths) - else: - jpegs = fn.external_source(name="images", no_copy=True) +def validation_pipe_external_source(interpolation, image_size, image_crop, output_layout, + dali_device="gpu", send_filepaths=False): + if send_filepaths: + filepaths = fn.external_source(name="images", no_copy=True) + jpegs = fn.io.file.read(filepaths) else: - jpegs, label = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank, - num_shards=world_size, random_shuffle=False, pad_last_batch=True) - - if dali_device == "gpu": - decoder_device = "mixed" - resize_device = "gpu" - else: - decoder_device = "cpu" - resize_device = "cpu" + jpegs = fn.external_source(name="images", no_copy=True) + decoder_device = "mixed" if dali_device == "gpu" else "cpu" images = fn.decoders.image(jpegs, device=decoder_device, output_type=types.RGB) - images = fn.resize(images, device=resize_device, resize_shorter=image_size, + images = fn.resize(images, resize_shorter=image_size, interp_type=interpolation, antialias=False) images = images.gpu() @@ -108,7 +146,4 @@ def validation_pipe(data_dir, interpolation, image_size, image_crop, output_layo mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) - if data_dir is None: - return output - else: - return output, label + return output diff --git a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py index 425caafbaf..ddd9fef125 100644 --- a/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py +++ b/docs/examples/use_cases/pytorch/efficientnet/image_classification/dataloaders.py @@ -41,6 +41,7 @@ import nvidia.dali.types as types from image_classification.dali import training_pipe, validation_pipe + from image_classification.dali import training_pipe_external_source, validation_pipe_external_source DATA_BACKEND_CHOICES.append("dali") DATA_BACKEND_CHOICES.append("dali_proxy") @@ -157,7 +158,7 @@ def gdtl( pipe = training_pipe(data_dir=traindir, interpolation=interpolation, image_size=image_size, output_layout=output_layout, automatic_augmentation=augmentation, dali_device=dali_device, rank=rank, world_size=world_size, - prefetch_queue_depth=8, + prefetch_queue_depth=workers*2, **pipeline_kwargs) train_loader = DALIClassificationIterator( @@ -259,6 +260,7 @@ def expand(num_classes, dtype, tensor): class PrefetchedWrapper(object): + @staticmethod def prefetched_loader(loader, num_classes, one_hot, normalize, memory_format, data_layout): if normalize: mean = ( @@ -478,10 +480,11 @@ def get_impl(data_path, "seed": 12 + rank % torch.cuda.device_count(), } - pipe = training_pipe(data_dir=None, interpolation=interpolation, image_size=image_size, - output_layout=output_layout, automatic_augmentation=augmentation, - dali_device=dali_device, prefetch_queue_depth=8, send_filepaths=send_filepaths, - **pipeline_kwargs) + pipe = training_pipe_external_source( + interpolation=interpolation, image_size=image_size, + output_layout=output_layout, automatic_augmentation=augmentation, + dali_device=dali_device, prefetch_queue_depth=workers*2, send_filepaths=send_filepaths, + **pipeline_kwargs) pipe.build() dali_server = dali_proxy.DALIServer(pipe) @@ -550,12 +553,12 @@ def get_impl(data_path, "seed": 12 + rank % torch.cuda.device_count(), } - pipe = validation_pipe(data_dir=None, interpolation=interpolation, - image_size=image_size + crop_padding, image_crop=image_size, - output_layout=output_layout, - send_filepaths=send_filepaths, - **pipeline_kwargs) - + pipe = validation_pipe_external_source( + interpolation=interpolation, + image_size=image_size + crop_padding, image_crop=image_size, + output_layout=output_layout, + send_filepaths=send_filepaths, + **pipeline_kwargs) pipe.build() dali_server = dali_proxy.DALIServer(pipe) diff --git a/docs/examples/use_cases/pytorch/resnet50/main.py b/docs/examples/use_cases/pytorch/resnet50/main.py index 79b52dd43e..7db1ef1a24 100644 --- a/docs/examples/use_cases/pytorch/resnet50/main.py +++ b/docs/examples/use_cases/pytorch/resnet50/main.py @@ -123,19 +123,12 @@ def to_python_float(t): @pipeline_def(exec_dynamic=True) def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True): - if data_dir is None: - if args.send_filepaths: - filepaths = fn.external_source(name="images", no_copy=True) - images = fn.io.file.read(filepaths) - else: - images = fn.external_source(name="images", no_copy=True) - else: - images, labels = fn.readers.file(file_root=data_dir, - shard_id=shard_id, - num_shards=num_shards, - random_shuffle=is_training, - pad_last_batch=True, - name="Reader") + images, labels = fn.readers.file(file_root=data_dir, + shard_id=shard_id, + num_shards=num_shards, + random_shuffle=is_training, + pad_last_batch=True, + name="Reader") dali_device = 'cpu' if dali_cpu else 'gpu' decoder_device = 'cpu' if dali_cpu else 'mixed' @@ -174,12 +167,58 @@ def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=Fa mean=[0.485 * 255,0.456 * 255,0.406 * 255], std=[0.229 * 255,0.224 * 255,0.225 * 255], mirror=mirror) + labels = labels.gpu() + return images, labels + + + +@pipeline_def(exec_dynamic=True) +def create_dali_pipeline_external_source(crop, size, dali_cpu=False, is_training=True): + if args.send_filepaths: + filepaths = fn.external_source(name="images", no_copy=True) + images = fn.io.file.read(filepaths) + else: + images = fn.external_source(name="images", no_copy=True) - if data_dir is None: - return images + dali_device = 'cpu' if dali_cpu else 'gpu' + decoder_device = 'cpu' if dali_cpu else 'mixed' + # ask HW NVJPEG to allocate memory ahead for the biggest image in the data set to avoid reallocations in runtime + preallocate_width_hint = 5980 if decoder_device == 'mixed' else 0 + preallocate_height_hint = 6430 if decoder_device == 'mixed' else 0 + if is_training: + images = fn.decoders.image_random_crop(images, + device=decoder_device, output_type=types.RGB, + preallocate_width_hint=preallocate_width_hint, + preallocate_height_hint=preallocate_height_hint, + random_aspect_ratio=[0.8, 1.25], + random_area=[0.1, 1.0], + num_attempts=100) + images = fn.resize(images, + device=dali_device, + resize_x=crop, + resize_y=crop, + interp_type=types.INTERP_TRIANGULAR) + mirror = fn.random.coin_flip(probability=0.5) else: - labels = labels.gpu() - return images, labels + images = fn.decoders.image(images, + device=decoder_device, + output_type=types.RGB) + images = fn.resize(images, + device=dali_device, + size=size, + mode="not_smaller", + interp_type=types.INTERP_TRIANGULAR) + mirror = False + + images = fn.crop_mirror_normalize(images.gpu(), + dtype=types.FLOAT, + output_layout="CHW", + crop=(crop, crop), + mean=[0.485 * 255,0.456 * 255,0.406 * 255], + std=[0.229 * 255,0.224 * 255,0.225 * 255], + mirror=mirror) + + return images def main(): @@ -300,12 +339,12 @@ def resume(): val_pipe = None dali_server_train = None dali_server_val = None - if not args.disable_dali: + if not args.disable_dali and not args.dali_proxy: train_pipe = create_dali_pipeline(batch_size=args.batch_size, num_threads=args.workers, device_id=args.local_rank, seed=12 + args.local_rank, - data_dir=traindir if not args.dali_proxy else None, + data_dir=traindir, crop=crop_size, size=val_size, dali_cpu=args.dali_cpu, @@ -318,7 +357,7 @@ def resume(): num_threads=args.workers, device_id=args.local_rank, seed=12 + args.local_rank, - data_dir=valdir if not args.dali_proxy else None, + data_dir=valdir, crop=crop_size, size=val_size, dali_cpu=args.dali_cpu, @@ -327,14 +366,35 @@ def resume(): is_training=False) val_pipe.build() - if not args.disable_dali and not args.dali_proxy: train_loader = DALIClassificationIterator(train_pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True) val_loader = DALIClassificationIterator(val_pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True) - elif args.dali_proxy: + elif not args.disable_dali and args.dali_proxy: + train_pipe = create_dali_pipeline_external_source( + batch_size=args.batch_size, + num_threads=args.workers, + device_id=args.local_rank, + seed=12 + args.local_rank, + crop=crop_size, + size=val_size, + dali_cpu=args.dali_cpu, + is_training=True) + train_pipe.build() + + val_pipe = create_dali_pipeline_external_source( + batch_size=args.batch_size, + num_threads=args.workers, + device_id=args.local_rank, + seed=12 + args.local_rank, + crop=crop_size, + size=val_size, + dali_cpu=args.dali_cpu, + is_training=False) + val_pipe.build() + assert train_pipe is not None assert val_pipe is not None dali_server_train = dali_proxy.DALIServer(train_pipe) diff --git a/docs/plugins/pytorch_dali_proxy.rst b/docs/plugins/pytorch_dali_proxy.rst index dc039e59ce..c7130e6c1e 100644 --- a/docs/plugins/pytorch_dali_proxy.rst +++ b/docs/plugins/pytorch_dali_proxy.rst @@ -153,7 +153,7 @@ If using a custom ``DataLoader``, call the DALI server explicitly: .. code-block:: python for data, _ in loader: - # Replaces instances of ``DALIPipelineRunRef`` with actual data + # Replaces instances of ``DALIOutputBatchRef`` with actual data processed_data = dali_server.produce_data(data) print(processed_data.shape) # data is now ready