Skip to content

Commit

Permalink
Allow PIL.Image as an input to the pipeline (can be converted to np.a…
Browse files Browse the repository at this point in the history
…rray)

Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Dec 24, 2024
1 parent 511436b commit 926b3bc
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 66 deletions.
23 changes: 17 additions & 6 deletions dali/python/nvidia/dali/external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,28 @@
SourceKind as _SourceKind,
)

def _get_shape(data):
if isinstance(data, (_tensors.TensorCPU, _tensors.TensorGPU)):
if callable(data.shape):
return data.shape()
else:
return data.shape
elif hasattr(data, "__array_interface__"):
return data.__array_interface__["shape"]
elif hasattr(data, "__array__"):
return data.__array__().shape
else:
raise RuntimeError(f"Don't know how to extract the shape out of {type(data)}")


def _get_batch_shape(data):
if isinstance(data, (list, tuple, _tensors.TensorListCPU, _tensors.TensorListGPU)):
if len(data) == 0:
return [], True
if callable(data[0].shape):
return [x.shape() for x in data], False
else:
return [x.shape for x in data], False
return [_get_shape(x) for x in data], False
else:
shape = data.shape
if callable(shape):
shape = data.shape()
shape = _get_shape(data)
return [shape[1:]] * shape[0], True


Expand Down Expand Up @@ -68,6 +77,8 @@ def to_numpy(x):
return x.asnumpy()
elif _types._is_torch_tensor(x):
return x.numpy()
elif hasattr(x, "__array__"):
return x.__array__()
else:
return x

Expand Down
108 changes: 48 additions & 60 deletions dali/python/nvidia/dali/plugin/pytorch/experimental/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@
from queue import Empty
from nvidia.dali.plugin.pytorch.torch_utils import to_torch_tensor
import warnings
import numpy as np

try:
from PIL import Image
has_pil = True
except:
has_pil = False


# DALI proxy is a PyTorch specific layer that connects a multi-process PyTorch DataLoader with
# a single DALI pipeline. This allows to run CUDA processing in a single process, avoiding the
Expand Down Expand Up @@ -156,9 +148,9 @@ class DALIProcessedSampleRef:
def __init__(self, proxy, inputs):
self.proxy = proxy
self.inputs = inputs
if len(inputs) != len(self.proxy.dali_input_names):
if len(inputs) != len(self.proxy._dali_input_names):
raise RuntimeError(
f"Unexpected number of inputs. Expected: {self.dali_input_names}, got: {inputs}"
f"Unexpected number of inputs. Expected: {self._dali_input_names}, got: {inputs}"
)


Expand All @@ -175,7 +167,7 @@ def _collate_dali_processed_sample_ref_fn(samples, *, collate_fn_map=None):
for idx, input_ref in enumerate(sample.inputs):
inputs[idx].append(input_ref)
pipe_run_ref = proxy._create_pipe_run_ref(inputs)
if not proxy.deterministic:
if not proxy._deterministic:
proxy._schedule_batch(pipe_run_ref)
# No need for the inputs now
pipe_run_ref.inputs = None
Expand All @@ -191,12 +183,12 @@ def _collate_dali_processed_sample_ref_fn(samples, *, collate_fn_map=None):
class _DALIProxy:
def __init__(self, dali_input_names, dali_input_q, deterministic):
# External source instance names
self.dali_input_names = dali_input_names
self._dali_input_names = dali_input_names
# If True, the request is not sent to DALI upon creation, so that it can be scheduled
# always in the same order by the main process. This comes at a cost of performance
self.deterministic = deterministic
self._deterministic = deterministic
# Shared queue with the server
self.dali_input_q = dali_input_q
self._dali_input_q = dali_input_q
# Torch worker id, to be filled on first call to worker_id()
self._worker_id = None
# Iteration index for the current worker
Expand All @@ -221,22 +213,18 @@ def _create_pipe_run_ref(self, inputs):

