Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable gpu load nifti #8188

Merged
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
84d8cf3
enable gpu load nifti
yiheng-wang-nv Nov 2, 2024
ca1cfb8
fix issue
yiheng-wang-nv Nov 2, 2024
d3551cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
01a21e0
update loadimage
yiheng-wang-nv Nov 2, 2024
be77a45
add init
yiheng-wang-nv Nov 2, 2024
b4a747c
update filename
yiheng-wang-nv Nov 2, 2024
f6af120
update supported reader
yiheng-wang-nv Nov 2, 2024
009fdf7
update load image call
yiheng-wang-nv Nov 2, 2024
27d218a
remove useless header
yiheng-wang-nv Nov 2, 2024
1baa31b
add filename
yiheng-wang-nv Nov 2, 2024
da41742
Merge branch 'dev' into add-gds-support-on-niftireader
yiheng-wang-nv Nov 8, 2024
f453158
reformat to add gpu load support on nibabelreader
yiheng-wang-nv Nov 8, 2024
8d8ba0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2024
617729e
Merge branch 'dev' into add-gds-support-on-niftireader
yiheng-wang-nv Nov 27, 2024
59eccd4
Merge branch 'dev' into add-gds-support-on-niftireader
yiheng-wang-nv Dec 12, 2024
7eb890f
update
yiheng-wang-nv Dec 12, 2024
a62b1dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2024
5f9ac06
update to_cupy
yiheng-wang-nv Dec 12, 2024
d052a5f
add tests
yiheng-wang-nv Dec 13, 2024
a987a94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
1b12a39
add description on warm up
yiheng-wang-nv Dec 13, 2024
b70a5f5
Update monai/data/image_reader.py
yiheng-wang-nv Dec 16, 2024
83a1daf
add doc string
yiheng-wang-nv Dec 16, 2024
acf2cba
resolve comments
yiheng-wang-nv Dec 16, 2024
e5d7907
update
yiheng-wang-nv Dec 17, 2024
5005846
Update monai/data/image_reader.py
yiheng-wang-nv Dec 19, 2024
c43b916
Update monai/data/image_reader.py
yiheng-wang-nv Dec 19, 2024
02c77f0
Update monai/data/image_reader.py
yiheng-wang-nv Dec 19, 2024
ea2355b
Merge branch 'dev' into add-gds-support-on-niftireader
yiheng-wang-nv Dec 19, 2024
6f4e5cb
remove device set
yiheng-wang-nv Dec 20, 2024
b6fb2ab
update format
yiheng-wang-nv Dec 20, 2024
f7f59bf
Merge branch 'dev' into add-gds-support-on-niftireader
KumoLiu Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 77 additions & 9 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from __future__ import annotations

import glob
import gzip
import io
import os
import re
import tempfile
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
Expand Down Expand Up @@ -51,6 +54,9 @@
pydicom, has_pydicom = optional_import("pydicom")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)

cp, has_cp = optional_import("cupy")
kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]


Expand Down Expand Up @@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
)


def _stack_images(image_list: list, meta_dict: dict):
def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
if to_cupy and has_cp:
return cp.concatenate(image_list, axis=channel_dim)
return np.concatenate(image_list, axis=channel_dim)
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
if to_cupy and has_cp:
return cp.stack(image_list, axis=0)
return np.stack(image_list, axis=0)


Expand Down Expand Up @@ -864,12 +874,18 @@ class NibabelReader(ImageReader):
Load NIfTI format images based on Nibabel library.

Args:
as_closest_canonical: if True, load the image as closest to canonical axis format.
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
channel_dim: the channel dimension of the input image, default is None.
this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
if None, `original_channel_dim` will be either `no_channel` or `-1`.
most Nifti files are usually "channel last", no need to specify this argument for them.
as_closest_canonical: if True, load the image as closest to canonical axis format.
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
Default is False. CuPy and Kvikio are required for this option.
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
In practical use, it's recommended to add a warm up call before the actual loading.
A related tutorial will be prepared in the future, and the document will be updated accordingly.
kwargs: additional args for `nibabel.load` API. more details about available args:
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py

Expand All @@ -880,14 +896,42 @@ def __init__(
channel_dim: str | int | None = None,
as_closest_canonical: bool = False,
squeeze_non_spatial_dims: bool = False,
to_gpu: bool = False,
**kwargs,
):
super().__init__()
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.as_closest_canonical = as_closest_canonical
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
if to_gpu and (not has_cp or not has_kvikio):
warnings.warn(
"NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
)
to_gpu = False

if to_gpu:
self.warmup_kvikio()

self.to_gpu = to_gpu
self.kwargs = kwargs

def warmup_kvikio(self):
"""
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
This can accelerate the data loading process when `to_gpu` is set to True.
"""
if has_cp and has_kvikio:
a = cp.arange(100)
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file_name = tmp_file.name
f = kvikio.CuFile(tmp_file_name, "w")
f.write(a)
f.close()

