Skip to content

Commit

Permalink
Update GDSDataset (#6787)
Browse files Browse the repository at this point in the history
Fixes #6786 .

### Description

- Update rst
- Update the type of dtype to str

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Jul 27, 2023
1 parent 87d0ede commit d89b457
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
steps:
- name: Import
run: |
export CUDA_VISIBLE_DEVICES= # cpu-only
export OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 CUDA_VISIBLE_DEVICES= # cpu-only
python -c 'import monai; monai.config.print_debug_info()'
cd /opt/monai
ls -al
Expand Down
7 changes: 7 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ Generic Interfaces
:members:
:special-members: __getitem__

`GDSDataset`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: GDSDataset
:members:
:special-members: __getitem__


`CacheNTransDataset`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: CacheNTransDataset
Expand Down
15 changes: 10 additions & 5 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,8 @@ class GDSDataset(PersistentDataset):
bandwidth while decreasing latency and utilization load on the CPU and GPU.
A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb.
See also: https://github.com/rapidsai/kvikio
"""

def __init__(
Expand Down Expand Up @@ -1607,17 +1609,20 @@ def _cachecheck(self, item_transformed):
return item
elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):
_meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta")
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(()))
_data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}")
if bool(_meta):
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(()))
_data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}")
filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys()))
if bool(filtered_keys):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(()))
item_k = kvikio_numpy.fromfile(
f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(())
)
item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k})
return item
Expand Down Expand Up @@ -1653,7 +1658,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
if isinstance(_item_transformed_data, torch.Tensor):
_item_transformed_data = _item_transformed_data.numpy()
self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape
self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype
self._meta_cache[meta_hash_file_name]["dtype"] = str(_item_transformed_data.dtype)
kvikio_numpy.tofile(_item_transformed_data, data_hashfile)
try:
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
Expand Down
42 changes: 34 additions & 8 deletions tests/test_gdsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data import GDSDataset, json_hashing
Expand Down Expand Up @@ -48,6 +49,19 @@

TEST_CASE_3 = [None, (128, 128, 128)]

DTYPES = {
np.dtype(np.uint8): torch.uint8,
np.dtype(np.int8): torch.int8,
np.dtype(np.int16): torch.int16,
np.dtype(np.int32): torch.int32,
np.dtype(np.int64): torch.int64,
np.dtype(np.float16): torch.float16,
np.dtype(np.float32): torch.float32,
np.dtype(np.float64): torch.float64,
np.dtype(np.complex64): torch.complex64,
np.dtype(np.complex128): torch.complex128,
}


class _InplaceXform(Transform):
def __call__(self, data):
Expand Down Expand Up @@ -93,16 +107,28 @@ def test_metatensor(self):
shape = (1, 10, 9, 8)
items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))]
with tempfile.TemporaryDirectory() as tempdir:
ds = GDSDataset(
data=items,
transform=_InplaceXform(),
cache_dir=tempdir,
device=0,
pickle_module="pickle",
pickle_protocol=pickle.HIGHEST_PROTOCOL,
)
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
assert_allclose(ds[0], ds[0][0], type_test=False)

def test_dtype(self):
shape = (1, 10, 9, 8)
data = np.arange(0, np.prod(shape)).reshape(shape)
for _dtype in DTYPES.keys():
items = [np.array(data).astype(_dtype)]
with tempfile.TemporaryDirectory() as tempdir:
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
self.assertEqual(ds[0].dtype, _dtype)
self.assertEqual(ds1[0].dtype, DTYPES[_dtype])

for _dtype in DTYPES.keys():
items = [torch.tensor(data, dtype=DTYPES[_dtype])]
with tempfile.TemporaryDirectory() as tempdir:
ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)
self.assertEqual(ds[0].dtype, DTYPES[_dtype])
self.assertEqual(ds1[0].dtype, DTYPES[_dtype])

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, transform, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
Expand Down

0 comments on commit d89b457

Please sign in to comment.