Skip to content

Commit

Permalink
enable gpu load nifti
Browse files Browse the repository at this point in the history
Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv committed Nov 2, 2024
1 parent c1ceea3 commit 84d8cf3
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 2 deletions.
127 changes: 125 additions & 2 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
import glob
import os
import re
import gzip
import io
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
import torch

import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern
Expand All @@ -41,17 +44,21 @@
import pydicom
from nibabel.nifti1 import Nifti1Image
from PIL import Image as PILImage
import cupy as cp
import kvikio

has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True
has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True
else:
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
nib, has_nib = optional_import("nibabel")
Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image")
PILImage, has_pil = optional_import("PIL.Image")
pydicom, has_pydicom = optional_import("pydicom")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)
cp, has_cp = optional_import("cupy")
kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]


class ImageReader(ABC):
Expand Down Expand Up @@ -1024,6 +1031,122 @@ def _get_array_data(self, img):
"""
return np.asanyarray(img.dataobj, order="C")


@require_pkg(pkg_name="nibabel")
@require_pkg(pkg_name="cupy")
@require_pkg(pkg_name="kvikio")
class NibabelGPUReader(NibabelReader):

def _gds_load(self, file_path):
file_size = os.path.getsize(file_path)
image = cp.empty(file_size, dtype=cp.uint8)
with kvikio.CuFile(file_path, "r") as f:
f.read(image)

if file_path.endswith(".gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# but it's still faster than Nibabel's default reader.
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
# since it's waste times especially in training
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

file_size = len(decompressed_data)
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))

return image

def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
"""
Read image data from specified file or files, it can read a list of images
and stack them together as multi-channel data in `get_data()`.
Note that the returned object is Nibabel image object or list of Nibabel image objects.
Args:
data: file name or a list of file names to read.
"""
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
img = self._gds_load(name)
img_.append(img) # type: ignore
return img_ if len(filenames) > 1 else img_[0]

def get_data(self, img):
"""
Extract data array and metadata from loaded image and return them.
This function returns two objects, first is numpy array of image data, second is dict of metadata.
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
When loading a list of files, they are stacked together at a new dimension as the first dimension,
and the metadata of the first image is used to present the output metadata.
Args:
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
"""
compatible_meta: dict = {}
img_array = []
for i in ensure_tuple(img):
header = self._get_header(i)
data_offset = header.get_data_offset()
data_shape = header.get_data_shape()
data_dtype = header.get_data_dtype()
affine = header.get_best_affine()
meta = dict(header)
meta[MetaKeys.AFFINE] = affine
meta[MetaKeys.ORIGINAL_AFFINE] = affine
# TODO: as_closest_canonical
# TODO: correct_nifti_header_if_necessary
meta[MetaKeys.SPATIAL_SHAPE] = data_shape
# TODO: figure out why always RAS for NibabelReader ?
# meta[MetaKeys.SPACE] = SpaceKeys.RAS

data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F")
# TODO: check channel
# if self.squeeze_non_spatial_dims:
img_array.append(data)
if self.channel_dim is None: # default to "no_channel" or -1
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
)
else:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

return self._stack_images(img_array, compatible_meta), compatible_meta

def _get_header(self, img):
"""
Get the all the metadata of the image and convert to dict type.
Args:
img: a Nibabel image object loaded from an image file.
"""
header_bytes = cp.asnumpy(img[:348])
header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes))
# swap to little endian as PyTorch doesn't support big endian
try:
header = header.as_byteswapped("<")
except ValueError:
pass
return header

def _stack_images(self, image_list: list, meta_dict: dict):
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
return torch.cat(image_list, axis=channel_dim)
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
return torch.stack(image_list, dim=0)


class NumpyReader(ImageReader):
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ImageReader,
ITKReader,
NibabelReader,
NibabelGPUReader,
NrrdReader,
NumpyReader,
PILReader,
Expand Down Expand Up @@ -69,6 +70,7 @@
"numpyreader": NumpyReader,
"pilreader": PILReader,
"nibabelreader": NibabelReader,
"nibabelgpureader": NibabelGPUReader,
}


Expand Down

0 comments on commit 84d8cf3

Please sign in to comment.