def _schedule_batch(self, pipe_run_ref):
torch.cuda.nvtx.range_push(f"dali_proxy.dali_input_q.put {pipe_run_ref.batch_id}")
self.dali_input_q.put((pipe_run_ref.batch_id, pipe_run_ref.inputs))
self._dali_input_q.put((pipe_run_ref.batch_id, pipe_run_ref.inputs))
torch.cuda.nvtx.range_pop()

def __call__(self, *inputs):
"""
Returns a reference to the pipeline run
"""
if len(inputs) != len(self.dali_input_names):
if len(inputs) != len(self._dali_input_names):
raise RuntimeError(
f"Unexpected number of inputs. Expected: {self.dali_input_names}, got: {inputs}"
f"Unexpected number of inputs. Expected: {self._dali_input_names}, got: {inputs}"
)

# Seamlessly handle PIL.Image (as it is typically done in PyTorch)
if has_pil:
inputs = [np.array(item) if isinstance(item, Image) else item for item in inputs]

return DALIProcessedSampleRef(self, inputs)


Expand Down Expand Up @@ -322,39 +310,39 @@ def read_filepath(path):
"""
assert isinstance(pipeline, Pipeline), f"Expected an NVIDIA DALI pipeline, got: {pipeline}"
self.pipe = pipeline
self.pipe_input_names = _external_source_node_names(self.pipe)
if len(self.pipe_input_names) == 0:
self._pipe = pipeline
self._pipe_input_names = _external_source_node_names(self._pipe)
if len(self._pipe_input_names) == 0:
raise RuntimeError("The provided pipeline doesn't have any inputs")
elif len(self.pipe_input_names) == 1:
assert input_names is None or input_names[0] == self.pipe_input_names[0]
self.dali_input_names = self.pipe_input_names
elif input_names is None or len(input_names) != len(self.pipe_input_names):
elif len(self._pipe_input_names) == 1:
assert input_names is None or input_names[0] == self._pipe_input_names[0]
self._dali_input_names = self._pipe_input_names
elif input_names is None or len(input_names) != len(self._pipe_input_names):
raise RuntimeError(
"The provided pipeline has more than one output. In such case, the argument "
"`input_names` should containi the same exact number of strings, one for "
"each pipeline input to be mapped by the proxy callable object"
)
self.num_inputs = len(self.dali_input_names)
self._num_inputs = len(self._dali_input_names)

# Multi-process queue used to transfer data from the pytorch workers to the main process
self.dali_input_q = mp.Queue()
self._dali_input_q = mp.Queue()
# Multi-process queue used by the main process to consume outputs from the DALI pipeline
self.dali_output_q = queue.Queue()
self._dali_output_q = queue.Queue()
# Thread
self.thread = None
self._thread = None
# Cache
self.cache_outputs = dict()
self._cache_outputs = dict()
# Whether we want the order of DALI execution to be reproducible
self.deterministic = deterministic
self._deterministic = deterministic

@property
def proxy(self):
return _DALIProxy(self.dali_input_names, self.dali_input_q, self.deterministic)
return _DALIProxy(self._dali_input_names, self._dali_input_q, self._deterministic)

def _schedule_batch(self, pipe_run_ref):
torch.cuda.nvtx.range_push(f"dali_proxy.dali_input_q.put {pipe_run_ref.batch_id}")
self.dali_input_q.put((pipe_run_ref.batch_id, pipe_run_ref.inputs))
self._dali_input_q.put((pipe_run_ref.batch_id, pipe_run_ref.inputs))
torch.cuda.nvtx.range_pop()

def _get_outputs(self, pipe_run_ref: DALIPipelineRunRef):
Expand All @@ -372,16 +360,16 @@ def _get_outputs(self, pipe_run_ref: DALIPipelineRunRef):
# Wait for the requested output to be ready
req_outputs = None
# If the data was already read, just return it (and clear the cache entry)
if req_batch_id in self.cache_outputs:
req_outputs = self.cache_outputs[req_batch_id]
del self.cache_outputs[req_batch_id]
if req_batch_id in self._cache_outputs:
req_outputs = self._cache_outputs[req_batch_id]
del self._cache_outputs[req_batch_id]

else:
curr_batch_id = None
# If not the data we are looking for, store it and keep processing until we find it
while req_batch_id != curr_batch_id:
torch.cuda.nvtx.range_push("dali_output_q.get")
curr_batch_id, curr_processed_outputs, err = self.dali_output_q.get()
curr_batch_id, curr_processed_outputs, err = self._dali_output_q.get()
torch.cuda.nvtx.range_pop()

if err is not None:
Expand All @@ -390,7 +378,7 @@ def _get_outputs(self, pipe_run_ref: DALIPipelineRunRef):
if curr_batch_id == req_batch_id:
req_outputs = curr_processed_outputs
else:
self.cache_outputs[curr_batch_id] = curr_processed_outputs
self._cache_outputs[curr_batch_id] = curr_processed_outputs
# Unpack single element tuples
if isinstance(req_outputs, tuple) and len(req_outputs) == 1:
req_outputs = req_outputs[0]
Expand Down Expand Up @@ -437,12 +425,12 @@ def _thread_fn(self):
Asynchronous DALI thread that gets iteration data from the queue and schedules it
for execution
"""
self.pipe.build() # just in case
self._pipe.build() # just in case