b = cp.empty_like(a)
f = kvikio.CuFile(tmp_file_name, "r")
f.read(b)

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
"""
Verify whether the specified file or files format is supported by Nibabel reader.
Expand Down Expand Up @@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
self.filenames = filenames
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand All @@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.

"""
# TODO: the actual type is list[np.ndarray | cp.ndarray]
# should figure out how to define correct types without having cupy not found error
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
img_array: list[np.ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
for i, filename in zip(ensure_tuple(img), self.filenames):
header = self._get_meta_dict(i)
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
Expand All @@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
header[MetaKeys.SPACE] = SpaceKeys.RAS
data = self._get_array_data(i)
data = self._get_array_data(i, filename)
if self.squeeze_non_spatial_dims:
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
if data.shape[d - 1] == 1:
Expand All @@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

return _stack_images(img_array, compatible_meta), compatible_meta
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta

def _get_meta_dict(self, img) -> dict:
"""
Expand Down Expand Up @@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img):
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

def _get_array_data(self, img):
def _get_array_data(self, img, filename):
"""
Get the raw array data of the image, converted to Numpy array.

Args:
img: a Nibabel image object loaded from an image file.
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved

"""
filename: file name of the image.

"""
if self.to_gpu:
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
with kvikio.CuFile(filename, "r") as f:
f.read(image)
if filename.endswith(".nii.gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# and may be slower than CPU loading in some cases.
warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.")
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

image = cp.frombuffer(decompressed_data, dtype=cp.uint8)
data_shape = img.shape
data_offset = img.dataobj.offset
data_dtype = img.dataobj.dtype
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
return np.asanyarray(img.dataobj, order="C")


Expand Down
15 changes: 11 additions & 4 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,12 @@ def clone(self, **kwargs):

@staticmethod
def ensure_torch_and_prune_meta(
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
im: NdarrayTensor,
meta: dict | None,
simple_keys: bool = False,
pattern: str | None = None,
sep: str = ".",
device: None | str | torch.device = None,
):
"""
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
Expand All @@ -547,13 +552,15 @@ def ensure_torch_and_prune_meta(
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.
device: target device to put the Tensor data.

Returns:
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray

img = convert_to_tensor(
im, track_meta=get_track_meta() and meta is not None, device=device
) # potentially ascontiguousarray
# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand All @@ -565,7 +572,7 @@ def ensure_torch_and_prune_meta(
if simple_keys:
# ensure affine is of type `torch.Tensor`
if MetaKeys.AFFINE in meta:
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
remove_extra_metadata(meta) # bc-breaking

if pattern is not None:
Expand Down
8 changes: 5 additions & 3 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
prune_meta_pattern: str | None = None,
prune_meta_sep: str = ".",
expanduser: bool = True,
device: None | str | torch.device = None,
*args,
**kwargs,
) -> None:
Expand All @@ -163,6 +164,7 @@ def __init__(
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.
args: additional parameters for reader if providing a reader name.
device: target device to put the loaded image.
kwargs: additional parameters for reader if providing a reader name.

Note:
Expand All @@ -184,6 +186,7 @@ def __init__(
self.pattern = prune_meta_pattern
self.sep = prune_meta_sep
self.expanduser = expanduser
self.device = device

self.readers: list[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -286,18 +289,17 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered: {self.readers}.\n{msg}"
)

img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0]
if not isinstance(meta_data, dict):
raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")

meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
img = MetaTensor.ensure_torch_and_prune_meta(
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device
)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_init_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def test_load_image(self):
inst = LoadImaged("image", reader=r)
self.assertIsInstance(inst, LoadImaged)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
def test_load_image_to_gpu(self):
for to_gpu in [True, False]:
instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu)
self.assertIsInstance(instance1, LoadImage)

instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu)
self.assertIsInstance(instance2, LoadImaged)

@SkipIfNoModule("itk")
@SkipIfNoModule("nibabel")
@SkipIfNoModule("PIL")
Expand Down Expand Up @@ -58,6 +69,14 @@ def test_readers(self):
inst = NrrdReader()
self.assertIsInstance(inst, NrrdReader)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
def test_readers_to_gpu(self):
for to_gpu in [True, False]:
inst = NibabelReader(to_gpu=to_gpu)
self.assertIsInstance(inst, NibabelReader)


if __name__ == "__main__":
unittest.main()
41 changes: 40 additions & 1 deletion tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from monai.data.meta_tensor import MetaTensor
from monai.transforms import LoadImage
from monai.utils import optional_import
from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config
from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config

itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator")
Expand Down Expand Up @@ -74,6 +74,22 @@ def get_data(self, _obj):

TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)]

TEST_CASE_GPU_3 = [
{"reader": "nibabelreader", "to_gpu": True},
["test_image.nii", "test_image2.nii", "test_image3.nii"],
(3, 128, 128, 128),
]

TEST_CASE_GPU_4 = [
{"reader": "nibabelreader", "to_gpu": True},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
(3, 128, 128, 128),
]

TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]

TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]
Expand Down Expand Up @@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape):
assert_allclose(result.affine, torch.eye(4))
self.assertTupleEqual(result.shape, expected_shape)

@SkipIfNoModule("nibabel")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
@parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4])
def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result = LoadImage(image_only=True, **input_param)(filenames)
ext = "".join(Path(name).suffixes)
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext))
self.assertEqual(result.meta["space"], "RAS")
assert_allclose(result.affine, torch.eye(4))
self.assertTupleEqual(result.shape, expected_shape)

# verify gpu and cpu loaded data are the same
input_param_cpu = input_param.copy()
input_param_cpu["to_gpu"] = False
result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames)
self.assertTrue(torch.equal(result_cpu, result.cpu()))

@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9])
def test_itk_reader(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
Expand Down
Loading