diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b4ae562911..20cd46994d 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -14,12 +14,15 @@ import glob import os import re +import gzip +import io import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any +import torch import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -41,8 +44,10 @@ import pydicom from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage + import cupy as cp + import kvikio - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") @@ -50,8 +55,10 @@ PILImage, has_pil = optional_import("PIL.Image") pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) + cp, has_cp = optional_import("cupy") + kvikio, has_kvikio = optional_import("kvikio") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] class ImageReader(ABC): @@ -1024,6 +1031,122 @@ def _get_array_data(self, img): """ return np.asanyarray(img.dataobj, order="C") + + +@require_pkg(pkg_name="nibabel") +@require_pkg(pkg_name="cupy") +@require_pkg(pkg_name="kvikio") +class NibabelGPUReader(NibabelReader): + + def _gds_load(self, file_path): + file_size = os.path.getsize(file_path) + image = cp.empty(file_size, dtype=cp.uint8) + with kvikio.CuFile(file_path, "r") as f: + f.read(image) + + if file_path.endswith(".gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # but it's still faster than Nibabel's default reader. + # TODO: can benchmark more, it may no need to do this since we don't have to use .gz + # since it's waste times especially in training + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + file_size = len(decompressed_data) + image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) + + return image + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Nibabel image object or list of Nibabel image objects. + + Args: + data: file name or a list of file names to read. + + """ + img_: list[Nifti1Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = self._gds_load(name) + img_.append(img) # type: ignore + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img): + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to present the output metadata. + + Args: + img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + + """ + compatible_meta: dict = {} + img_array = [] + for i in ensure_tuple(img): + header = self._get_header(i) + data_offset = header.get_data_offset() + data_shape = header.get_data_shape() + data_dtype = header.get_data_dtype() + affine = header.get_best_affine() + meta = dict(header) + meta[MetaKeys.AFFINE] = affine + meta[MetaKeys.ORIGINAL_AFFINE] = affine + # TODO: as_closest_canonical + # TODO: correct_nifti_header_if_necessary + meta[MetaKeys.SPATIAL_SHAPE] = data_shape + # TODO: figure out why always RAS for NibabelReader ? + # meta[MetaKeys.SPACE] = SpaceKeys.RAS + + data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F") + # TODO: check channel + # if self.squeeze_non_spatial_dims: + img_array.append(data) + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return self._stack_images(img_array, compatible_meta), compatible_meta + + def _get_header(self, img): + """ + Get the all the metadata of the image and convert to dict type. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + header_bytes = cp.asnumpy(img[:348]) + header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes)) + # swap to little endian as PyTorch doesn't support big endian + try: + header = header.as_byteswapped("<") + except ValueError: + pass + return header + + def _stack_images(self, image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return torch.cat(image_list, axis=channel_dim) + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return torch.stack(image_list, dim=0) class NumpyReader(ImageReader): diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e71870fc9..eb0a0b88d8 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -35,6 +35,7 @@ ImageReader, ITKReader, NibabelReader, + NibabelGPUReader, NrrdReader, NumpyReader, PILReader, @@ -69,6 +70,7 @@ "numpyreader": NumpyReader, "pilreader": PILReader, "nibabelreader": NibabelReader, + "nibabelgpureader": NibabelGPUReader, }