Skip to content

Commit

Permalink
Code review changes
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Dec 4, 2024
1 parent 0472490 commit a802a29
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 366 deletions.
322 changes: 1 addition & 321 deletions dali/python/nvidia/dali/plugin/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,12 @@
from nvidia.dali.plugin.base_iterator import LastBatchPolicy

import torch
import torch.multiprocessing as mp
import torch.utils.dlpack as torch_dlpack # noqa: F401
from torch.utils import data
from torch.utils.data._utils.collate import collate
from torch.utils.data.dataloader import (
DataLoader,
_MultiProcessingDataLoaderIter,
_SingleProcessDataLoaderIter,
_BaseDataLoaderIter,
)
import ctypes
import numpy as np
import threading
from queue import Empty

from nvtx import nvtx

from . import fn # noqa: F401
from . import proxy # noqa: F401

from nvidia.dali.plugin.pytorch._torch_function import TorchPythonFunction as TorchPythonFunction

Expand All @@ -66,51 +54,6 @@
}


def to_torch_tensor(tensor_or_tl, device_id=0):
"""
Copy contents of DALI tensor to PyTorch's Tensor.
Parameters
----------
`tensor_or_tl` : TensorGPU or TensorListGPU
`arr` : torch.Tensor
Destination of the copy
`cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
CUDA stream to be used for the copy
(if not provided, an internal user stream will be selected)
In most cases, using pytorch's current stream is expected (for example,
if we are copying to a tensor allocated with torch.zeros(...))
"""
if isinstance(tensor_or_tl, (TensorListGPU, TensorListCPU)):
dali_tensor = tensor_or_tl.as_tensor()
else:
dali_tensor = tensor_or_tl

if isinstance(dali_tensor, (TensorGPU)):
torch_device = torch.device("cuda", device_id)
else:
torch_device = torch.device("cpu")

out_torch = torch.empty(
dali_tensor.shape(),
dtype=to_torch_type[dali_tensor.dtype],
device=torch_device,
)

# turn raw int to a c void pointer
c_type_pointer = ctypes.c_void_p(out_torch.data_ptr())
if isinstance(dali_tensor, (TensorGPU)):
non_blocking = True
cuda_stream = torch.cuda.current_stream(device=torch_device)
cuda_stream = types._raw_cuda_stream(cuda_stream)
stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
tensor_or_tl.copy_to_external(c_type_pointer, stream, non_blocking)
else:
tensor_or_tl.copy_to_external(c_type_pointer)

return out_torch


def feed_ndarray(
dali_tensor: Union[TensorCPU, TensorGPU, TensorListCPU, TensorListGPU],
arr: torch.Tensor,
Expand Down Expand Up @@ -853,266 +796,3 @@ def __next__(self) -> List[Dict[str, torch.Tensor]]:
DENSE_TAG: str = "dense"
SPARSE_LIST_TAG: str = "sparse_list"
SPARSE_COO_TAG: str = "sparse_coo"


class DALIProxy:
"""
Proxy to communicate to send processing requests to a DALI pipeline running on the main loop.
This is used by PyTorch data workers to assign some processing to the loaded samples, which can
execute in the main process via the GPU.
Background: As the PyTorch workers run on separate processes, using the GPU directly from each
of those is not ideal for performance, due to the usage of several CUDA contexts.
"""

def __init__(self, input_names):
"""
Initializes a new DALI proxy instance.
Args:
input_names (list): list of strings representing the inputs to the pipeline. Those should match
the names of the ``external_source`` nodes in the DALI pipeline.
"""
self.input_names = input_names
self.num_inputs = len(input_names)
# Multi-process queue used to transfer data from the pytorch workers to the main process
self.send_q = mp.Queue()
# Multi-process queue used by the main process to remember the actual order of execution of the requests
self.order_q = mp.Queue()
# Torch worker id, to be filled on first call to worker_id()
self._worker_id = None
# Iteration index for the current worker
self.data_idx = 0

def worker_id(self):
if self._worker_id is None:
self._worker_id = torch.utils.data.get_worker_info().id
return self._worker_id

class PipelineOutputRef:
"""
Placeholder for a pipeline output reference, after the iteration has been scheduled to DALI.
"""
def __init__(self, info):
self.info = info

def schedule_batch(self, inputs):
"""
Schedule a pipeline run to DALI, by queuing the worker id, the iteration index and the inputs
"""
# Identifier of this request
info = (self.worker_id(), self.data_idx)
with nvtx.annotate(f"dali_proxy.send_q.put {info}", color="blue"):
self.send_q.put((info, inputs))
self.data_idx = self.data_idx + 1
# Returns a placeholder, which is replaced with the actual data once the iteration completes
return DALIProxy.PipelineOutputRef(info)

class PipelineRunRef:
"""
Placeholder for a pipeline run reference, which is returned by the data worker instead of the actual data
The PyTorch worker returns this trivial object, only containing information about this proxy instance and
the input data to the pipeline. Later in the collate function, we send the data for execution to DALI.
"""
def __init__(self, dali_proxy, inputs):
self.dali_proxy = dali_proxy
self.inputs = inputs
assert len(self.inputs) == dali_proxy.num_inputs

def transform(self, *inputs):
"""
The 'transform' function consists of returning a reference to the pipeline run
"""
assert len(inputs) == self.num_inputs, f"Unexpected number of inputs: {len(inputs)}"
return DALIProxy.PipelineRunRef(self, inputs)


def collate_pipeline_run_ref_fn(pipe_out, *, collate_fn_map=None):
"""
Special collate function that schedules a batch for execution
"""
assert len(pipe_out) > 0
first_elem = pipe_out[0]
inputs = [[] for idx in range(len(first_elem.inputs))]
for elem in pipe_out:
assert first_elem.dali_proxy == elem.dali_proxy
for idx, input_ref in enumerate(elem.inputs):
inputs[idx].append(input_ref)
return first_elem.dali_proxy.schedule_batch(inputs)


def custom_collate(batch):
"""
Subscribe a special collate function for PipelineRunRef, that handles the scheduling of the iteration
on the fly
"""
collate_fn_map = data._utils.collate.default_collate_fn_map
collate_fn_map.update({DALIProxy.PipelineRunRef: collate_pipeline_run_ref_fn})
return collate(batch, collate_fn_map=collate_fn_map)


def flatten_tuple(nested_tuple):
"""
Flattens a nested tuple
"""
flat_list = []

def _flatten(t):
for item in t:
if isinstance(item, tuple):
_flatten(item)
else:
flat_list.append(item)

_flatten(nested_tuple)
return tuple(flat_list)


class DALIMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
"""
Data loader iterator used by the DALI proxy data loader
"""
def __init__(self, loader):
super().__init__(loader)
self.loader = loader

def _next_data(self):
data = super()._next_data()
if not hasattr(data, "__iter__"):
print(
"Warning: Non iterable returned from dataloader. Please "
" review the code, since it usually indicates a bug in a pipeline."
)
data = [data]
for data_idx, data_elem in enumerate(data):
# If loader returns a dictionary the iterator iterates over its keys.
# We need to access a value. Probably need to address more casess.
if isinstance(data, dict):
if isinstance(data[data_elem], DALIProxy.PipelineOutputRef):
data[data_elem] = self.loader.get_outputs(data[data_elem].info)
if isinstance(data_elem, DALIProxy.PipelineOutputRef):
data[data_idx] = self.loader.get_outputs(data_elem.info)
return data


class DALIDataLoader(DataLoader):
"""
DALI data loader to be used in the main loop, which runs the DALI pipeline doing the processing
asynchronously with regards to the training.
"""
def __init__(self, pipe, dali_proxy, *args, **kwargs):
if "collate_fn" in kwargs and kwargs["collate_fn"] is not None:
print(
"Warning: Make sure to handle DALIProxy.PipelineRunRef when providing"
" a custom collate_fn"
)
else:
kwargs["collate_fn"] = custom_collate
super().__init__(*args, **kwargs)
self.pipe = pipe
self.dali_proxy = dali_proxy
self.t = None
self.thread_stop_event = None
self.cache_outputs = dict()
self.cache_inputs = dict()

def outputs(self):
# Get the information about the order of execution, so that we know which one is the next iteration
torch.cuda.nvtx.range_push("order_q.get")
info = self.dali_proxy.order_q.get()
torch.cuda.nvtx.range_pop()

# Get the outputs from the current iteration
torch.cuda.nvtx.range_push(f"pipe.outputs {info}")
outputs = self.pipe.outputs()
torch.cuda.nvtx.range_pop()

# Return information about the iteration, together with the data
processed_outputs = tuple(
[to_torch_tensor(output, device_id=self.pipe.device_id) for output in outputs]
)
return (info, processed_outputs)

def get_outputs(self, req_info):
req_outputs = None
# If the data was already read, just return it (and clear the cache entry)
if req_info in self.cache_outputs:
req_outputs = self.cache_outputs[req_info]
del self.cache_outputs[req_info]
del self.cache_inputs[req_info]
else:
info = None
# If not the data we are looking for, store it and keep processing until we find it
while req_info != info:
info, processed_outputs = self.outputs()
if info == req_info:
req_outputs = processed_outputs
del self.cache_inputs[req_info]
else:
self.cache_outputs[info] = processed_outputs
# Unpack single element tuples
if isinstance(req_outputs, tuple) and len(req_outputs) == 1:
req_outputs = req_outputs[0]
return req_outputs

def thread_fn(self):
"""
Asynchronous DALI thread that gets iteration data from the queue and schedules it for execution
"""
while not self.thread_stop_event.is_set():
try:
torch.cuda.nvtx.range_push("dali_proxy.send_q.get")
info, inputs = self.dali_proxy.send_q.get(timeout=5)
torch.cuda.nvtx.range_pop()
self.cache_inputs[info] = inputs
except mp.TimeoutError:
continue
except Empty:
continue
torch.cuda.nvtx.range_push("dali_proxy.order_q.put {info}")
self.dali_proxy.order_q.put(info)
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push(f"feed_input {info}")
for idx, input_name in enumerate(self.dali_proxy.input_names):
self.pipe.feed_input(input_name, inputs[idx])
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("schedule_run {info}")
self.pipe.schedule_run()
torch.cuda.nvtx.range_pop()

def start_thread(self):
"""
Starts the DALI pipeline thread
"""
if self.t is not None:
return
self.t = threading.Thread(target=DALIDataLoader.thread_fn, args=(self,))
self.thread_stop_event = threading.Event()
self.t.start()

def stop_thread(self):
"""
Stops the DALI pipeline thread
"""
if self.thread_stop_event is None:
return
self.thread_stop_event.set()
self.t.join()
self.t = None
self.thread_stop_event = None

def __enter__(self):
self.start_thread()

def __exit__(self, exc_type, exc_value, tb):
self.stop_thread()

def _get_iterator(self) -> "_BaseDataLoaderIter":
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return DALIMultiProcessingDataLoaderIter(self)
Loading

0 comments on commit a802a29

Please sign in to comment.