diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 2c809b9817..1702a1211c 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -89,7 +89,7 @@ jobs: steps: - name: Import run: | - export CUDA_VISIBLE_DEVICES= # cpu-only + export OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 CUDA_VISIBLE_DEVICES= # cpu-only python -c 'import monai; monai.config.print_debug_info()' cd /opt/monai ls -al diff --git a/docs/source/data.rst b/docs/source/data.rst index b789102b81..63d5e0e23d 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -45,6 +45,13 @@ Generic Interfaces :members: :special-members: __getitem__ +`GDSDataset` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: GDSDataset + :members: + :special-members: __getitem__ + + `CacheNTransDataset` ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: CacheNTransDataset diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 6aebe47ed7..a20511267b 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1521,6 +1521,8 @@ class GDSDataset(PersistentDataset): bandwidth while decreasing latency and utilization load on the CPU and GPU. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb. + + See also: https://github.com/rapidsai/kvikio """ def __init__( @@ -1607,9 +1609,10 @@ def _cachecheck(self, item_transformed): return item elif isinstance(item_transformed, (np.ndarray, torch.Tensor)): _meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta") - _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(())) - _data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}") - if bool(_meta): + _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(())) + _data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}") + filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys())) + if bool(filtered_keys): return (_data, _meta) return _data else: @@ -1617,7 +1620,9 @@ def _cachecheck(self, item_transformed): for i, _item in enumerate(item_transformed): for k in _item: meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}") - item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(())) + item_k = kvikio_numpy.fromfile( + f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(()) + ) item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k}) return item @@ -1653,7 +1658,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): if isinstance(_item_transformed_data, torch.Tensor): _item_transformed_data = _item_transformed_data.numpy() self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape - self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype + self._meta_cache[meta_hash_file_name]["dtype"] = str(_item_transformed_data.dtype) kvikio_numpy.tofile(_item_transformed_data, data_hashfile) try: # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index 2971b34fe7..29f2d0096b 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data import GDSDataset, json_hashing @@ -48,6 +49,19 @@ TEST_CASE_3 = [None, (128, 128, 128)] +DTYPES = { + np.dtype(np.uint8): torch.uint8, + np.dtype(np.int8): torch.int8, + np.dtype(np.int16): torch.int16, + np.dtype(np.int32): torch.int32, + np.dtype(np.int64): torch.int64, + np.dtype(np.float16): torch.float16, + np.dtype(np.float32): torch.float32, + np.dtype(np.float64): torch.float64, + np.dtype(np.complex64): torch.complex64, + np.dtype(np.complex128): torch.complex128, +} + class _InplaceXform(Transform): def __call__(self, data): @@ -93,16 +107,28 @@ def test_metatensor(self): shape = (1, 10, 9, 8) items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))] with tempfile.TemporaryDirectory() as tempdir: - ds = GDSDataset( - data=items, - transform=_InplaceXform(), - cache_dir=tempdir, - device=0, - pickle_module="pickle", - pickle_protocol=pickle.HIGHEST_PROTOCOL, - ) + ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) assert_allclose(ds[0], ds[0][0], type_test=False) + def test_dtype(self): + shape = (1, 10, 9, 8) + data = np.arange(0, np.prod(shape)).reshape(shape) + for _dtype in DTYPES.keys(): + items = [np.array(data).astype(_dtype)] + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + self.assertEqual(ds[0].dtype, _dtype) + self.assertEqual(ds1[0].dtype, DTYPES[_dtype]) + + for _dtype in DTYPES.keys(): + items = [torch.tensor(data, dtype=DTYPES[_dtype])] + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + self.assertEqual(ds[0].dtype, DTYPES[_dtype]) + self.assertEqual(ds1[0].dtype, DTYPES[_dtype]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))