diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 2c809b9817..1702a1211c 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -89,7 +89,7 @@ jobs: steps: - name: Import run: | - export CUDA_VISIBLE_DEVICES= # cpu-only + export OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 CUDA_VISIBLE_DEVICES= # cpu-only python -c 'import monai; monai.config.print_debug_info()' cd /opt/monai ls -al diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d97cabc2cb..c37ce2d425 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -151,6 +151,10 @@ make html ``` The above commands build html documentation, they are used to automatically generate [https://docs.monai.io](https://docs.monai.io). +The Python code docstring are written in +[reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) and +the documentation pages can be in either [reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) or [Markdown](https://en.wikipedia.org/wiki/Markdown). In general the Python docstrings follow the [Google style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). + Before submitting a pull request, it is recommended to: - edit the relevant `.rst` files in [`docs/source`](./docs/source) accordingly. - build html documentation locally diff --git a/docs/images/precision_options.png b/docs/images/precision_options.png new file mode 100644 index 0000000000..269560d80f Binary files /dev/null and b/docs/images/precision_options.png differ diff --git a/docs/source/data.rst b/docs/source/data.rst index b789102b81..63d5e0e23d 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -45,6 +45,13 @@ Generic Interfaces :members: :special-members: __getitem__ +`GDSDataset` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: GDSDataset + :members: + :special-members: __getitem__ + + `CacheNTransDataset` ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: CacheNTransDataset diff --git a/docs/source/index.rst b/docs/source/index.rst index 1fde0b0ef3..54dc6e6922 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -58,6 +58,12 @@ Technical documentation is available at `docs.monai.io `_ installation +.. toctree:: + :maxdepth: 1 + :caption: Precision and Performance + + precision_performance + .. toctree:: :maxdepth: 1 :caption: Contributing diff --git a/docs/source/precision_performance.md b/docs/source/precision_performance.md new file mode 100644 index 0000000000..6e6c51d8c1 --- /dev/null +++ b/docs/source/precision_performance.md @@ -0,0 +1,39 @@ +# Precision and Performance + +Modern GPU architectures usually can use reduced precision tensor data or computational operations to save memory and increase throughput. However, in some cases, the reduced precision will cause numerical stability issues, and further cause reproducibility issues. Therefore, please ensure that you are using appropriate precision. + + + +## TensorFloat-32 (TF32) + +### Introduction + +NVIDIA introduced a new math mode TensorFloat-32 (TF32) for NVIDIA Ampere GPUs and above, see [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/), [TRAINING NEURAL NETWORKS +WITH TENSOR CORES](https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf), [CUDA 11](https://developer.nvidia.com/blog/cuda-11-features-revealed/) and [Ampere architecture](https://developer.nvidia.com/blog/nvidia-ampere-architecture-in-depth/). + +TF32 adopts 8 exponent bits, 10 bits of mantissa, and one sign bit. + +![Precision options used for AI training.](../images/precision_options.png) + +### Potential Impact + +Although NVIDIA has shown that TF32 mode can reach the same accuracy and convergence as float32 for most AI workloads, some users still find some significant effect on their applications, see [PyTorch and TensorFloat32](https://dev-discuss.pytorch.org/t/pytorch-and-tensorfloat32/504). Users who need high-precision matrix operation, such as traditional computer graphics operation and kernel method, may be affected by TF32 precision. + +Note that all operations that use `cuda.matmul` may be affected +by TF32 mode so the impact is very wide. + +### Settings + +[PyTorch TF32](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) default value: +```python +torch.backends.cuda.matmul.allow_tf32 = False # in PyTorch 1.12 and later. +torch.backends.cudnn.allow_tf32 = True +``` +Please note that there are environment variables that can override the flags above. For example, the environment variables mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden. + +We recommend that users print out these two flags for confirmation when unsure. + +If you are using an [NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), the container includes a layer `ENV TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`. +The default value `torch.backends.cuda.matmul.allow_tf32` will be overridden to `True`. + +If you can confirm through experiments that your model has no accuracy or convergence issues in TF32 mode and you have NVIDIA Ampere GPUs or above, you can set the two flags above to `True` to speed up your model. diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index e09317aac2..9dfe82a992 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -127,6 +127,8 @@ def _resolve_one_item( item = look_up_option(id, self.items, print_all_options=False, default=kwargs.get("default", "no_default")) except ValueError as err: raise KeyError(f"id='{id}' is not found in the config resolver.") from err + if not isinstance(item, ConfigItem): + return item item_config = item.get_config() if waiting_list is None: @@ -151,11 +153,10 @@ def _resolve_one_item( look_up_option(d, self.items, print_all_options=False) except ValueError as err: msg = f"the referring item `@{d}` is not defined in the config content." - if self.allow_missing_reference: - warnings.warn(msg) - continue - else: + if not self.allow_missing_reference: raise ValueError(msg) from err + warnings.warn(msg) + continue # recursively resolve the reference first self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs) waiting_list.discard(d) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 9339897d7a..340c5eb8fa 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -33,6 +33,7 @@ CSVDataset, Dataset, DatasetFunc, + GDSDataset, LMDBDataset, NPZDictItemDataset, PersistentDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 912576bdcc..a20511267b 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -34,6 +34,7 @@ from torch.utils.data import Dataset as _TorchDataset from torch.utils.data import Subset +from monai.data.meta_tensor import MetaTensor from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing from monai.transforms import ( Compose, @@ -44,7 +45,7 @@ convert_to_contiguous, reset_ops_id, ) -from monai.utils import MAX_SEED, get_seed, look_up_option, min_version, optional_import +from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first if TYPE_CHECKING: @@ -54,8 +55,10 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") +cp, _ = optional_import("cupy") lmdb, _ = optional_import("lmdb") pd, _ = optional_import("pandas") +kvikio_numpy, _ = optional_import("kvikio.numpy") class Dataset(_TorchDataset): @@ -326,7 +329,6 @@ def _pre_transform(self, item_transformed): first_random = self.transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) - item_transformed = self.transform(item_transformed, end=first_random, threading=True) if self.reset_ops_id: @@ -1510,3 +1512,180 @@ def __init__( dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) super().__init__(data=data, transform=transform) + + +class GDSDataset(PersistentDataset): + """ + An extension of the PersistentDataset using direct memory access(DMA) data path between + GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system + bandwidth while decreasing latency and utilization load on the CPU and GPU. + + A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb. + + See also: https://github.com/rapidsai/kvikio + """ + + def __init__( + self, + data: Sequence, + transform: Sequence[Callable] | Callable, + cache_dir: Path | str | None, + device: int, + hash_func: Callable[..., bytes] = pickle_hashing, + hash_transform: Callable[..., bytes] | None = None, + reset_ops_id: bool = True, + **kwargs: Any, + ) -> None: + """ + Args: + data: input data file paths to load and transform to generate dataset for model. + `GDSDataset` expects input data to be a list of serializable + and hashes them as cache keys using `hash_func`. + transform: transforms to execute operations on input data. + cache_dir: If specified, this is the location for gpu direct storage + of pre-computed transformed data tensors. The cache_dir is computed once, and + persists on disk until explicitly removed. Different runs, programs, experiments + may share a common cache dir provided that the transforms pre-processing is consistent. + If `cache_dir` doesn't exist, will automatically create it. + If `cache_dir` is `None`, there is effectively no caching. + device: target device to put the output Tensor data. Note that only int can be used to + specify the gpu to be used. + hash_func: a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. + hash_transform: a callable to compute hash from the transform information when caching. + This may reduce errors due to transforms changing during experiments. Default to None (no hash). + Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. + reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``. + When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. + This is useful for skipping the transform instance checks when inverting applied operations + using the cached content and with re-created transform instances. + + """ + super().__init__( + data=data, + transform=transform, + cache_dir=cache_dir, + hash_func=hash_func, + hash_transform=hash_transform, + reset_ops_id=reset_ops_id, + **kwargs, + ) + self.device = device + self._meta_cache: dict[Any, dict[Any, Any]] = {} + + def _cachecheck(self, item_transformed): + """ + In order to enable direct storage to the GPU when loading the hashfile, rewritten this function. + Note that in this function, it will always return `torch.Tensor` when load data from cache. + + Args: + item_transformed: The current data element to be mutated into transformed representation + + Returns: + The transformed data_element, either from cache, or explicitly computing it. + + Warning: + The current implementation does not encode transform information as part of the + hashing mechanism used for generating cache names when `hash_transform` is None. + If the transforms applied are changed in any way, the objects in the cache dir will be invalid. + + """ + hashfile = None + # compute a cache id + if self.cache_dir is not None: + data_item_md5 = self.hash_func(item_transformed).decode("utf-8") + data_item_md5 += self.transform_hash + hashfile = self.cache_dir / f"{data_item_md5}.pt" + + if hashfile is not None and hashfile.is_file(): # cache hit + with cp.cuda.Device(self.device): + if isinstance(item_transformed, dict): + item: dict[Any, Any] = {} # type:ignore + for k in item_transformed: + meta_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta") + item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(())) + item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") + item[f"{k}_meta_dict"] = meta_k + return item + elif isinstance(item_transformed, (np.ndarray, torch.Tensor)): + _meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta") + _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(())) + _data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}") + filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys())) + if bool(filtered_keys): + return (_data, _meta) + return _data + else: + item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore + for i, _item in enumerate(item_transformed): + for k in _item: + meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}") + item_k = kvikio_numpy.fromfile( + f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(()) + ) + item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") + item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k}) + return item + + # create new cache + _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed + if hashfile is None: + return _item_transformed + if isinstance(_item_transformed, dict): + for k in _item_transformed: + data_hashfile = f"{hashfile}-{k}" + meta_hash_file_name = f"{hashfile.name}-{k}-meta" + if isinstance(_item_transformed[k], (np.ndarray, torch.Tensor)): + self._create_new_cache(_item_transformed[k], data_hashfile, meta_hash_file_name) + else: + return _item_transformed + elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)): + data_hashfile = f"{hashfile}" + meta_hash_file_name = f"{hashfile.name}-meta" + self._create_new_cache(_item_transformed, data_hashfile, meta_hash_file_name) + else: + for i, _item in enumerate(_item_transformed): + for k in _item: + data_hashfile = f"{hashfile}-{k}-{i}" + meta_hash_file_name = f"{hashfile.name}-{k}-meta-{i}" + self._create_new_cache(_item, data_hashfile, meta_hash_file_name) + open(hashfile, "a").close() # store cacheid + return _item_transformed + + def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): + self._meta_cache[meta_hash_file_name] = copy(data.meta) if isinstance(data, MetaTensor) else {} + _item_transformed_data = data.array if isinstance(data, MetaTensor) else data + if isinstance(_item_transformed_data, torch.Tensor): + _item_transformed_data = _item_transformed_data.numpy() + self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape + self._meta_cache[meta_hash_file_name]["dtype"] = str(_item_transformed_data.dtype) + kvikio_numpy.tofile(_item_transformed_data, data_hashfile) + try: + # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation + # to make the cache more robust to manual killing of parent process + # which may leave partially written cache files in an incomplete state + with tempfile.TemporaryDirectory() as tmpdirname: + meta_hash_file = self.cache_dir / meta_hash_file_name + temp_hash_file = Path(tmpdirname) / meta_hash_file_name + torch.save( + obj=self._meta_cache[meta_hash_file_name], + f=temp_hash_file, + pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_protocol=self.pickle_protocol, + ) + if temp_hash_file.is_file() and not meta_hash_file.is_file(): + # On Unix, if target exists and is a file, it will be replaced silently if the + # user has permission. + # for more details: https://docs.python.org/3/library/shutil.html#shutil.move. + try: + shutil.move(str(temp_hash_file), meta_hash_file) + except FileExistsError: + pass + except PermissionError: # project-monai/monai issue #3613 + pass + + def _load_meta_cache(self, meta_hash_file_name): + if meta_hash_file_name in self._meta_cache: + return self._meta_cache[meta_hash_file_name] + else: + return torch.load(self.cache_dir / meta_hash_file_name) # type:ignore diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 33f7b8a53c..fdf4997e58 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -202,7 +202,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: :py:func:`MetaTensor._copy_meta`). """ out = [] - metas = None + metas = None # optional output metadicts for each of the return value in `rets` is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch")) for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. @@ -219,55 +219,61 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: # the following is not implemented but the network arch may run into this case: # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") - - # If we have a batch of data, then we need to be careful if a slice of - # the data is returned. Depending on how the data are indexed, we return - # some or all of the metadata, and the return object may or may not be a - # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if is_batch: - # if indexing e.g., `batch[0]` - if func == torch.Tensor.__getitem__: - batch_idx = args[1] - if isinstance(batch_idx, Sequence): - batch_idx = batch_idx[0] - # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the - # first element will be `slice(None, None, None)` and `Ellipsis`, - # respectively. Don't need to do anything with the metadata. - if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0: - ret_meta = decollate_batch(args[0], detach=False)[batch_idx] - if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate - try: - ret_meta = list_data_collate(ret_meta) - except (TypeError, ValueError, RuntimeError, IndexError) as e: - raise ValueError( - "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, " - "please convert it into a torch Tensor using `x.as_tensor()` or " - "a numpy array using `x.array`." - ) from e - elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int - ret_meta.is_batch = False - if hasattr(ret_meta, "__dict__"): - ret.__dict__ = ret_meta.__dict__.copy() - # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. - # But we only want to split the batch if the `unbind` is along the 0th - # dimension. - elif func == torch.Tensor.unbind: - if len(args) > 1: - dim = args[1] - elif "dim" in kwargs: - dim = kwargs["dim"] - else: - dim = 0 - if dim == 0: - if metas is None: - metas = decollate_batch(args[0], detach=False) - ret.__dict__ = metas[idx].__dict__.copy() - ret.is_batch = False - + ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out + @classmethod + def _handle_batched(cls, ret, idx, metas, func, args, kwargs): + """utility function to handle batched MetaTensors.""" + # If we have a batch of data, then we need to be careful if a slice of + # the data is returned. Depending on how the data are indexed, we return + # some or all of the metadata, and the return object may or may not be a + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). + # if indexing e.g., `batch[0]` + if func == torch.Tensor.__getitem__: + if idx > 0 or len(args) < 2 or len(args[0]) < 1: + return ret + batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1] + # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the + # first element will be `slice(None, None, None)` and `Ellipsis`, + # respectively. Don't need to do anything with the metadata. + if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor): + return ret + dec_batch = decollate_batch(args[0], detach=False) + ret_meta = dec_batch[batch_idx] + if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate + try: + ret_meta = list_data_collate(ret_meta) + except (TypeError, ValueError, RuntimeError, IndexError) as e: + raise ValueError( + "Inconsistent batched metadata dicts when slicing a batch of MetaTensors, " + "please consider converting it into a torch Tensor using `x.as_tensor()` or " + "a numpy array using `x.array`." + ) from e + elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int + ret_meta.is_batch = False + if hasattr(ret_meta, "__dict__"): + ret.__dict__ = ret_meta.__dict__.copy() + # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. + # But we only want to split the batch if the `unbind` is along the 0th dimension. + elif func == torch.Tensor.unbind: + if len(args) > 1: + dim = args[1] + elif "dim" in kwargs: + dim = kwargs["dim"] + else: + dim = 0 + if dim == 0: + if metas is None: + metas = decollate_batch(args[0], detach=False) + if hasattr(metas[idx], "__dict__"): + ret.__dict__ = metas[idx].__dict__.copy() + ret.is_batch = False + return ret + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: """Wraps all torch functions.""" diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 088ad50efd..4f53e3a7d8 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -268,6 +268,7 @@ def __init__( smooth_dr: a small constant added to the denominator to avoid nan. batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. + If True, the class-weighted intersection and union areas are first summed across the batches. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -360,8 +361,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1) w = w + infs * max_values - numer = 2.0 * (intersection * w) + self.smooth_nr - denom = (denominator * w) + self.smooth_dr + final_reduce_dim = 0 if self.batch else 1 + numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 8ea3eb116b..8ee1da7267 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -61,7 +61,7 @@ def __init__( """ super().__init__(reduction=LossReduction(reduction).value) self.spatial_dims = spatial_dims - self.data_range = data_range + self._data_range = data_range self.kernel_type = kernel_type if not isinstance(win_size, Sequence): @@ -77,7 +77,7 @@ def __init__( self.ssim_metric = SSIMMetric( spatial_dims=self.spatial_dims, - data_range=self.data_range, + data_range=self._data_range, kernel_type=self.kernel_type, win_size=self.kernel_size, kernel_sigma=self.kernel_sigma, @@ -85,6 +85,15 @@ def __init__( k2=self.k2, ) + @property + def data_range(self) -> float: + return self._data_range + + @data_range.setter + def data_range(self, value: float) -> None: + self._data_range = value + self.ssim_metric.data_range = value + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index 697547093a..d89eb8ae03 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -16,6 +16,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args +from monai.utils import deprecated_arg __all__ = ["VNet"] @@ -133,7 +134,7 @@ def __init__( out_channels: int, nconvs: int, act: tuple[str, dict] | str, - dropout_prob: float | None = None, + dropout_prob: tuple[float | None, float] = (None, 0.5), dropout_dim: int = 3, ): super().__init__() @@ -144,8 +145,8 @@ def __init__( self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2) self.bn1 = norm_type(out_channels // 2) - self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None - self.dropout2 = dropout_type(0.5) + self.dropout = dropout_type(dropout_prob[0]) if dropout_prob[0] is not None else None + self.dropout2 = dropout_type(dropout_prob[1]) self.act_function1 = get_acti_layer(act, out_channels // 2) self.act_function2 = get_acti_layer(act, out_channels) self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act) @@ -206,8 +207,9 @@ class VNet(nn.Module): The value should meet the condition that ``16 % in_channels == 0``. out_channels: number of output channels for the network. Defaults to 1. act: activation type in the network. Defaults to ``("elu", {"inplace": True})``. - dropout_prob: dropout ratio. Defaults to 0.5. - dropout_dim: determine the dimensions of dropout. Defaults to 3. + dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5. + dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5). + dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5). - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel. - ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map). @@ -216,15 +218,29 @@ class VNet(nn.Module): According to `Performance Tuning Guide `_, if a conv layer is directly followed by a batch norm layer, bias should be False. + .. deprecated:: 1.2 + ``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``. + """ + @deprecated_arg( + name="dropout_prob", + since="1.2", + new_name="dropout_prob_down", + msg_suffix="please use `dropout_prob_down` instead.", + ) + @deprecated_arg( + name="dropout_prob", since="1.2", new_name="dropout_prob_up", msg_suffix="please use `dropout_prob_up` instead." + ) def __init__( self, spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 1, act: tuple[str, dict] | str = ("elu", {"inplace": True}), - dropout_prob: float = 0.5, + dropout_prob: float | None = 0.5, # deprecated + dropout_prob_down: float | None = 0.5, + dropout_prob_up: tuple[float | None, float] = (0.5, 0.5), dropout_dim: int = 3, bias: bool = False, ): @@ -236,10 +252,10 @@ def __init__( self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias) self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias) self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias) - self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob, bias=bias) - self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob, bias=bias) - self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob) - self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob) + self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob_down, bias=bias) + self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob_down, bias=bias) + self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob_up) + self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob_up) self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act) self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act) self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f8eadcfb1b..8cd15083c9 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -493,8 +493,8 @@ def __init__( fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling to ensure that the output has the same mean as the input. channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied - on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the - channel of the image if True. + on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the + channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. """ self.factor = factor @@ -633,12 +633,20 @@ class RandScaleIntensity(RandomizableTransform): backend = ScaleIntensity.backend - def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtype: DtypeLike = np.float32) -> None: + def __init__( + self, + factors: tuple[float, float] | float, + prob: float = 0.1, + channel_wise: bool = False, + dtype: DtypeLike = np.float32, + ) -> None: """ Args: factors: factor range to randomly scale by ``v = v * (1 + factor)``. if single number, factor value is picked from (-factors, factors). prob: probability of scale. + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. """ @@ -650,13 +658,17 @@ def __init__(self, factors: tuple[float, float] | float, prob: float = 0.1, dtyp else: self.factors = (min(factors), max(factors)) self.factor = self.factors[0] + self.channel_wise = channel_wise self.dtype = dtype def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) + if self.channel_wise: + self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore + else: + self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ @@ -664,12 +676,21 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: - self.randomize() + self.randomize(img) if not self._do_transform: return convert_data_type(img, dtype=self.dtype)[0] - return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img) + ret: NdarrayOrTensor + if self.channel_wise: + out = [] + for i, d in enumerate(img): + out_channel = ScaleIntensity(minv=None, maxv=None, factor=self.factor[i], dtype=self.dtype)(d) # type: ignore + out.append(out_channel) + ret = torch.stack(out) # type: ignore + else: + ret = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img) + return ret class RandBiasField(RandomizableTransform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 91acff0c3d..32052ad406 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -586,6 +586,7 @@ def __init__( keys: KeysCollection, factors: tuple[float, float] | float, prob: float = 0.1, + channel_wise: bool = False, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: @@ -597,13 +598,15 @@ def __init__( if single number, factor value is picked from (-factors, factors). prob: probability of scale. (Default 0.1, with 10% probability it returns a scaled array.) + channel_wise: if True, scale on each channel separately. Please ensure + that the first dimension represents the channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0) + self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0, channel_wise=channel_wise) def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None @@ -620,8 +623,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d + # expect all the specified keys have same spatial shape and share same random holes + first_key: Hashable = self.first_key(d) + if first_key == (): + for key in self.key_iterator(d): + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + return d + # all the keys share the same random scale factor - self.scaler.randomize(None) + self.scaler.randomize(d[first_key]) for key in self.key_iterator(d): d[key] = self.scaler(d[key], randomize=False) return d diff --git a/tests/test_cldice_loss.py b/tests/test_cldice_loss.py index 109186b5d1..071bd20d6c 100644 --- a/tests/test_cldice_loss.py +++ b/tests/test_cldice_loss.py @@ -17,14 +17,8 @@ from monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss TEST_CASES = [ - [ # shape: (1, 4), (1, 4) - {"y_pred": torch.ones((100, 3, 256, 256)), "y_true": torch.ones((100, 3, 256, 256))}, - 0.0, - ], - [ # shape: (1, 5), (1, 5) - {"y_pred": torch.ones((100, 3, 256, 256, 5)), "y_true": torch.ones((100, 3, 256, 256, 5))}, - 0.0, - ], + [{"y_pred": torch.ones((7, 3, 11, 10)), "y_true": torch.ones((7, 3, 11, 10))}, 0.0], + [{"y_pred": torch.ones((2, 3, 13, 14, 5)), "y_true": torch.ones((2, 3, 13, 14, 5))}, 0.0], ] diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index d45a251f9b..8cecfe87cf 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -174,6 +174,7 @@ def test_parse(self, config, expected_ids, output_types): self.assertTrue(isinstance(v, cls)) # test default value self.assertEqual(parser.get_parsed_content(id="abc", default=ConfigItem(12345, "abc")), 12345) + self.assertEqual(parser.get_parsed_content(id="abcd", default=1), 1) @parameterized.expand([TEST_CASE_2]) def test_function(self, config): diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py new file mode 100644 index 0000000000..29f2d0096b --- /dev/null +++ b/tests/test_gdsdataset.py @@ -0,0 +1,222 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import pickle +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import GDSDataset, json_hashing +from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform +from monai.utils import optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_cp = optional_import("cupy") +nib, has_nib = optional_import("nibabel") +_, has_kvikio_numpy = optional_import("kvikio.numpy") + +TEST_CASE_1 = [ + Compose( + [ + LoadImaged(keys=["image", "label", "extra"], image_only=True), + SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + ] + ), + (128, 128, 128), +] + +TEST_CASE_2 = [ + [ + LoadImaged(keys=["image", "label", "extra"], image_only=True), + SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + ], + (128, 128, 128), +] + +TEST_CASE_3 = [None, (128, 128, 128)] + +DTYPES = { + np.dtype(np.uint8): torch.uint8, + np.dtype(np.int8): torch.int8, + np.dtype(np.int16): torch.int16, + np.dtype(np.int32): torch.int32, + np.dtype(np.int64): torch.int64, + np.dtype(np.float16): torch.float16, + np.dtype(np.float32): torch.float32, + np.dtype(np.float64): torch.float64, + np.dtype(np.complex64): torch.complex64, + np.dtype(np.complex128): torch.complex128, +} + + +class _InplaceXform(Transform): + def __call__(self, data): + data[0] = data[0] + 1 + return data + + +@unittest.skipUnless(has_cp, "Requires CuPy library.") +@unittest.skipUnless(has_nib, "Requires nibabel package.") +@unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.") +class TestDataset(unittest.TestCase): + def test_cache(self): + """testing no inplace change to the hashed item""" + for p in TEST_NDARRAYS[:2]: + shape = (1, 10, 9, 8) + items = [p(np.arange(0, np.prod(shape)).reshape(shape))] + + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset( + data=items, + transform=_InplaceXform(), + cache_dir=tempdir, + device=0, + pickle_module="pickle", + pickle_protocol=pickle.HIGHEST_PROTOCOL, + ) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + assert_allclose(ds[0], ds1[0], type_test=False) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + + ds = GDSDataset( + items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0 + ) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + ds1 = GDSDataset( + items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0 + ) + assert_allclose(ds[0], ds1[0], type_test=False) + assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) + + def test_metatensor(self): + shape = (1, 10, 9, 8) + items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))] + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + assert_allclose(ds[0], ds[0][0], type_test=False) + + def test_dtype(self): + shape = (1, 10, 9, 8) + data = np.arange(0, np.prod(shape)).reshape(shape) + for _dtype in DTYPES.keys(): + items = [np.array(data).astype(_dtype)] + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + self.assertEqual(ds[0].dtype, _dtype) + self.assertEqual(ds1[0].dtype, DTYPES[_dtype]) + + for _dtype in DTYPES.keys(): + items = [torch.tensor(data, dtype=DTYPES[_dtype])] + with tempfile.TemporaryDirectory() as tempdir: + ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0) + self.assertEqual(ds[0].dtype, DTYPES[_dtype]) + self.assertEqual(ds1[0].dtype, DTYPES[_dtype]) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, + ] + + cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + dataset_precached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + data1_precached = dataset_precached[0] + data2_precached = dataset_precached[1] + + dataset_postcached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0) + data1_postcached = dataset_postcached[0] + data2_postcached = dataset_postcached[1] + data3_postcached = dataset_postcached[0:2] + + if transform is None: + self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) + self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) + else: + self.assertTupleEqual(data1_precached["image"].shape, expected_shape) + self.assertTupleEqual(data1_precached["label"].shape, expected_shape) + self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_precached["image"].shape, expected_shape) + self.assertTupleEqual(data2_precached["label"].shape, expected_shape) + self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) + + self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + for d in data3_postcached: + self.assertTupleEqual(d["image"].shape, expected_shape) + + # update the data to cache + test_data_new = [ + { + "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), + }, + ] + dataset_postcached.set_data(data=test_data_new) + # test new exchanged cache content + if transform is None: + self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + + def test_different_transforms(self): + """ + Different instances of `GDSDataset` with the same cache_dir, + same input data, but different transforms should give different results. + """ + shape = (1, 10, 9, 8) + im = np.arange(0, np.prod(shape)).reshape(shape) + with tempfile.TemporaryDirectory() as path: + im1 = GDSDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing, device=0)[0] + im2 = GDSDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing, device=0)[0] + l2 = ((im1 - im2) ** 2).sum() ** 0.5 + self.assertTrue(l2 > 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index b8256a41a9..d8ba496d03 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -48,7 +48,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.435035, + 0.469964, ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -56,7 +56,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.3837, + 0.414507, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -71,7 +71,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 1.5348, + 0.829015, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -86,7 +86,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[[0.210949], [0.295351]], [[0.599976], [0.428522]]], + [[[0.273476]], [[0.555539]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8}, @@ -114,7 +114,7 @@ "input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1, 1, 0, 0]]]), }, - 0.26669, + 0.250023, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -136,7 +136,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - -8.55485, + -0.097833, ], ] diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 739955ea67..0cd0522036 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -413,6 +413,10 @@ def test_slicing(self): x.is_batch = True with self.assertRaises(ValueError): x[slice(0, 8)] + x = MetaTensor(np.zeros((3, 3, 4))) + x.is_batch = True + self.assertEqual(x[torch.tensor([True, False, True])].shape, (2, 3, 4)) + self.assertEqual(x[[True, False, True]].shape, (2, 3, 4)) @parameterized.expand(DTYPES) @SkipIfBeforePyTorchVersion((1, 8)) diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 5f5ca076a8..a857c0cefb 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -33,6 +33,22 @@ def test_value(self, p): expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) assert_allclose(result, p(expected), rtol=1e-7, atol=0, type_test="tensor") + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + scaler = RandScaleIntensity(factors=0.5, channel_wise=True, prob=1.0) + scaler.set_random_state(seed=0) + im = p(self.imt) + result = scaler(im) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)] + expected = p( + np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32) + ) + assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_scale_intensityd.py b/tests/test_rand_scale_intensityd.py index 6b5a04a8f3..8d928ac157 100644 --- a/tests/test_rand_scale_intensityd.py +++ b/tests/test_rand_scale_intensityd.py @@ -32,6 +32,22 @@ def test_value(self): expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) assert_allclose(result[key], p(expected), type_test="tensor") + def test_channel_wise(self): + key = "img" + for p in TEST_NDARRAYS: + scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0, channel_wise=True) + scaler.set_random_state(seed=0) + result = scaler({key: p(self.imt)}) + np.random.seed(0) + # simulate the randomize function of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)] + expected = p( + np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32) + ) + assert_allclose(result[key], p(expected), type_test="tensor") + if __name__ == "__main__": unittest.main()