From 84d8cf3c14af42c5a9c407cc6221a4be82a73d0f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 08:11:09 +0800 Subject: [PATCH 01/27] enable gpu load nifti Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 127 ++++++++++++++++++++++++++++++++++- monai/transforms/io/array.py | 2 + 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b4ae562911..20cd46994d 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -14,12 +14,15 @@ import glob import os import re +import gzip +import io import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any +import torch import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -41,8 +44,10 @@ import pydicom from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage + import cupy as cp + import kvikio - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") @@ -50,8 +55,10 @@ PILImage, has_pil = optional_import("PIL.Image") pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) + cp, has_cp = optional_import("cupy") + kvikio, has_kvikio = optional_import("kvikio") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] class ImageReader(ABC): @@ -1024,6 +1031,122 @@ def _get_array_data(self, img): """ return np.asanyarray(img.dataobj, order="C") + + +@require_pkg(pkg_name="nibabel") +@require_pkg(pkg_name="cupy") +@require_pkg(pkg_name="kvikio") +class NibabelGPUReader(NibabelReader): + + def _gds_load(self, file_path): + file_size = os.path.getsize(file_path) + image = cp.empty(file_size, dtype=cp.uint8) + with kvikio.CuFile(file_path, "r") as f: + f.read(image) + + if file_path.endswith(".gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # but it's still faster than Nibabel's default reader. + # TODO: can benchmark more, it may no need to do this since we don't have to use .gz + # since it's waste times especially in training + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + file_size = len(decompressed_data) + image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) + + return image + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Nibabel image object or list of Nibabel image objects. + + Args: + data: file name or a list of file names to read. + + """ + img_: list[Nifti1Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = self._gds_load(name) + img_.append(img) # type: ignore + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img): + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to present the output metadata. + + Args: + img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + + """ + compatible_meta: dict = {} + img_array = [] + for i in ensure_tuple(img): + header = self._get_header(i) + data_offset = header.get_data_offset() + data_shape = header.get_data_shape() + data_dtype = header.get_data_dtype() + affine = header.get_best_affine() + meta = dict(header) + meta[MetaKeys.AFFINE] = affine + meta[MetaKeys.ORIGINAL_AFFINE] = affine + # TODO: as_closest_canonical + # TODO: correct_nifti_header_if_necessary + meta[MetaKeys.SPATIAL_SHAPE] = data_shape + # TODO: figure out why always RAS for NibabelReader ? + # meta[MetaKeys.SPACE] = SpaceKeys.RAS + + data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F") + # TODO: check channel + # if self.squeeze_non_spatial_dims: + img_array.append(data) + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return self._stack_images(img_array, compatible_meta), compatible_meta + + def _get_header(self, img): + """ + Get the all the metadata of the image and convert to dict type. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + header_bytes = cp.asnumpy(img[:348]) + header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes)) + # swap to little endian as PyTorch doesn't support big endian + try: + header = header.as_byteswapped("<") + except ValueError: + pass + return header + + def _stack_images(self, image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return torch.cat(image_list, axis=channel_dim) + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return torch.stack(image_list, dim=0) class NumpyReader(ImageReader): diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e71870fc9..eb0a0b88d8 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -35,6 +35,7 @@ ImageReader, ITKReader, NibabelReader, + NibabelGPUReader, NrrdReader, NumpyReader, PILReader, @@ -69,6 +70,7 @@ "numpyreader": NumpyReader, "pilreader": PILReader, "nibabelreader": NibabelReader, + "nibabelgpureader": NibabelGPUReader, } From ca1cfb81d953459eed3f20620c9b9999c5c95cc8 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 08:35:06 +0800 Subject: [PATCH 02/27] fix issue Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 20cd46994d..54a0fd3b4c 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1069,7 +1069,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): data: file name or a list of file names to read. """ - img_: list[Nifti1Image] = [] + img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) kwargs_ = self.kwargs.copy() @@ -1113,12 +1113,12 @@ def get_data(self, img): # if self.squeeze_non_spatial_dims: img_array.append(data) if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 ) else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(meta, compatible_meta) return self._stack_images(img_array, compatible_meta), compatible_meta From d3551cc1d1a61f82e765ac353a21fbbb95322694 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 Nov 2024 00:35:38 +0000 Subject: [PATCH 03/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 54a0fd3b4c..fe89fd3921 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1031,7 +1031,7 @@ def _get_array_data(self, img): """ return np.asanyarray(img.dataobj, order="C") - + @require_pkg(pkg_name="nibabel") @require_pkg(pkg_name="cupy") @@ -1053,12 +1053,12 @@ def _gds_load(self, file_path): compressed_data = cp.asnumpy(image) with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: decompressed_data = gz_file.read() - + file_size = len(decompressed_data) image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) return image - + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): """ Read image data from specified file or files, it can read a list of images @@ -1078,7 +1078,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img = self._gds_load(name) img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] - + def get_data(self, img): """ Extract data array and metadata from loaded image and return them. @@ -1121,7 +1121,7 @@ def get_data(self, img): _copy_compatible_dict(meta, compatible_meta) return self._stack_images(img_array, compatible_meta), compatible_meta - + def _get_header(self, img): """ Get the all the metadata of the image and convert to dict type. @@ -1138,7 +1138,7 @@ def _get_header(self, img): except ValueError: pass return header - + def _stack_images(self, image_list: list, meta_dict: dict): if len(image_list) <= 1: return image_list[0] From 01a21e055acf5f4431d89ce13347ad30b5d8be35 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:10:45 +0800 Subject: [PATCH 04/27] update loadimage Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 108 ++++++++++++++--------------------- monai/transforms/io/array.py | 11 +++- 2 files changed, 53 insertions(+), 66 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index fe89fd3921..b9bacc303b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any import torch - +from monai.data.meta_tensor import MetaTensor import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -1038,13 +1038,22 @@ def _get_array_data(self, img): @require_pkg(pkg_name="kvikio") class NibabelGPUReader(NibabelReader): - def _gds_load(self, file_path): - file_size = os.path.getsize(file_path) + def read(self, filename: PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Nibabel image object or list of Nibabel image objects. + + Args: + data: file name. + + """ + file_size = os.path.getsize(filename) image = cp.empty(file_size, dtype=cp.uint8) - with kvikio.CuFile(file_path, "r") as f: + with kvikio.CuFile(filename, "r") as f: f.read(image) - if file_path.endswith(".gz"): + if filename.endswith(".gz"): # for compressed data, have to tansfer to CPU to decompress # and then transfer back to GPU. It is not efficient compared to .nii file # but it's still faster than Nibabel's default reader. @@ -1056,29 +1065,8 @@ def _gds_load(self, file_path): file_size = len(decompressed_data) image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) - return image - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Nibabel image object or list of Nibabel image objects. - - Args: - data: file name or a list of file names to read. - - """ - img_ = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = self._gds_load(name) - img_.append(img) # type: ignore - return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img): """ Extract data array and metadata from loaded image and return them. @@ -1088,39 +1076,38 @@ def get_data(self, img): and the metadata of the first image is used to present the output metadata. Args: - img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + img: a Nibabel image object loaded from an image file. """ - compatible_meta: dict = {} - img_array = [] - for i in ensure_tuple(img): - header = self._get_header(i) - data_offset = header.get_data_offset() - data_shape = header.get_data_shape() - data_dtype = header.get_data_dtype() - affine = header.get_best_affine() - meta = dict(header) - meta[MetaKeys.AFFINE] = affine - meta[MetaKeys.ORIGINAL_AFFINE] = affine - # TODO: as_closest_canonical - # TODO: correct_nifti_header_if_necessary - meta[MetaKeys.SPATIAL_SHAPE] = data_shape - # TODO: figure out why always RAS for NibabelReader ? - # meta[MetaKeys.SPACE] = SpaceKeys.RAS - - data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F") - # TODO: check channel - # if self.squeeze_non_spatial_dims: - img_array.append(data) - if self.channel_dim is None: # default to "no_channel" or -1 - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(meta, compatible_meta) - return self._stack_images(img_array, compatible_meta), compatible_meta + # TODO: use a formal way for device + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + header = self._get_header(img) + data_offset = header.get_data_offset() + data_shape = header.get_data_shape() + data_dtype = header.get_data_dtype() + affine = header.get_best_affine() + meta = dict(header) + meta[MetaKeys.AFFINE] = affine + meta[MetaKeys.ORIGINAL_AFFINE] = affine + # TODO: as_closest_canonical + # TODO: correct_nifti_header_if_necessary + meta[MetaKeys.SPATIAL_SHAPE] = data_shape + # TODO: figure out why always RAS for NibabelReader ? + # meta[MetaKeys.SPACE] = SpaceKeys.RAS + + data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F") + # TODO: check channel + # if self.squeeze_non_spatial_dims: + if self.channel_dim is None: # default to "no_channel" or -1 + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + + return MetaTensor(data, affine=affine, meta=meta, device=device) def _get_header(self, img): """ @@ -1139,15 +1126,6 @@ def _get_header(self, img): pass return header - def _stack_images(self, image_list: list, meta_dict: dict): - if len(image_list) <= 1: - return image_list[0] - if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): - channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) - return torch.cat(image_list, axis=channel_dim) - meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 - return torch.stack(image_list, dim=0) - class NumpyReader(ImageReader): """ diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index eb0a0b88d8..52f98ce8ee 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -258,6 +258,16 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img, err = None, [] if reader is not None: + if isinstance(reader, NibabelGPUReader): + buffer = reader.read(filename) + img = reader.get_data(buffer) + # TODO: check ensure channel first + if self.ensure_channel_first: + img = EnsureChannelFirst()(img) + if self.image_only: + return img + return img, img.meta + img = reader.read(filename) # runtime specified reader else: for reader in self.readers[::-1]: @@ -288,7 +298,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered: {self.readers}.\n{msg}" ) - img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] From be77a45be56f840fd096dd76fda235b50553deaa Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:13:02 +0800 Subject: [PATCH 05/27] add init Signed-off-by: Yiheng Wang --- monai/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 340c5eb8fa..14d0dfb193 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -50,7 +50,7 @@ from .folder_layout import FolderLayout, FolderLayoutBase from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, From b4a747ce096f2a691a69fedcd94a782588730de5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:47:17 +0800 Subject: [PATCH 06/27] update filename Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 52f98ce8ee..9012f2bb80 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -259,7 +259,8 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader img, err = None, [] if reader is not None: if isinstance(reader, NibabelGPUReader): - buffer = reader.read(filename) + # TODO: handle multiple filenames later + buffer = reader.read(filename[0]) img = reader.get_data(buffer) # TODO: check ensure channel first if self.ensure_channel_first: From f6af1202bd7ab913f5775219874c2e1e55974333 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:53:20 +0800 Subject: [PATCH 07/27] update supported reader Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 9012f2bb80..f465fe60a6 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -70,7 +70,6 @@ "numpyreader": NumpyReader, "pilreader": PILReader, "nibabelreader": NibabelReader, - "nibabelgpureader": NibabelGPUReader, } From 009fdf7d60d449e767ae1108cfdb54d3b35a72c5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 09:59:39 +0800 Subject: [PATCH 08/27] update load image call Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index f465fe60a6..4e2fdfcda8 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -277,6 +277,16 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader break else: # try the user designated readers try: + if isinstance(reader, NibabelGPUReader): + # TODO: handle multiple filenames later + buffer = reader.read(filename[0]) + img = reader.get_data(buffer) + # TODO: check ensure channel first + if self.ensure_channel_first: + img = EnsureChannelFirst()(img) + if self.image_only: + return img + return img, img.meta img = reader.read(filename) except Exception as e: err.append(traceback.format_exc()) From 27d218a1a15f65cc96448a1438f4668a0a6e2831 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 11:36:15 +0800 Subject: [PATCH 09/27] remove useless header Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b9bacc303b..68ef5420ae 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1088,7 +1088,7 @@ def get_data(self, img): data_shape = header.get_data_shape() data_dtype = header.get_data_dtype() affine = header.get_best_affine() - meta = dict(header) + meta = {} meta[MetaKeys.AFFINE] = affine meta[MetaKeys.ORIGINAL_AFFINE] = affine # TODO: as_closest_canonical From 1baa31b85fc887dd009f600c9db69964669c0dc7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 2 Nov 2024 12:03:04 +0800 Subject: [PATCH 10/27] add filename Signed-off-by: Yiheng Wang --- monai/transforms/io/array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e2fdfcda8..455e38ac08 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -261,6 +261,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader # TODO: handle multiple filenames later buffer = reader.read(filename[0]) img = reader.get_data(buffer) + img.meta[Key.FILENAME_OR_OBJ] = filename[0] # TODO: check ensure channel first if self.ensure_channel_first: img = EnsureChannelFirst()(img) @@ -281,6 +282,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader # TODO: handle multiple filenames later buffer = reader.read(filename[0]) img = reader.get_data(buffer) + img.meta[Key.FILENAME_OR_OBJ] = filename[0] # TODO: check ensure channel first if self.ensure_channel_first: img = EnsureChannelFirst()(img) From f4531588232449ad9231aff797e857a474f88397 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 8 Nov 2024 08:11:20 +0000 Subject: [PATCH 11/27] reformat to add gpu load support on nibabelreader Signed-off-by: Yiheng Wang --- monai/data/__init__.py | 2 +- monai/data/image_reader.py | 143 +++++++++++------------------------ monai/data/meta_tensor.py | 13 +++- monai/transforms/io/array.py | 31 ++------ 4 files changed, 59 insertions(+), 130 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 14d0dfb193..340c5eb8fa 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -50,7 +50,7 @@ from .folder_layout import FolderLayout, FolderLayoutBase from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd from .image_dataset import ImageDataset -from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader +from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader from .image_writer import ( SUPPORTED_WRITERS, ImageWriter, diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 68ef5420ae..ae94fcc053 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -58,7 +58,7 @@ cp, has_cp = optional_import("cupy") kvikio, has_kvikio = optional_import("kvikio") -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] class ImageReader(ABC): @@ -155,6 +155,17 @@ def _stack_images(image_list: list, meta_dict: dict): return np.stack(image_list, axis=0) +def _stack_gpu_images(image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return cp.concatenate(image_list, axis=channel_dim) + # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return cp.stack(image_list, axis=0) + + @require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ @@ -887,12 +898,15 @@ def __init__( channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, + gpu_load: bool = False, **kwargs, ): super().__init__() self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + # TODO: add warning if not have required libs + self.gpu_load = gpu_load self.kwargs = kwargs def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: @@ -923,6 +937,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = filenames kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -946,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img_array: list[np.ndarray] = [] compatible_meta: dict = {} - for i in ensure_tuple(img): + for i, filename in zip(ensure_tuple(img), self.filenames): header = self._get_meta_dict(i) header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) @@ -956,7 +971,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) + data = self._get_array_data(i, filename) if self.squeeze_non_spatial_dims: for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): if data.shape[d - 1] == 1: @@ -969,7 +984,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - + if self.gpu_load: + return _stack_gpu_images(img_array, compatible_meta), compatible_meta return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> dict: @@ -1022,7 +1038,7 @@ def _get_spatial_shape(self, img): spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) - def _get_array_data(self, img): + def _get_array_data(self, img, filename): """ Get the raw array data of the image, converted to Numpy array. @@ -1030,103 +1046,32 @@ def _get_array_data(self, img): img: a Nibabel image object loaded from an image file. """ + if self.gpu_load: + file_size = os.path.getsize(filename) + image = cp.empty(file_size, dtype=cp.uint8) + # suggestion from Ming: more tests, diff size + # cucim + nifti + with kvikio.CuFile(filename, "r") as f: + f.read(image) + if filename.endswith(".gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # but it's still faster than Nibabel's default reader. + # TODO: can benchmark more, it may no need to do this since we don't have to use .gz + # since it's waste times especially in training + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + file_size = len(decompressed_data) + image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) + data_shape = img.shape + data_offset = img.dataobj.offset + data_dtype = img.dataobj.dtype + return image[data_offset:].view(data_dtype).reshape(data_shape, order="F") return np.asanyarray(img.dataobj, order="C") -@require_pkg(pkg_name="nibabel") -@require_pkg(pkg_name="cupy") -@require_pkg(pkg_name="kvikio") -class NibabelGPUReader(NibabelReader): - - def read(self, filename: PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Nibabel image object or list of Nibabel image objects. - - Args: - data: file name. - - """ - file_size = os.path.getsize(filename) - image = cp.empty(file_size, dtype=cp.uint8) - with kvikio.CuFile(filename, "r") as f: - f.read(image) - - if filename.endswith(".gz"): - # for compressed data, have to tansfer to CPU to decompress - # and then transfer back to GPU. It is not efficient compared to .nii file - # but it's still faster than Nibabel's default reader. - # TODO: can benchmark more, it may no need to do this since we don't have to use .gz - # since it's waste times especially in training - compressed_data = cp.asnumpy(image) - with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: - decompressed_data = gz_file.read() - - file_size = len(decompressed_data) - image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) - return image - - def get_data(self, img): - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to present the output metadata. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - - # TODO: use a formal way for device - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - - header = self._get_header(img) - data_offset = header.get_data_offset() - data_shape = header.get_data_shape() - data_dtype = header.get_data_dtype() - affine = header.get_best_affine() - meta = {} - meta[MetaKeys.AFFINE] = affine - meta[MetaKeys.ORIGINAL_AFFINE] = affine - # TODO: as_closest_canonical - # TODO: correct_nifti_header_if_necessary - meta[MetaKeys.SPATIAL_SHAPE] = data_shape - # TODO: figure out why always RAS for NibabelReader ? - # meta[MetaKeys.SPACE] = SpaceKeys.RAS - - data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F") - # TODO: check channel - # if self.squeeze_non_spatial_dims: - if self.channel_dim is None: # default to "no_channel" or -1 - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - - return MetaTensor(data, affine=affine, meta=meta, device=device) - - def _get_header(self, img): - """ - Get the all the metadata of the image and convert to dict type. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - header_bytes = cp.asnumpy(img[:348]) - header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes)) - # swap to little endian as PyTorch doesn't support big endian - try: - header = header.as_byteswapped("<") - except ValueError: - pass - return header - - class NumpyReader(ImageReader): """ Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ac171e8508..959108eb47 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -532,7 +532,12 @@ def clone(self, **kwargs): @staticmethod def ensure_torch_and_prune_meta( - im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." + im: NdarrayTensor, + meta: dict | None, + simple_keys: bool = False, + pattern: str | None = None, + sep: str = ".", + device: None | str | torch.device = None, ): """ Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, @@ -547,13 +552,13 @@ def ensure_torch_and_prune_meta( sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``. + device: target device to put the Tensor data. Returns: By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray - + img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img @@ -565,7 +570,7 @@ def ensure_torch_and_prune_meta( if simple_keys: # ensure affine is of type `torch.Tensor` if MetaKeys.AFFINE in meta: - meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking + meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking remove_extra_metadata(meta) # bc-breaking if pattern is not None: diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 455e38ac08..2eb00ab38d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -35,7 +35,6 @@ ImageReader, ITKReader, NibabelReader, - NibabelGPUReader, NrrdReader, NumpyReader, PILReader, @@ -140,6 +139,7 @@ def __init__( prune_meta_pattern: str | None = None, prune_meta_sep: str = ".", expanduser: bool = True, + device: None | str | torch.device = None, *args, **kwargs, ) -> None: @@ -164,6 +164,7 @@ def __init__( e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is. args: additional parameters for reader if providing a reader name. + device: target device to put the loaded image. kwargs: additional parameters for reader if providing a reader name. Note: @@ -185,6 +186,7 @@ def __init__( self.pattern = prune_meta_pattern self.sep = prune_meta_sep self.expanduser = expanduser + self.device = device self.readers: list[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default @@ -257,18 +259,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img, err = None, [] if reader is not None: - if isinstance(reader, NibabelGPUReader): - # TODO: handle multiple filenames later - buffer = reader.read(filename[0]) - img = reader.get_data(buffer) - img.meta[Key.FILENAME_OR_OBJ] = filename[0] - # TODO: check ensure channel first - if self.ensure_channel_first: - img = EnsureChannelFirst()(img) - if self.image_only: - return img - return img, img.meta - img = reader.read(filename) # runtime specified reader else: for reader in self.readers[::-1]: @@ -278,17 +268,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader break else: # try the user designated readers try: - if isinstance(reader, NibabelGPUReader): - # TODO: handle multiple filenames later - buffer = reader.read(filename[0]) - img = reader.get_data(buffer) - img.meta[Key.FILENAME_OR_OBJ] = filename[0] - # TODO: check ensure channel first - if self.ensure_channel_first: - img = EnsureChannelFirst()(img) - if self.image_only: - return img - return img, img.meta img = reader.read(filename) except Exception as e: err.append(traceback.format_exc()) @@ -312,7 +291,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) - img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] + img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0] if not isinstance(meta_data, dict): raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.") # make sure all elements in metadata are little endian @@ -320,7 +299,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader img = MetaTensor.ensure_torch_and_prune_meta( - img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep + img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device ) if self.ensure_channel_first: img = EnsureChannelFirst()(img) From 8d8ba0ff710415f69297aed1efac27ef449d530f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 08:11:43 +0000 Subject: [PATCH 12/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index ae94fcc053..d602af2217 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -22,8 +22,6 @@ from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any -import torch -from monai.data.meta_tensor import MetaTensor import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern From 7eb890fcf522e6d2c478734a5a8faea8b0b8c7b8 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 12 Dec 2024 06:28:32 +0000 Subject: [PATCH 13/27] update Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 48 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index d602af2217..9312117740 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -142,28 +142,21 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict): ) -def _stack_images(image_list: list, meta_dict: dict): +def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False): if len(image_list) <= 1: return image_list[0] if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + if to_cupy and has_cp: + return cp.concatenate(image_list, axis=channel_dim) return np.concatenate(image_list, axis=channel_dim) # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + if to_cupy and has_cp: + return cp.stack(image_list, axis=0) return np.stack(image_list, axis=0) -def _stack_gpu_images(image_list: list, meta_dict: dict): - if len(image_list) <= 1: - return image_list[0] - if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): - channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) - return cp.concatenate(image_list, axis=channel_dim) - # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified - meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 - return cp.stack(image_list, axis=0) - - @require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ @@ -880,12 +873,16 @@ class NibabelReader(ImageReader): Load NIfTI format images based on Nibabel library. Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) channel_dim: the channel dimension of the input image, default is None. this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. if None, `original_channel_dim` will be either `no_channel` or `-1`. most Nifti files are usually "channel last", no need to specify this argument for them. + as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) + to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading. + Default is False. CuPy and Kvikio are required for this option. + Note: For compressed NIfTI files, some operations may still be performed on CPU memory, + and the acceleration may not be significant. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py @@ -896,15 +893,22 @@ def __init__( channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, - gpu_load: bool = False, + to_gpu: bool = False, **kwargs, ): super().__init__() self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims - # TODO: add warning if not have required libs - self.gpu_load = gpu_load + if to_gpu is True: + if not has_cp: + warnings.warn("CuPy is not installed, fall back to use cpu load.") + to_gpu = False + if not has_kvikio: + warnings.warn("Kvikio is not installed, fall back to use cpu load.") + to_gpu = False + + self.to_gpu = to_gpu self.kwargs = kwargs def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: @@ -982,8 +986,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - if self.gpu_load: - return _stack_gpu_images(img_array, compatible_meta), compatible_meta + if self.to_gpu: + return _stack_images(img_array, compatible_meta, to_cupy=True), compatible_meta return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> dict: @@ -1047,22 +1051,18 @@ def _get_array_data(self, img, filename): if self.gpu_load: file_size = os.path.getsize(filename) image = cp.empty(file_size, dtype=cp.uint8) - # suggestion from Ming: more tests, diff size - # cucim + nifti with kvikio.CuFile(filename, "r") as f: f.read(image) if filename.endswith(".gz"): # for compressed data, have to tansfer to CPU to decompress # and then transfer back to GPU. It is not efficient compared to .nii file # but it's still faster than Nibabel's default reader. - # TODO: can benchmark more, it may no need to do this since we don't have to use .gz - # since it's waste times especially in training compressed_data = cp.asnumpy(image) with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: decompressed_data = gz_file.read() file_size = len(decompressed_data) - image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) + image = cp.frombuffer(decompressed_data, dtype=cp.uint8) data_shape = img.shape data_offset = img.dataobj.offset data_dtype = img.dataobj.dtype From a62b1dc737a2448f4064dd29c9cf31bf4757cb55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Dec 2024 06:28:43 +0000 Subject: [PATCH 14/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9312117740..7f8f59b57a 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -907,7 +907,7 @@ def __init__( if not has_kvikio: warnings.warn("Kvikio is not installed, fall back to use cpu load.") to_gpu = False - + self.to_gpu = to_gpu self.kwargs = kwargs From 5f9ac060b8d7a4f2a9970daa6c469b7fbd16752d Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 12 Dec 2024 06:39:00 +0000 Subject: [PATCH 15/27] update to_cupy Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 7f8f59b57a..605d043410 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -986,9 +986,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - if self.to_gpu: - return _stack_images(img_array, compatible_meta, to_cupy=True), compatible_meta - return _stack_images(img_array, compatible_meta), compatible_meta + + return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta def _get_meta_dict(self, img) -> dict: """ From d052a5f1494e6a032540fee4c9169c49ced0e381 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 13 Dec 2024 06:20:43 +0000 Subject: [PATCH 16/27] add tests Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 17 ++++++++-------- monai/data/meta_tensor.py | 4 +++- tests/test_init_reader.py | 19 ++++++++++++++++++ tests/test_load_image.py | 41 +++++++++++++++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 10 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 605d043410..9d61c21fe8 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -12,16 +12,17 @@ from __future__ import annotations import glob -import os -import re import gzip import io +import os +import re import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any + import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -36,14 +37,14 @@ from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg if TYPE_CHECKING: + import cupy as cp import itk + import kvikio import nibabel as nib import nrrd import pydicom from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - import cupy as cp - import kvikio has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True else: @@ -948,7 +949,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> tuple[np.ndarray, dict]: + def get_data(self, img) -> tuple[np.ndarray | "cp.ndarray", dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -960,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ - img_array: list[np.ndarray] = [] + img_array: list[np.ndarray | "cp.ndarray"] = [] compatible_meta: dict = {} for i, filename in zip(ensure_tuple(img), self.filenames): @@ -1047,12 +1048,12 @@ def _get_array_data(self, img, filename): img: a Nibabel image object loaded from an image file. """ - if self.gpu_load: + if self.to_gpu: file_size = os.path.getsize(filename) image = cp.empty(file_size, dtype=cp.uint8) with kvikio.CuFile(filename, "r") as f: f.read(image) - if filename.endswith(".gz"): + if filename.endswith(".nii.gz"): # for compressed data, have to tansfer to CPU to decompress # and then transfer back to GPU. It is not efficient compared to .nii file # but it's still faster than Nibabel's default reader. diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 959108eb47..8c729088ee 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -558,7 +558,9 @@ def ensure_torch_and_prune_meta( By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray + img = convert_to_tensor( + im, track_meta=get_track_meta() and meta is not None, device=device + ) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index cb45cb5146..8331f742ec 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -30,6 +30,17 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_load_image_to_gpu(self): + for to_gpu in [True, False]: + instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance1, LoadImage) + + instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance2, LoadImaged) + @SkipIfNoModule("itk") @SkipIfNoModule("nibabel") @SkipIfNoModule("PIL") @@ -58,6 +69,14 @@ def test_readers(self): inst = NrrdReader() self.assertIsInstance(inst, NrrdReader) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_readers_to_gpu(self): + for to_gpu in [True, False]: + inst = NibabelReader(to_gpu=to_gpu) + self.assertIsInstance(inst, NibabelReader) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 0207079d7d..a3e6d7bcfc 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -29,7 +29,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImage from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config +from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config itk, has_itk = optional_import("itk", allow_namespace_pkg=True) ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator") @@ -74,6 +74,22 @@ def get_data(self, _obj): TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)] + +TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)] + +TEST_CASE_GPU_3 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii", "test_image2.nii", "test_image3.nii"], + (3, 128, 128, 128), +] + +TEST_CASE_GPU_4 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + (3, 128, 128, 128), +] + TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] @@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4]) + def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): + test_image = np.random.rand(128, 128, 128) + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadImage(image_only=True, **input_param)(filenames) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + self.assertEqual(result.meta["space"], "RAS") + assert_allclose(result.affine, torch.eye(4)) + self.assertTupleEqual(result.shape, expected_shape) + + # verify gpu and cpu loaded data are the same + input_param_cpu = input_param.copy() + input_param_cpu["to_gpu"] = False + result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames) + self.assertTrue(torch.equal(result_cpu, result.cpu())) + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): test_image = np.random.rand(128, 128, 128) From a987a943f4d12ef61ffc3573bfa588ea321444f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Dec 2024 06:20:56 +0000 Subject: [PATCH 17/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9d61c21fe8..ff2b8cdf6c 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -949,7 +949,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> tuple[np.ndarray | "cp.ndarray", dict]: + def get_data(self, img) -> tuple[np.ndarray | cp.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -961,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray | "cp.ndarray", dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ - img_array: list[np.ndarray | "cp.ndarray"] = [] + img_array: list[np.ndarray | cp.ndarray] = [] compatible_meta: dict = {} for i, filename in zip(ensure_tuple(img), self.filenames): From 1b12a39badaa643475a36840e9095cf73f56bbef Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 13 Dec 2024 07:17:00 +0000 Subject: [PATCH 18/27] add description on warm up Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index ff2b8cdf6c..f137171963 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -883,7 +883,10 @@ class NibabelReader(ImageReader): to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading. Default is False. CuPy and Kvikio are required for this option. Note: For compressed NIfTI files, some operations may still be performed on CPU memory, - and the acceleration may not be significant. + and the acceleration may not be significant. In some cases, it may be slower than loading on CPU. + #TODO: the first kvikio call is slow since it will initialize internal buffers, cuFile, GDS, etc. + In practical use, it's recommended to add a warm up call before the actual loading. + A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py @@ -1056,7 +1059,8 @@ def _get_array_data(self, img, filename): if filename.endswith(".nii.gz"): # for compressed data, have to tansfer to CPU to decompress # and then transfer back to GPU. It is not efficient compared to .nii file - # but it's still faster than Nibabel's default reader. + # and may be slower than CPU loading in some cases. + warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.") compressed_data = cp.asnumpy(image) with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: decompressed_data = gz_file.read() From b70a5f58093adafa60914786551fbdfb62d3e3dd Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:08:58 +0800 Subject: [PATCH 19/27] Update monai/data/image_reader.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/data/image_reader.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index f137171963..cecea02f9f 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -904,13 +904,9 @@ def __init__( self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims - if to_gpu is True: - if not has_cp: - warnings.warn("CuPy is not installed, fall back to use cpu load.") - to_gpu = False - if not has_kvikio: - warnings.warn("Kvikio is not installed, fall back to use cpu load.") - to_gpu = False + if to_gpu and (not has_cp or not has_kvikio): + warnings.warn("NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading.") + to_gpu = False self.to_gpu = to_gpu self.kwargs = kwargs From 83a1daf34d5b7f1bc2e471f1ffb7ff08e3d05329 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 16 Dec 2024 07:09:44 +0000 Subject: [PATCH 20/27] add doc string Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index cecea02f9f..32ffd9ff83 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1045,6 +1045,7 @@ def _get_array_data(self, img, filename): Args: img: a Nibabel image object loaded from an image file. + filename: file name of the image. """ if self.to_gpu: From acf2cba64326f56993dc224a9a0ebfcf645d9adf Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 16 Dec 2024 07:28:46 +0000 Subject: [PATCH 21/27] resolve comments Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 32ffd9ff83..6598fc3829 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -37,16 +37,14 @@ from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg if TYPE_CHECKING: - import cupy as cp import itk - import kvikio import nibabel as nib import nrrd import pydicom from nibabel.nifti1 import Nifti1Image from PIL import Image as PILImage - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True else: itk, has_itk = optional_import("itk", allow_namespace_pkg=True) nib, has_nib = optional_import("nibabel") @@ -54,8 +52,9 @@ PILImage, has_pil = optional_import("PIL.Image") pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) - cp, has_cp = optional_import("cupy") - kvikio, has_kvikio = optional_import("kvikio") + +cp, has_cp = optional_import("cupy") +kvikio, has_kvikio = optional_import("kvikio") __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -905,7 +904,9 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims if to_gpu and (not has_cp or not has_kvikio): - warnings.warn("NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading.") + warnings.warn( + "NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading." + ) to_gpu = False self.to_gpu = to_gpu @@ -948,7 +949,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] - def get_data(self, img) -> tuple[np.ndarray | cp.ndarray, dict]: + def get_data(self, img) -> tuple[np.ndarray, dict]: """ Extract data array and metadata from loaded image and return them. This function returns two objects, first is numpy array of image data, second is dict of metadata. @@ -960,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray | cp.ndarray, dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ - img_array: list[np.ndarray | cp.ndarray] = [] + img_array: list[np.ndarray] = [] compatible_meta: dict = {} for i, filename in zip(ensure_tuple(img), self.filenames): From e5d790786f532c6d53e288b7b1a07e947fe41f36 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 17 Dec 2024 09:44:44 +0000 Subject: [PATCH 22/27] update Signed-off-by: Yiheng Wang --- monai/data/image_reader.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 6598fc3829..86702f087a 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -16,6 +16,7 @@ import io import os import re +import tempfile import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -909,9 +910,30 @@ def __init__( ) to_gpu = False + if to_gpu: + self.warmup_kvikio() + self.to_gpu = to_gpu self.kwargs = kwargs + def warmup_kvikio(self): + """ + Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc. + This can accelerate the data loading process when `to_gpu` is set to True. + """ + if has_cp and has_kvikio: + print("warm up") + a = cp.arange(100) + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b) + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -961,6 +983,9 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ + # TODO: the actual type is list[np.ndarray | cp.ndarray] + # should figure out how to define correct types without having cupy not found error + # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 img_array: list[np.ndarray] = [] compatible_meta: dict = {} From 5005846b7cdf850dde8c43a856629ccc8669eba5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:44:33 +0800 Subject: [PATCH 23/27] Update monai/data/image_reader.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/data/image_reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 86702f087a..4d5d7d2137 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -922,7 +922,6 @@ def warmup_kvikio(self): This can accelerate the data loading process when `to_gpu` is set to True. """ if has_cp and has_kvikio: - print("warm up") a = cp.arange(100) with tempfile.NamedTemporaryFile() as tmp_file: tmp_file_name = tmp_file.name From c43b916c1625b7ae0dc5e69ed710247e32160db9 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:44:46 +0800 Subject: [PATCH 24/27] Update monai/data/image_reader.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/data/image_reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 4d5d7d2137..f3895b1f87 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -884,7 +884,6 @@ class NibabelReader(ImageReader): Default is False. CuPy and Kvikio are required for this option. Note: For compressed NIfTI files, some operations may still be performed on CPU memory, and the acceleration may not be significant. In some cases, it may be slower than loading on CPU. - #TODO: the first kvikio call is slow since it will initialize internal buffers, cuFile, GDS, etc. In practical use, it's recommended to add a warm up call before the actual loading. A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `nibabel.load` API. more details about available args: From 02c77f0254405a497d5f4e2aac6f27a205c44dcb Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:45:20 +0800 Subject: [PATCH 25/27] Update monai/data/image_reader.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> --- monai/data/image_reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index f3895b1f87..5bc38f69ea 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1086,7 +1086,6 @@ def _get_array_data(self, img, filename): with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: decompressed_data = gz_file.read() - file_size = len(decompressed_data) image = cp.frombuffer(decompressed_data, dtype=cp.uint8) data_shape = img.shape data_offset = img.dataobj.offset From 6f4e5cb03b2d6cd5723bda6085764239b80e8111 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Dec 2024 06:59:41 +0000 Subject: [PATCH 26/27] remove device set Signed-off-by: Yiheng Wang --- monai/data/meta_tensor.py | 8 ++------ monai/transforms/io/array.py | 6 ++---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8c729088ee..8282ca2098 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -537,7 +537,6 @@ def ensure_torch_and_prune_meta( simple_keys: bool = False, pattern: str | None = None, sep: str = ".", - device: None | str | torch.device = None, ): """ Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, @@ -552,15 +551,12 @@ def ensure_torch_and_prune_meta( sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary). default is ".", see also :py:class:`monai.transforms.DeleteItemsd`. e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``. - device: target device to put the Tensor data. Returns: By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ - img = convert_to_tensor( - im, track_meta=get_track_meta() and meta is not None, device=device - ) # potentially ascontiguousarray + img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img @@ -572,7 +568,7 @@ def ensure_torch_and_prune_meta( if simple_keys: # ensure affine is of type `torch.Tensor` if MetaKeys.AFFINE in meta: - meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking + meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking remove_extra_metadata(meta) # bc-breaking if pattern is not None: diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 2eb00ab38d..7fa710b3a4 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -164,7 +164,6 @@ def __init__( e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``. expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is. args: additional parameters for reader if providing a reader name. - device: target device to put the loaded image. kwargs: additional parameters for reader if providing a reader name. Note: @@ -186,7 +185,6 @@ def __init__( self.pattern = prune_meta_pattern self.sep = prune_meta_sep self.expanduser = expanduser - self.device = device self.readers: list[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default @@ -291,7 +289,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader ) img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) - img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0] + img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] if not isinstance(meta_data, dict): raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.") # make sure all elements in metadata are little endian @@ -299,7 +297,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader img = MetaTensor.ensure_torch_and_prune_meta( - img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device + img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep ) if self.ensure_channel_first: img = EnsureChannelFirst()(img) From b6fb2abc3e2ea92bb9cc4412485462c2137c065c Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Dec 2024 07:25:44 +0000 Subject: [PATCH 27/27] update format Signed-off-by: Yiheng Wang --- monai/data/meta_tensor.py | 6 +----- monai/transforms/io/array.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8282ca2098..c4c491e1b9 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -532,11 +532,7 @@ def clone(self, **kwargs): @staticmethod def ensure_torch_and_prune_meta( - im: NdarrayTensor, - meta: dict | None, - simple_keys: bool = False, - pattern: str | None = None, - sep: str = ".", + im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." ): """ Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7fa710b3a4..1023cd7a7d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -139,7 +139,6 @@ def __init__( prune_meta_pattern: str | None = None, prune_meta_sep: str = ".", expanduser: bool = True, - device: None | str | torch.device = None, *args, **kwargs, ) -> None: