Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli authored Jul 28, 2023
2 parents 6134967 + e2fa53b commit 70d7828
Show file tree
Hide file tree
Showing 22 changed files with 642 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
steps:
- name: Import
run: |
export CUDA_VISIBLE_DEVICES= # cpu-only
export OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 CUDA_VISIBLE_DEVICES= # cpu-only
python -c 'import monai; monai.config.print_debug_info()'
cd /opt/monai
ls -al
Expand Down
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ make html
```
The above commands build html documentation, they are used to automatically generate [https://docs.monai.io](https://docs.monai.io).

The Python code docstring are written in
[reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) and
the documentation pages can be in either [reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) or [Markdown](https://en.wikipedia.org/wiki/Markdown). In general the Python docstrings follow the [Google style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings).

Before submitting a pull request, it is recommended to:
- edit the relevant `.rst` files in [`docs/source`](./docs/source) accordingly.
- build html documentation locally
Expand Down
Binary file added docs/images/precision_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ Generic Interfaces
:members:
:special-members: __getitem__

`GDSDataset`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: GDSDataset
:members:
:special-members: __getitem__


`CacheNTransDataset`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CacheNTransDataset
Expand Down
6 changes: 6 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ Technical documentation is available at `docs.monai.io <https://docs.monai.io>`_

installation

.. toctree::
:maxdepth: 1
:caption: Precision and Performance

precision_performance

.. toctree::
:maxdepth: 1
:caption: Contributing
Expand Down
39 changes: 39 additions & 0 deletions docs/source/precision_performance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Precision and Performance

Modern GPU architectures usually can use reduced precision tensor data or computational operations to save memory and increase throughput. However, in some cases, the reduced precision will cause numerical stability issues, and further cause reproducibility issues. Therefore, please ensure that you are using appropriate precision.

<!-- Maybe adding Automatic Mixed Precision, Float16 or BFloat16 in the future-->

## TensorFloat-32 (TF32)

### Introduction

NVIDIA introduced a new math mode TensorFloat-32 (TF32) for NVIDIA Ampere GPUs and above, see [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/), [TRAINING NEURAL NETWORKS
WITH TENSOR CORES](https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf), [CUDA 11](https://developer.nvidia.com/blog/cuda-11-features-revealed/) and [Ampere architecture](https://developer.nvidia.com/blog/nvidia-ampere-architecture-in-depth/).

TF32 adopts 8 exponent bits, 10 bits of mantissa, and one sign bit.

![Precision options used for AI training.](../images/precision_options.png)

### Potential Impact

Although NVIDIA has shown that TF32 mode can reach the same accuracy and convergence as float32 for most AI workloads, some users still find some significant effect on their applications, see [PyTorch and TensorFloat32](https://dev-discuss.pytorch.org/t/pytorch-and-tensorfloat32/504). Users who need high-precision matrix operation, such as traditional computer graphics operation and kernel method, may be affected by TF32 precision.

Note that all operations that use `cuda.matmul` may be affected
by TF32 mode so the impact is very wide.

### Settings

[PyTorch TF32](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) default value:
```python
torch.backends.cuda.matmul.allow_tf32 = False # in PyTorch 1.12 and later.
torch.backends.cudnn.allow_tf32 = True
```
Please note that there are environment variables that can override the flags above. For example, the environment variables mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.

We recommend that users print out these two flags for confirmation when unsure.

If you are using an [NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), the container includes a layer `ENV TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`.
The default value `torch.backends.cuda.matmul.allow_tf32` will be overridden to `True`.

If you can confirm through experiments that your model has no accuracy or convergence issues in TF32 mode and you have NVIDIA Ampere GPUs or above, you can set the two flags above to `True` to speed up your model.
9 changes: 5 additions & 4 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def _resolve_one_item(
item = look_up_option(id, self.items, print_all_options=False, default=kwargs.get("default", "no_default"))
except ValueError as err:
raise KeyError(f"id='{id}' is not found in the config resolver.") from err
if not isinstance(item, ConfigItem):
return item
item_config = item.get_config()

if waiting_list is None:
Expand All @@ -151,11 +153,10 @@ def _resolve_one_item(
look_up_option(d, self.items, print_all_options=False)
except ValueError as err:
msg = f"the referring item `@{d}` is not defined in the config content."
if self.allow_missing_reference:
warnings.warn(msg)
continue
else:
if not self.allow_missing_reference:
raise ValueError(msg) from err
warnings.warn(msg)
continue
# recursively resolve the reference first
self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs)
waiting_list.discard(d)
Expand Down
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CSVDataset,
Dataset,
DatasetFunc,
GDSDataset,
LMDBDataset,
NPZDictItemDataset,
PersistentDataset,
Expand Down
183 changes: 181 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import (
Compose,
Expand All @@ -44,7 +45,7 @@
convert_to_contiguous,
reset_ops_id,
)
from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import
from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first

if TYPE_CHECKING:
Expand All @@ -54,8 +55,10 @@
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

cp, _ = optional_import("cupy")
lmdb, _ = optional_import("lmdb")
pd, _ = optional_import("pandas")
kvikio_numpy, _ = optional_import("kvikio.numpy")


class Dataset(_TorchDataset):
Expand Down Expand Up @@ -326,7 +329,6 @@ def _pre_transform(self, item_transformed):
first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)

item_transformed = self.transform(item_transformed, end=first_random, threading=True)

if self.reset_ops_id:
Expand Down Expand Up @@ -1510,3 +1512,180 @@ def __init__(
dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs
)
super().__init__(data=data, transform=transform)


class GDSDataset(PersistentDataset):
"""
An extension of the PersistentDataset using direct memory access(DMA) data path between
GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system
bandwidth while decreasing latency and utilization load on the CPU and GPU.
A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb.
See also: https://github.com/rapidsai/kvikio
"""

def __init__(
self,
data: Sequence,
transform: Sequence[Callable] | Callable,
cache_dir: Path | str | None,
device: int,
hash_func: Callable[..., bytes] = pickle_hashing,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
**kwargs: Any,
) -> None:
"""
Args:
data: input data file paths to load and transform to generate dataset for model.
`GDSDataset` expects input data to be a list of serializable
and hashes them as cache keys using `hash_func`.
transform: transforms to execute operations on input data.
cache_dir: If specified, this is the location for gpu direct storage
of pre-computed transformed data tensors. The cache_dir is computed once, and
persists on disk until explicitly removed. Different runs, programs, experiments
may share a common cache dir provided that the transforms pre-processing is consistent.
If `cache_dir` doesn't exist, will automatically create it.
If `cache_dir` is `None`, there is effectively no caching.
device: target device to put the output Tensor data. Note that only int can be used to
specify the gpu to be used.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.
"""
super().__init__(
data=data,
transform=transform,
cache_dir=cache_dir,
hash_func=hash_func,
hash_transform=hash_transform,
reset_ops_id=reset_ops_id,
**kwargs,
)
self.device = device
self._meta_cache: dict[Any, dict[Any, Any]] = {}

def _cachecheck(self, item_transformed):
"""
In order to enable direct storage to the GPU when loading the hashfile, rewritten this function.
Note that in this function, it will always return `torch.Tensor` when load data from cache.
Args:
item_transformed: The current data element to be mutated into transformed representation
Returns:
The transformed data_element, either from cache, or explicitly computing it.
Warning:
The current implementation does not encode transform information as part of the
hashing mechanism used for generating cache names when `hash_transform` is None.
If the transforms applied are changed in any way, the objects in the cache dir will be invalid.
"""
hashfile = None
# compute a cache id
if self.cache_dir is not None:
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
data_item_md5 += self.transform_hash
hashfile = self.cache_dir / f"{data_item_md5}.pt"

if hashfile is not None and hashfile.is_file(): # cache hit
with cp.cuda.Device(self.device):
if isinstance(item_transformed, dict):
item: dict[Any, Any] = {} # type:ignore
for k in item_transformed:
meta_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta")
item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(()))
item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}")
item[f"{k}_meta_dict"] = meta_k
return item
elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):
_meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta")
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(()))
_data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}")
filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys()))
if bool(filtered_keys):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
item_k = kvikio_numpy.fromfile(
f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(())
)
item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k})
return item

# create new cache
_item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed
if hashfile is None:
return _item_transformed
if isinstance(_item_transformed, dict):
for k in _item_transformed:
data_hashfile = f"{hashfile}-{k}"
meta_hash_file_name = f"{hashfile.name}-{k}-meta"
if isinstance(_item_transformed[k], (np.ndarray, torch.Tensor)):
self._create_new_cache(_item_transformed[k], data_hashfile, meta_hash_file_name)
else:
return _item_transformed
elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)):
data_hashfile = f"{hashfile}"
meta_hash_file_name = f"{hashfile.name}-meta"
self._create_new_cache(_item_transformed, data_hashfile, meta_hash_file_name)
else:
for i, _item in enumerate(_item_transformed):
for k in _item:
data_hashfile = f"{hashfile}-{k}-{i}"
meta_hash_file_name = f"{hashfile.name}-{k}-meta-{i}"
self._create_new_cache(_item, data_hashfile, meta_hash_file_name)
open(hashfile, "a").close() # store cacheid
return _item_transformed

def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
self._meta_cache[meta_hash_file_name] = copy(data.meta) if isinstance(data, MetaTensor) else {}
_item_transformed_data = data.array if isinstance(data, MetaTensor) else data
if isinstance(_item_transformed_data, torch.Tensor):
_item_transformed_data = _item_transformed_data.numpy()
self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape
self._meta_cache[meta_hash_file_name]["dtype"] = str(_item_transformed_data.dtype)
kvikio_numpy.tofile(_item_transformed_data, data_hashfile)
try:
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
# to make the cache more robust to manual killing of parent process
# which may leave partially written cache files in an incomplete state
with tempfile.TemporaryDirectory() as tmpdirname:
meta_hash_file = self.cache_dir / meta_hash_file_name
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
torch.save(
obj=self._meta_cache[meta_hash_file_name],
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
)
if temp_hash_file.is_file() and not meta_hash_file.is_file():
# On Unix, if target exists and is a file, it will be replaced silently if the
# user has permission.
# for more details: https://docs.python.org/3/library/shutil.html#shutil.move.
try:
shutil.move(str(temp_hash_file), meta_hash_file)
except FileExistsError:
pass
except PermissionError: # project-monai/monai issue #3613
pass

def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
return torch.load(self.cache_dir / meta_hash_file_name) # type:ignore
Loading

0 comments on commit 70d7828

Please sign in to comment.