while not self.thread_stop_event.is_set():
while not self._thread_stop_event.is_set():
try:
torch.cuda.nvtx.range_push("dali_input_q.get")
batch_id, inputs = self.dali_input_q.get(timeout=5)
batch_id, inputs = self._dali_input_q.get(timeout=5)
torch.cuda.nvtx.range_pop()
except mp.TimeoutError:
continue
Expand All @@ -453,41 +441,41 @@ def _thread_fn(self):
torch_outputs = None
torch.cuda.nvtx.range_push(f"schedule iteration {batch_id}")
try:
for idx, input_name in enumerate(self.dali_input_names):
self.pipe.feed_input(input_name, inputs[idx])
self.pipe._run_once()
pipe_outputs = self.pipe.outputs()
for idx, input_name in enumerate(self._dali_input_names):
self._pipe.feed_input(input_name, inputs[idx])
self._pipe._run_once()
pipe_outputs = self._pipe.outputs()
torch_outputs = tuple(
[
to_torch_tensor(out.as_tensor(), not self.pipe.exec_dynamic)
to_torch_tensor(out.as_tensor(), not self._pipe.exec_dynamic)
for out in pipe_outputs
]
)
except Exception as exception:
err = exception
self.dali_output_q.put((batch_id, torch_outputs, err))
self._dali_output_q.put((batch_id, torch_outputs, err))
torch.cuda.nvtx.range_pop()

def start_thread(self):
"""
Starts the DALI pipeline thread
"""
if self.thread is not None:
if self._thread is not None:
return
self.thread = threading.Thread(target=DALIServer._thread_fn, args=(self,))
self.thread_stop_event = threading.Event()
self.thread.start()
self._thread = threading.Thread(target=DALIServer._thread_fn, args=(self,))
self._thread_stop_event = threading.Event()
self._thread.start()

def stop_thread(self):
"""
Stops the DALI pipeline thread
"""
if self.thread_stop_event is None:
if self._thread_stop_event is None:
return
self.thread_stop_event.set()
self.thread.join()
self.thread = None
self.thread_stop_event = None
self._thread_stop_event.set()
self._thread.join()
self._thread = None
self._thread_stop_event = None

def __enter__(self):
self.start_thread()
Expand Down Expand Up @@ -517,7 +505,7 @@ def __init__(self, loader):

def _next_data(self):
data = super()._next_data()
if self.loader.dali_server.thread is None:
if self.loader.dali_server._thread is None:
raise RuntimeError("DALI server is not running")
data = self.loader.dali_server.produce_data(data)
return data
Expand Down

0 comments on commit 926b3bc

Please sign in to comment.