From a21c7f1dd2c5c2b41d67e0cadf15c8180bf8f73f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 14 Jul 2023 13:29:13 -0600 Subject: [PATCH 01/20] Adds losses --- monai/losses/__init__.py | 3 + monai/losses/adversarial_loss.py | 171 +++++++++++++++ monai/losses/perceptual.py | 366 +++++++++++++++++++++++++++++++ monai/losses/spectral_loss.py | 88 ++++++++ tests/test_adversarial_loss.py | 93 ++++++++ tests/test_perceptual_loss.py | 82 +++++++ tests/test_spectral_loss.py | 86 ++++++++ 7 files changed, 889 insertions(+) create mode 100644 monai/losses/adversarial_loss.py create mode 100644 monai/losses/perceptual.py create mode 100644 monai/losses/spectral_loss.py create mode 100644 tests/test_adversarial_loss.py create mode 100644 tests/test_perceptual_loss.py create mode 100644 tests/test_spectral_loss.py diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 9e09b0b123..db51f0a2f8 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .adversarial_loss import PatchAdversarialLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( @@ -33,7 +34,9 @@ from .giou_loss import BoxGIoULoss, giou from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss +from .perceptual import PerceptualLoss from .spatial_mask import MaskedLoss +from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py new file mode 100644 index 0000000000..6f09228e38 --- /dev/null +++ b/monai/losses/adversarial_loss.py @@ -0,0 +1,171 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch +from torch.nn.modules.loss import _Loss + +from monai.networks.layers.utils import get_act_layer +from monai.utils import LossReduction +from monai.utils.enums import StrEnum + + +class AdversarialCriterions(StrEnum): + BCE = "bce" + HINGE = "hinge" + LEAST_SQUARE = "least_squares" + + +class PatchAdversarialLoss(_Loss): + """ + Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator. + Warning: due to the possibility of using different criterions, the output of the discrimination + mustn't be passed to a final activation layer. That is taken care of internally within the loss. + + Args: + reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. + Defaults to ``"mean"``. + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs. + Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs + through an activation layer prior to calling the loss. + no_activation_leastsq: if True, the activation layer in the case of least-squares is removed. + """ + + def __init__( + self, + reduction: LossReduction | str = LossReduction.MEAN, + criterion: str = AdversarialCriterions.LEAST_SQUARE.value, + no_activation_leastsq: bool = False, + ) -> None: + super().__init__(reduction=LossReduction(reduction).value) + + if criterion.lower() not in [m.value for m in AdversarialCriterions]: + raise ValueError( + "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" + % ", ".join([m.value for m in AdversarialCriterions]) + ) + + # Depending on the criterion, a different activation layer is used. + self.real_label = 1.0 + self.fake_label = 0.0 + if criterion == AdversarialCriterions.BCE.value: + self.activation = get_act_layer("SIGMOID") + self.loss_fct = torch.nn.BCELoss(reduction=reduction) + elif criterion == AdversarialCriterions.HINGE.value: + self.activation = get_act_layer("TANH") + self.fake_label = -1.0 + elif criterion == AdversarialCriterions.LEAST_SQUARE.value: + if no_activation_leastsq: + self.activation = None + else: + self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) + self.loss_fct = torch.nn.MSELoss(reduction=reduction) + + self.criterion = criterion + self.reduction = reduction + + def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> torch.Tensor: + """ + Gets the ground truth tensor for the discriminator depending on whether the input is real or fake. + + Args: + input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale + discriminator). This is used to match the shape. + target_is_real: whether the input is real or wannabe-real (1s) or fake (0s). + Returns: + """ + filling_label = self.real_label if target_is_real else self.fake_label + label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device) + label_tensor.requires_grad_(False) + return label_tensor.expand_as(input) + + def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: + """ + Gets a zero tensor. + + Args: + input: tensor which shape you want the zeros tensor to correspond to. + Returns: + """ + + zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device) + zero_label_tensor.requires_grad_(False) + return zero_label_tensor.expand_as(input) + + def forward( + self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool + ) -> torch.Tensor | list[torch.Tensor]: + """ + + Args: + input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of + tensors or a tensor; they shouldn't have gone through an activation layer. + target_is_real: whereas the input corresponds to discriminator output for real or fake images + for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last + case, target_is_real is set to True, as the generator wants the input to be dimmed as real. + Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale + discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the + summed or mean loss over the tensor and discriminator/s. + + """ + + if not for_discriminator and not target_is_real: + target_is_real = True # With generator, we always want this to be true! + warnings.warn( + "Variable target_is_real has been set to False, but for_discriminator is set" + "to False. To optimise a generator, target_is_real must be set to True." + ) + + if type(input) is not list: + input = [input] + target_ = [] + for _, disc_out in enumerate(input): + if self.criterion != AdversarialCriterions.HINGE.value: + target_.append(self.get_target_tensor(disc_out, target_is_real)) + else: + target_.append(self.get_zero_tensor(disc_out)) + + # Loss calculation + loss = [] + for disc_ind, disc_out in enumerate(input): + if self.activation is not None: + disc_out = self.activation(disc_out) + if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real: + loss_ = self.forward_single(-disc_out, target_[disc_ind]) + else: + loss_ = self.forward_single(disc_out, target_[disc_ind]) + loss.append(loss_) + + if loss is not None: + if self.reduction == LossReduction.MEAN.value: + loss = torch.mean(torch.stack(loss)) + elif self.reduction == LossReduction.SUM.value: + loss = torch.sum(torch.stack(loss)) + + return loss + + def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> torch.Tensor | None: + if ( + self.criterion == AdversarialCriterions.BCE.value + or self.criterion == AdversarialCriterions.LEAST_SQUARE.value + ): + return self.loss_fct(input, target) + elif self.criterion == AdversarialCriterions.HINGE.value: + minval = torch.min(input - 1, self.get_zero_tensor(input)) + return -torch.mean(minval) + else: + return None diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py new file mode 100644 index 0000000000..8fffb1c870 --- /dev/null +++ b/monai/losses/perceptual.py @@ -0,0 +1,366 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +from lpips import LPIPS +from torchvision.models import ResNet50_Weights, resnet50 +from torchvision.models.feature_extraction import create_feature_extractor + + +class PerceptualLoss(nn.Module): + """ + Perceptual loss using features from pretrained deep neural networks trained. The function supports networks + pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. "The unreasonable effectiveness of deep + features as a perceptual metric." https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. "RadImageNet: An + Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning" + https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. "Med3D: Transfer Learning for + 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ; + and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html . + + The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the + three axis. + + Args: + spatial_dims: number of spatial dimensions. + network_type: {``"alex"``, ``"vgg"``, ``"squeeze"``, ``"radimagenet_resnet50"``, + ``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``, ``"resnet50"``} + Specifies the network architecture to use. Defaults to ``"alex"``. + is_fake_3d: if True use 2.5D approach for a 3D perceptual loss. + fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach. + cache_dir: path to cache directory to save the pretrained network weights. + pretrained: whether to load pretrained weights. This argument only works when using networks from + LIPIS or Torchvision. Defaults to ``"True"``. + pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded + via using this argument. This argument only works when ``"network_type"`` is "resnet50". + Defaults to `None`. + pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to + extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50". + Defaults to `None`. + """ + + def __init__( + self, + spatial_dims: int, + network_type: str = "alex", + is_fake_3d: bool = True, + fake_3d_ratio: float = 0.5, + cache_dir: str | None = None, + pretrained: bool = True, + pretrained_path: str | None = None, + pretrained_state_dict_key: str | None = None, + ): + super().__init__() + + if spatial_dims not in [2, 3]: + raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") + + if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: + raise ValueError( + "MedicalNet networks are only compatible with ``spatial_dims=3``." + "Argument is_fake_3d must be set to False." + ) + + if cache_dir: + torch.hub.set_dir(cache_dir) + + self.spatial_dims = spatial_dims + if spatial_dims == 3 and is_fake_3d is False: + self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) + elif "radimagenet_" in network_type: + self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) + elif network_type == "resnet50": + self.perceptual_function = TorchvisionModelPerceptualSimilarity( + net=network_type, + pretrained=pretrained, + pretrained_path=pretrained_path, + pretrained_state_dict_key=pretrained_state_dict_key, + ) + else: + self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) + self.is_fake_3d = is_fake_3d + self.fake_3d_ratio = fake_3d_ratio + + def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatial_axis: int) -> torch.Tensor: + """ + Calculate perceptual loss in one of the axis used in the 2.5D approach. After the slices of one spatial axis + is transformed into different instances in the batch, we compute the loss using the 2D approach. + + Args: + input: input 5D tensor. BNHWD + target: target 5D tensor. BNHWD + spatial_axis: spatial axis to obtain the 2D slices. + """ + + def batchify_axis(x: torch.Tensor, fake_3d_perm: tuple) -> torch.Tensor: + """ + Transform slices from one spatial axis into different instances in the batch. + """ + slices = x.float().permute((0,) + fake_3d_perm).contiguous() + slices = slices.view(-1, x.shape[fake_3d_perm[1]], x.shape[fake_3d_perm[2]], x.shape[fake_3d_perm[3]]) + + return slices + + preserved_axes = [2, 3, 4] + preserved_axes.remove(spatial_axis) + + channel_axis = 1 + input_slices = batchify_axis(x=input, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes)) + indices = torch.randperm(input_slices.shape[0])[: int(input_slices.shape[0] * self.fake_3d_ratio)].to( + input_slices.device + ) + input_slices = torch.index_select(input_slices, dim=0, index=indices) + target_slices = batchify_axis(x=target, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes)) + target_slices = torch.index_select(target_slices, dim=0, index=indices) + + axis_loss = torch.mean(self.perceptual_function(input_slices, target_slices)) + + return axis_loss + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNHW[D]. + target: the shape should be BNHW[D]. + """ + if target.shape != input.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + if self.spatial_dims == 3 and self.is_fake_3d: + # Compute 2.5D approach + loss_sagittal = self._calculate_axis_loss(input, target, spatial_axis=2) + loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3) + loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4) + loss = loss_sagittal + loss_axial + loss_coronal + else: + # 2D and real 3D cases + loss = self.perceptual_function(input, target) + + return torch.mean(loss) + + +class MedicalNetPerceptualSimilarity(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer + Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from + "Warvito/MedicalNet-models". + + Args: + net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} + Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: + super().__init__() + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) + self.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the + pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across + the channels. Finally, we compute the difference between the input and target features and calculate the mean + value from the spatial dimensions to obtain the perceptual loss. + + Args: + input: 3D input tensor with shape BCDHW. + target: 3D target tensor with shape BCDHW. + """ + input = medicalnet_intensity_normalisation(input) + target = medicalnet_intensity_normalisation(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3, 4], keepdim=keepdim) + + +def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def medicalnet_intensity_normalisation(volume): + """Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133""" + mean = volume.mean() + std = volume.std() + return (volume - mean) / std + + +class RadImageNetPerceptualSimilarity(nn.Module): + """ + Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et + al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class + uses torch Hub to download the networks from "Warvito/radimagenet-models". + + Args: + net: {``"radimagenet_resnet50"``} + Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. + verbose: if false, mute messages from torch Hub load function. + """ + + def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: + super().__init__() + self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) + self.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at + https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from + 'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised + across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package). + """ + # If input has just 1 channel, repeat channel to have 3 channels + if input.shape[1] == 1 and target.shape[1] == 1: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + # Change order from 'RGB' to 'BGR' + input = input[:, [2, 1, 0], ...] + target = target[:, [2, 1, 0], ...] + + # Subtract mean used during training + input = subtract_mean(input) + target = subtract_mean(target) + + # Get model outputs + outs_input = self.model.forward(input) + outs_target = self.model.forward(target) + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +class TorchvisionModelPerceptualSimilarity(nn.Module): + """ + Component to perform the perceptual evaluation with TorchVision models. + Currently, only ResNet50 is supported. The network structure is based on: + https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html + + Args: + net: {``"resnet50"``} + Specifies the network architecture to use. Defaults to ``"resnet50"``. + pretrained: whether to load pretrained weights. Defaults to `True`. + pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded + via using this argument. Defaults to `None`. + pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to + extract the expected state dict. Defaults to `None`. + """ + + def __init__( + self, + net: str = "resnet50", + pretrained: bool = True, + pretrained_path: str | None = None, + pretrained_state_dict_key: str | None = None, + ) -> None: + super().__init__() + supported_networks = ["resnet50"] + if net not in supported_networks: + raise NotImplementedError( + f"'net' {net} is not supported, please select a network from {supported_networks}." + ) + + if pretrained_path is None: + network = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None) + else: + network = resnet50(weights=None) + if pretrained is True: + state_dict = torch.load(pretrained_path) + if pretrained_state_dict_key is not None: + state_dict = state_dict[pretrained_state_dict_key] + network.load_state_dict(state_dict) + self.final_layer = "layer4.2.relu_2" + self.model = create_feature_extractor(network, [self.final_layer]) + self.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at + https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights, + we make sure that the input and target have 3 channels, and then do Z-Score normalization. + The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar + approach to the lpips package). + """ + # If input has just 1 channel, repeat channel to have 3 channels + if input.shape[1] == 1 and target.shape[1] == 1: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + # Input normalization + input = torchvision_zscore_norm(input) + target = torchvision_zscore_norm(target) + + # Get model outputs + outs_input = self.model.forward(input)[self.final_layer] + outs_target = self.model.forward(target)[self.final_layer] + + # Normalise through the channels + feats_input = normalize_tensor(outs_input) + feats_target = normalize_tensor(outs_target) + + results = (feats_input - feats_target) ** 2 + results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) + + return results + + +def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + return x.mean([2, 3], keepdim=keepdim) + + +def torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor: + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0] + x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1] + x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2] + return x + + +def subtract_mean(x: torch.Tensor) -> torch.Tensor: + mean = [0.406, 0.456, 0.485] + x[:, 0, :, :] -= mean[0] + x[:, 1, :, :] -= mean[1] + x[:, 2, :, :] -= mean[2] + return x diff --git a/monai/losses/spectral_loss.py b/monai/losses/spectral_loss.py new file mode 100644 index 0000000000..311a64d590 --- /dev/null +++ b/monai/losses/spectral_loss.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch.fft import fftn +from torch.nn.modules.loss import _Loss + +from monai.utils import LossReduction + + +class JukeboxLoss(_Loss): + """ + Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT). + + Based on: + Dhariwal, et al. 'Jukebox: A generative model for music.'https://arxiv.org/abs/2005.00341 + + Args: + spatial_dims: number of spatial dimensions. + fft_signal_size: signal size in the transformed dimensions. See torch.fft.fftn() for more information. + fft_norm: {``"forward"``, ``"backward"``, ``"ortho"``} Specifies the normalization mode in the fft. See + torch.fft.fftn() for more information. + + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + + def __init__( + self, + spatial_dims: int, + fft_signal_size: tuple[int] | None = None, + fft_norm: str = "ortho", + reduction: LossReduction | str = LossReduction.MEAN, + ) -> None: + super().__init__(reduction=LossReduction(reduction).value) + + self.spatial_dims = spatial_dims + self.fft_signal_size = fft_signal_size + self.fft_dim = tuple(range(1, spatial_dims + 2)) + self.fft_norm = fft_norm + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + input_amplitude = self._get_fft_amplitude(target) + target_amplitude = self._get_fft_amplitude(input) + + # Compute distance between amplitude of frequency components + # See Section 3.3 from https://arxiv.org/abs/2005.00341 + loss = F.mse_loss(target_amplitude, input_amplitude, reduction="none") + + if self.reduction == LossReduction.MEAN.value: + loss = loss.mean() + elif self.reduction == LossReduction.SUM.value: + loss = loss.sum() + elif self.reduction == LossReduction.NONE.value: + pass + + return loss + + def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor: + """ + Calculate the amplitude of the fourier transformations representation of the images + + Args: + images: Images that are to undergo fftn + + Returns: + fourier transformation amplitude + """ + img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm) + + amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2) + + return amplitude diff --git a/tests/test_adversarial_loss.py b/tests/test_adversarial_loss.py new file mode 100644 index 0000000000..77880725ec --- /dev/null +++ b/tests/test_adversarial_loss.py @@ -0,0 +1,93 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.losses import PatchAdversarialLoss + +shapes_tensors = {"2d": [4, 1, 64, 64], "3d": [4, 1, 64, 64, 64]} +reductions = ["sum", "mean"] +criterion = ["bce", "least_squares", "hinge"] + +TEST_CASE_CREATION_FAIL = [{"reduction": "sum", "criterion": "invalid"}] + +TEST_CASES_LOSS_LOGIC_2D = [] +TEST_CASES_LOSS_LOGIC_3D = [] + +for c in criterion: + for r in reductions: + TEST_CASES_LOSS_LOGIC_2D.append([{"reduction": r, "criterion": c}, shapes_tensors["2d"]]) + TEST_CASES_LOSS_LOGIC_3D.append([{"reduction": r, "criterion": c}, shapes_tensors["3d"]]) + +TEST_CASES_LOSS_LOGIC_LIST = [] +for c in criterion: + TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["2d"]]) + TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["3d"]]) + + +class TestPatchAdversarialLoss(unittest.TestCase): + def get_input(self, shape, is_positive): + """ + Get tensor for the tests. The tensor is around (-1) or (+1), depending on + is_positive. + """ + if is_positive: + offset = 1 + else: + offset = -1 + return torch.ones(shape) * (offset) + 0.01 * torch.randn(shape) + + def test_criterion(self): + """ + Make sure that unknown criterion fail. + """ + with self.assertRaises(ValueError): + PatchAdversarialLoss(**TEST_CASE_CREATION_FAIL[0]) + + @parameterized.expand(TEST_CASES_LOSS_LOGIC_2D + TEST_CASES_LOSS_LOGIC_3D) + def test_loss_logic(self, input_param: dict, shape_input: list): + """ + We want to make sure that the adversarial losses do what they should. + If the discriminator takes in a tensor that looks positive, yet the label is fake, + the loss should be bigger than that obtained with a tensor that looks negative. + Same for the real label, and for the generator. + """ + loss = PatchAdversarialLoss(**input_param) + fakes = self.get_input(shape_input, is_positive=False) + reals = self.get_input(shape_input, is_positive=True) + # Discriminator: fake label + loss_disc_f_f = loss(fakes, target_is_real=False, for_discriminator=True) + loss_disc_f_r = loss(reals, target_is_real=False, for_discriminator=True) + assert loss_disc_f_f < loss_disc_f_r + # Discriminator: real label + loss_disc_r_f = loss(fakes, target_is_real=True, for_discriminator=True) + loss_disc_r_r = loss(reals, target_is_real=True, for_discriminator=True) + assert loss_disc_r_f > loss_disc_r_r + # Generator: + loss_gen_f = loss(fakes, target_is_real=True, for_discriminator=False) # target_is_real is overridden + loss_gen_r = loss(reals, target_is_real=True, for_discriminator=False) # target_is_real is overridden + assert loss_gen_f > loss_gen_r + + @parameterized.expand(TEST_CASES_LOSS_LOGIC_LIST) + def test_multiple_discs(self, input_param: dict, shape_input): + shapes = [shape_input] + [shape_input[0:2] + [int(i / j) for i in shape_input[2:]] for j in range(1, 3)] + inputs = [self.get_input(shapes[i], is_positive=True) for i in range(len(shapes))] + loss = PatchAdversarialLoss(**input_param) + assert len(loss(inputs, for_discriminator=True, target_is_real=True)) == 3 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py new file mode 100644 index 0000000000..3a4e084b7e --- /dev/null +++ b/tests/test_perceptual_loss.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.losses import PerceptualLoss + +TEST_CASES = [ + [{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)], + [ + {"spatial_dims": 3, "network_type": "squeeze", "is_fake_3d": True, "fake_3d_ratio": 0.1}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], + [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 1, 64, 64), (2, 1, 64, 64)], + [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 3, 64, 64), (2, 3, 64, 64)], + [ + {"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], + [ + {"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2}, + (2, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], +] + + +class TestPerceptualLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, target_shape): + loss = PerceptualLoss(**input_param) + result = loss(torch.randn(input_shape), torch.randn(target_shape)) + self.assertEqual(result.shape, torch.Size([])) + + @parameterized.expand(TEST_CASES) + def test_identical_input(self, input_param, input_shape, target_shape): + loss = PerceptualLoss(**input_param) + tensor = torch.randn(input_shape) + result = loss(tensor, tensor) + self.assertEqual(result, torch.Tensor([0.0])) + + def test_different_shape(self): + loss = PerceptualLoss(spatial_dims=2, network_type="squeeze") + tensor = torch.randn(2, 1, 64, 64) + target = torch.randn(2, 1, 32, 32) + with self.assertRaises(ValueError): + loss(tensor, target) + + def test_1d(self): + with self.assertRaises(NotImplementedError): + PerceptualLoss(spatial_dims=1) + + def test_medicalnet_on_2d_data(self): + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets") + + with self.assertRaises(ValueError): + PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spectral_loss.py b/tests/test_spectral_loss.py new file mode 100644 index 0000000000..21b5c48de4 --- /dev/null +++ b/tests/test_spectral_loss.py @@ -0,0 +1,86 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import JukeboxLoss +from tests.utils import test_script_save + +TEST_CASES = [ + [ + {"spatial_dims": 2}, + { + "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + }, + 0.070648, + ], + [ + {"spatial_dims": 2, "reduction": "sum"}, + { + "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), + }, + 0.8478, + ], + [ + {"spatial_dims": 3}, + { + "input": torch.tensor( + [ + [ + [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + ] + ] + ), + "target": torch.tensor( + [ + [ + [[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], + ] + ] + ), + }, + 0.03838, + ], +] + + +class TestJukeboxLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_results(self, input_param, input_data, expected_val): + results = JukeboxLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_2d_shape(self): + results = JukeboxLoss(spatial_dims=2, reduction="none").forward(**TEST_CASES[0][1]) + self.assertEqual(results.shape, (1, 2, 2, 3)) + + def test_3d_shape(self): + results = JukeboxLoss(spatial_dims=3, reduction="none").forward(**TEST_CASES[2][1]) + self.assertEqual(results.shape, (1, 2, 2, 2, 3)) + + def test_script(self): + loss = JukeboxLoss(spatial_dims=2) + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + + +if __name__ == "__main__": + unittest.main() From 79a378452ac7f890cb9848637dbab1b067e8ac5c Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 14 Jul 2023 13:30:10 -0600 Subject: [PATCH 02/20] Adds LPIPS as a requirement --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 78e3b7381a..bf24ba5a3a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -53,3 +53,4 @@ onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr +lpips==0.1.4 From 89c4e83a7245be5ebd4df4b0d18627bd65693b23 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 19 Jul 2023 14:19:35 +0100 Subject: [PATCH 03/20] Updates docstrings --- docs/source/losses.rst | 15 +++++++++++++++ monai/losses/adversarial_loss.py | 26 ++++++++++++++------------ monai/losses/spectral_loss.py | 2 +- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 0f262894cf..39f1d0e4d1 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -99,6 +99,21 @@ Reconstruction Losses .. autoclass:: monai.losses.ssim_loss.SSIMLoss :members: +`PatchAdversarialLoss` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: PatchAdversarialLoss + :members: + +`PerceptualLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: PerceptualLoss + :members: + +`JukeboxLoss` +~~~~~~~~~~~~~~ +.. autoclass:: JukeboxLoss + :members: + Loss Wrappers ------------- diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py index 6f09228e38..62cff46200 100644 --- a/monai/losses/adversarial_loss.py +++ b/monai/losses/adversarial_loss.py @@ -34,14 +34,16 @@ class PatchAdversarialLoss(_Loss): mustn't be passed to a final activation layer. That is taken care of internally within the loss. Args: - reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. - Defaults to ``"mean"``. - - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - - ``"sum"``: the output will be summed. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs. - Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs - through an activation layer prior to calling the loss. + Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs + through an activation layer prior to calling the loss. no_activation_leastsq: if True, the activation layer in the case of least-squares is removed. """ @@ -112,14 +114,14 @@ def forward( """ Args: - input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of - tensors or a tensor; they shouldn't have gone through an activation layer. + input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors + or a tensor; they shouldn't have gone through an activation layer. target_is_real: whereas the input corresponds to discriminator output for real or fake images for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last - case, target_is_real is set to True, as the generator wants the input to be dimmed as real. + case, target_is_real is set to True, as the generator wants the input to be dimmed as real. Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale - discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the - summed or mean loss over the tensor and discriminator/s. + discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the + summed or mean loss over the tensor and discriminator/s. """ diff --git a/monai/losses/spectral_loss.py b/monai/losses/spectral_loss.py index 311a64d590..06714f3993 100644 --- a/monai/losses/spectral_loss.py +++ b/monai/losses/spectral_loss.py @@ -24,7 +24,7 @@ class JukeboxLoss(_Loss): Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT). Based on: - Dhariwal, et al. 'Jukebox: A generative model for music.'https://arxiv.org/abs/2005.00341 + Dhariwal, et al. 'Jukebox: A generative model for music.' https://arxiv.org/abs/2005.00341 Args: spatial_dims: number of spatial dimensions. From 3faaca51ded8b32ff4b0e488a773a6c41cb70496 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 19 Jul 2023 15:22:31 +0100 Subject: [PATCH 04/20] Adds external dependency to relevant files, excludes from min tests and uses optional_import to import it --- docs/source/installation.md | 4 ++-- monai/losses/perceptual.py | 5 ++++- setup.cfg | 3 +++ tests/min_tests.py | 1 + 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/source/installation.md b/docs/source/installation.md index eb7adb06fb..53e5167346 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr` and `lpips` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 8fffb1c870..73d2e83bfa 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -13,7 +13,10 @@ import torch import torch.nn as nn -from lpips import LPIPS + +from monai.utils import optional_import + +LPIPS, _ = optional_import("lpips", name="LPIPS") from torchvision.models import ResNet50_Weights, resnet50 from torchvision.models.feature_extraction import create_feature_extractor diff --git a/setup.cfg b/setup.cfg index c218b133ee..5d29d2bce9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,6 +80,7 @@ all = onnx>=1.13.0 onnxruntime; python_version <= '3.10' zarr + lpips==0.1.4 nibabel = nibabel ninja = @@ -145,6 +146,8 @@ onnx = onnxruntime; python_version <= '3.10' zarr = zarr +lpips = + lpips==0.1.4 # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded diff --git a/tests/min_tests.py b/tests/min_tests.py index e3b09e7c84..e93652d215 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -204,6 +204,7 @@ def run_testsuit(): "test_spatial_combine_transforms", "test_bundle_workflow", "test_zarr_avg_merger", + "test_perceptual_loss", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 086b8a9de6ffef5dd315cf52c98f62ee9f628d5e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 19 Jul 2023 15:30:00 +0100 Subject: [PATCH 05/20] Uses optional_import for torchvision too --- monai/losses/perceptual.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 73d2e83bfa..82c9ede91c 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -17,6 +17,7 @@ from monai.utils import optional_import LPIPS, _ = optional_import("lpips", name="LPIPS") +torchvision, _ = optional_import("torchvision") from torchvision.models import ResNet50_Weights, resnet50 from torchvision.models.feature_extraction import create_feature_extractor @@ -302,16 +303,18 @@ def __init__( ) if pretrained_path is None: - network = resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None) + network = torchvision.models.resnet50( + weights=torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None + ) else: - network = resnet50(weights=None) + network = torchvision.models.resnet50(weights=None) if pretrained is True: state_dict = torch.load(pretrained_path) if pretrained_state_dict_key is not None: state_dict = state_dict[pretrained_state_dict_key] network.load_state_dict(state_dict) self.final_layer = "layer4.2.relu_2" - self.model = create_feature_extractor(network, [self.final_layer]) + self.model = torchvision.models.feature_extraction.create_feature_extractor(network, [self.final_layer]) self.eval() for param in self.parameters(): From f500f4a66648d4b82a36ce4e690ce5fc37b38eb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jul 2023 14:32:06 +0000 Subject: [PATCH 06/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/perceptual.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 82c9ede91c..5cfa35ea98 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -18,8 +18,6 @@ LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") -from torchvision.models import ResNet50_Weights, resnet50 -from torchvision.models.feature_extraction import create_feature_extractor class PerceptualLoss(nn.Module): From 105c3b8dad6662915cea7adac666059fe4336b50 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 14:31:31 +0100 Subject: [PATCH 07/20] Fixes typing issues in perceptual loss --- monai/losses/perceptual.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 82c9ede91c..0b574395a3 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -18,8 +18,7 @@ LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") -from torchvision.models import ResNet50_Weights, resnet50 -from torchvision.models.feature_extraction import create_feature_extractor + class PerceptualLoss(nn.Module): @@ -79,6 +78,7 @@ def __init__( torch.hub.set_dir(cache_dir) self.spatial_dims = spatial_dims + self.perceptual_function : nn.Module if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) elif "radimagenet_" in network_type: @@ -168,7 +168,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("Warvito/MedicalNet-models", model=net, verbose=verbose) + self.model = torch.hub.load("marksgraham/MedicalNet-models", model=net, verbose=verbose, force_reload=True) self.eval() for param in self.parameters(): @@ -196,7 +196,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results = (feats_input - feats_target) ** 2 + results : torch.Tensor = (feats_input - feats_target) ** 2 results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) return results @@ -266,7 +266,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results = (feats_input - feats_target) ** 2 + results: torch.Tensor = (feats_input - feats_target) ** 2 results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) return results @@ -345,7 +345,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results = (feats_input - feats_target) ** 2 + results : torch.Tensor = (feats_input - feats_target) ** 2 results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) return results From 700096d959ff9dd9a5c55dd635ec482af3a6a4e2 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 14:50:17 +0100 Subject: [PATCH 08/20] Fixes typing issues in adversarial loss --- monai/losses/adversarial_loss.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py index 62cff46200..2165fc8daa 100644 --- a/monai/losses/adversarial_loss.py +++ b/monai/losses/adversarial_loss.py @@ -64,6 +64,7 @@ def __init__( # Depending on the criterion, a different activation layer is used. self.real_label = 1.0 self.fake_label = 0.0 + self.loss_fct : _Loss if criterion == AdversarialCriterions.BCE.value: self.activation = get_act_layer("SIGMOID") self.loss_fct = torch.nn.BCELoss(reduction=reduction) @@ -80,7 +81,7 @@ def __init__( self.criterion = criterion self.reduction = reduction - def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> torch.Tensor: + def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor: """ Gets the ground truth tensor for the discriminator depending on whether the input is real or fake. @@ -95,7 +96,7 @@ def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> t label_tensor.requires_grad_(False) return label_tensor.expand_as(input) - def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: + def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor: """ Gets a zero tensor. @@ -109,7 +110,7 @@ def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: return zero_label_tensor.expand_as(input) def forward( - self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool + self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool ) -> torch.Tensor | list[torch.Tensor]: """ @@ -142,7 +143,7 @@ def forward( target_.append(self.get_zero_tensor(disc_out)) # Loss calculation - loss = [] + loss_list = [] for disc_ind, disc_out in enumerate(input): if self.activation is not None: disc_out = self.activation(disc_out) @@ -150,24 +151,24 @@ def forward( loss_ = self.forward_single(-disc_out, target_[disc_ind]) else: loss_ = self.forward_single(disc_out, target_[disc_ind]) - loss.append(loss_) + loss_list.append(loss_) - if loss is not None: + if loss_list is not None: if self.reduction == LossReduction.MEAN.value: - loss = torch.mean(torch.stack(loss)) + loss = torch.mean(torch.stack(loss_list)) elif self.reduction == LossReduction.SUM.value: - loss = torch.sum(torch.stack(loss)) + loss = torch.sum(torch.stack(loss_list)) return loss - def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> torch.Tensor | None: + def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + forward : torch.Tensor if ( self.criterion == AdversarialCriterions.BCE.value or self.criterion == AdversarialCriterions.LEAST_SQUARE.value ): - return self.loss_fct(input, target) + forward = self.loss_fct(input, target) elif self.criterion == AdversarialCriterions.HINGE.value: minval = torch.min(input - 1, self.get_zero_tensor(input)) - return -torch.mean(minval) - else: - return None + forward = -torch.mean(minval) + return forward From cd01b5955f4daec2f912d9d4d7a4859768186d6e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 15:26:06 +0100 Subject: [PATCH 09/20] Fixes more typing errors --- monai/losses/adversarial_loss.py | 8 +++++--- monai/losses/perceptual.py | 7 +++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py index 2165fc8daa..9005b7f030 100644 --- a/monai/losses/adversarial_loss.py +++ b/monai/losses/adversarial_loss.py @@ -64,7 +64,7 @@ def __init__( # Depending on the criterion, a different activation layer is used. self.real_label = 1.0 self.fake_label = 0.0 - self.loss_fct : _Loss + self.loss_fct: _Loss if criterion == AdversarialCriterions.BCE.value: self.activation = get_act_layer("SIGMOID") self.loss_fct = torch.nn.BCELoss(reduction=reduction) @@ -153,16 +153,18 @@ def forward( loss_ = self.forward_single(disc_out, target_[disc_ind]) loss_list.append(loss_) + loss: torch.Tensor | list[torch.Tensor] if loss_list is not None: if self.reduction == LossReduction.MEAN.value: loss = torch.mean(torch.stack(loss_list)) elif self.reduction == LossReduction.SUM.value: loss = torch.sum(torch.stack(loss_list)) - + else: + loss = loss_list return loss def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - forward : torch.Tensor + forward: torch.Tensor if ( self.criterion == AdversarialCriterions.BCE.value or self.criterion == AdversarialCriterions.LEAST_SQUARE.value diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 0b574395a3..e9a801c532 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -20,7 +20,6 @@ torchvision, _ = optional_import("torchvision") - class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks @@ -78,7 +77,7 @@ def __init__( torch.hub.set_dir(cache_dir) self.spatial_dims = spatial_dims - self.perceptual_function : nn.Module + self.perceptual_function: nn.Module if spatial_dims == 3 and is_fake_3d is False: self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False) elif "radimagenet_" in network_type: @@ -196,7 +195,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results : torch.Tensor = (feats_input - feats_target) ** 2 + results: torch.Tensor = (feats_input - feats_target) ** 2 results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True) return results @@ -345,7 +344,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: feats_input = normalize_tensor(outs_input) feats_target = normalize_tensor(outs_target) - results : torch.Tensor = (feats_input - feats_target) ** 2 + results: torch.Tensor = (feats_input - feats_target) ** 2 results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True) return results From a5e092e4b442c913bde31c432f0a640cd6e5ab10 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 15:41:31 +0100 Subject: [PATCH 10/20] Formatting fix --- monai/losses/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d1fba62979..75f4d181d0 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,7 +11,6 @@ from __future__ import annotations - from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss From a2419195d1565ea6743ab109ad0d8e05fc4e1678 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 16:11:22 +0100 Subject: [PATCH 11/20] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: a21c7f1dd2c5c2b41d67e0cadf15c8180bf8f73f I, Mark Graham , hereby add my Signed-off-by to this commit: 79a378452ac7f890cb9848637dbab1b067e8ac5c I, Mark Graham , hereby add my Signed-off-by to this commit: 89c4e83a7245be5ebd4df4b0d18627bd65693b23 I, Mark Graham , hereby add my Signed-off-by to this commit: 3faaca51ded8b32ff4b0e488a773a6c41cb70496 I, Mark Graham , hereby add my Signed-off-by to this commit: 086b8a9de6ffef5dd315cf52c98f62ee9f628d5e I, Mark Graham , hereby add my Signed-off-by to this commit: 105c3b8dad6662915cea7adac666059fe4336b50 I, Mark Graham , hereby add my Signed-off-by to this commit: 700096d959ff9dd9a5c55dd635ec482af3a6a4e2 I, Mark Graham , hereby add my Signed-off-by to this commit: cd01b5955f4daec2f912d9d4d7a4859768186d6e I, Mark Graham , hereby add my Signed-off-by to this commit: a5e092e4b442c913bde31c432f0a640cd6e5ab10 Signed-off-by: Mark Graham --- tests/test_perceptual_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 3a4e084b7e..879f68a6c6 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -79,4 +79,4 @@ def test_medicalnet_on_2d_data(self): if __name__ == "__main__": - unittest.main() + unittest.main() From 36d7ed69bca9f89a5195c828f54e7c2336e4a818 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 16:11:42 +0100 Subject: [PATCH 12/20] Empty commit workaround undo --- tests/test_perceptual_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 879f68a6c6..3a4e084b7e 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -79,4 +79,4 @@ def test_medicalnet_on_2d_data(self): if __name__ == "__main__": - unittest.main() + unittest.main() From d9d11405fe076cd12f56cb8b0d4f437214dd3737 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Jul 2023 15:11:51 +0000 Subject: [PATCH 13/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_perceptual_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 879f68a6c6..3a4e084b7e 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -79,4 +79,4 @@ def test_medicalnet_on_2d_data(self): if __name__ == "__main__": - unittest.main() + unittest.main() From c118d101d619e3886c91fef847cfd118c860775f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 31 Jul 2023 16:13:43 +0100 Subject: [PATCH 14/20] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: 36d7ed69bca9f89a5195c828f54e7c2336e4a818 Signed-off-by: Mark Graham --- tests/test_perceptual_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 3a4e084b7e..879f68a6c6 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -79,4 +79,4 @@ def test_medicalnet_on_2d_data(self): if __name__ == "__main__": - unittest.main() + unittest.main() From f8f1a3fa0eaf8c8d8b7534efab048a76e544c2fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Jul 2023 15:14:08 +0000 Subject: [PATCH 15/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_perceptual_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 879f68a6c6..3a4e084b7e 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -79,4 +79,4 @@ def test_medicalnet_on_2d_data(self): if __name__ == "__main__": - unittest.main() + unittest.main() From a4e16c81d55c0c07f5a4247be33183b036a7be48 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 2 Aug 2023 16:24:36 +0100 Subject: [PATCH 16/20] Fix errors with earlier torchvision versions Signed-off-by: Mark Graham --- tests/test_perceptual_loss.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 879f68a6c6..7cc24fb483 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -17,7 +17,10 @@ from parameterized import parameterized from monai.losses import PerceptualLoss +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion +_, has_torchvision = optional_import("torchvision") TEST_CASES = [ [{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)], [ @@ -45,6 +48,8 @@ ] +@SkipIfBeforePyTorchVersion((1, 10)) +@unittest.skipUnless(has_torchvision, "Requires torchvision") class TestPerceptualLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, target_shape): @@ -79,4 +84,4 @@ def test_medicalnet_on_2d_data(self): if __name__ == "__main__": - unittest.main() + unittest.main() From 42414ff069b31682d52e5f10c0792a33a370dc8f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 2 Aug 2023 16:43:03 +0100 Subject: [PATCH 17/20] Adds warning if user specifies cache_dir Signed-off-by: Mark Graham --- monai/losses/perceptual.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index e9a801c532..a146b26856 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -11,6 +11,8 @@ from __future__ import annotations +import warnings + import torch import torch.nn as nn @@ -75,6 +77,10 @@ def __init__( if cache_dir: torch.hub.set_dir(cache_dir) + # raise a warning that this may change the default cache dir for all torch.hub calls + warnings.warn( + f"Setting cache_dir to {cache_dir}, this may change the default cache dir for all torch.hub calls." + ) self.spatial_dims = spatial_dims self.perceptual_function: nn.Module From cfb24b481f5f3d887a3de3e66bcc0e928493779b Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 3 Aug 2023 13:05:03 +0100 Subject: [PATCH 18/20] Reverts to warivto for modelhub and changes perceptual test decorator version Signed-off-by: Mark Graham --- monai/losses/perceptual.py | 2 +- tests/test_perceptual_loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index a146b26856..fe5a39bc2e 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -173,7 +173,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("marksgraham/MedicalNet-models", model=net, verbose=verbose, force_reload=True) + self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose) self.eval() for param in self.parameters(): diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index 7cc24fb483..2f807d8222 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -48,7 +48,7 @@ ] -@SkipIfBeforePyTorchVersion((1, 10)) +@SkipIfBeforePyTorchVersion((1, 11)) @unittest.skipUnless(has_torchvision, "Requires torchvision") class TestPerceptualLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) From ebfe2effd58fcbe4324752162d45d0864d5f72dc Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 3 Aug 2023 13:26:21 +0100 Subject: [PATCH 19/20] Addresses comments from mingxin-zheng Signed-off-by: Mark Graham --- monai/losses/adversarial_loss.py | 35 +++++++++++++++----------------- monai/losses/perceptual.py | 23 ++++++++++++++++++--- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py index 9005b7f030..e2be9ae036 100644 --- a/monai/losses/adversarial_loss.py +++ b/monai/losses/adversarial_loss.py @@ -50,28 +50,28 @@ class PatchAdversarialLoss(_Loss): def __init__( self, reduction: LossReduction | str = LossReduction.MEAN, - criterion: str = AdversarialCriterions.LEAST_SQUARE.value, + criterion: str = AdversarialCriterions.LEAST_SQUARE, no_activation_leastsq: bool = False, ) -> None: - super().__init__(reduction=LossReduction(reduction).value) + super().__init__(reduction=LossReduction(reduction)) - if criterion.lower() not in [m.value for m in AdversarialCriterions]: + if criterion.lower() not in [m for m in AdversarialCriterions]: raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" - % ", ".join([m.value for m in AdversarialCriterions]) + % ", ".join([m for m in AdversarialCriterions]) ) # Depending on the criterion, a different activation layer is used. self.real_label = 1.0 self.fake_label = 0.0 self.loss_fct: _Loss - if criterion == AdversarialCriterions.BCE.value: + if criterion == AdversarialCriterions.BCE: self.activation = get_act_layer("SIGMOID") self.loss_fct = torch.nn.BCELoss(reduction=reduction) - elif criterion == AdversarialCriterions.HINGE.value: + elif criterion == AdversarialCriterions.HINGE: self.activation = get_act_layer("TANH") self.fake_label = -1.0 - elif criterion == AdversarialCriterions.LEAST_SQUARE.value: + elif criterion == AdversarialCriterions.LEAST_SQUARE: if no_activation_leastsq: self.activation = None else: @@ -137,7 +137,7 @@ def forward( input = [input] target_ = [] for _, disc_out in enumerate(input): - if self.criterion != AdversarialCriterions.HINGE.value: + if self.criterion != AdversarialCriterions.HINGE: target_.append(self.get_target_tensor(disc_out, target_is_real)) else: target_.append(self.get_zero_tensor(disc_out)) @@ -147,30 +147,27 @@ def forward( for disc_ind, disc_out in enumerate(input): if self.activation is not None: disc_out = self.activation(disc_out) - if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real: - loss_ = self.forward_single(-disc_out, target_[disc_ind]) + if self.criterion == AdversarialCriterions.HINGE and not target_is_real: + loss_ = self._forward_single(-disc_out, target_[disc_ind]) else: - loss_ = self.forward_single(disc_out, target_[disc_ind]) + loss_ = self._forward_single(disc_out, target_[disc_ind]) loss_list.append(loss_) loss: torch.Tensor | list[torch.Tensor] if loss_list is not None: - if self.reduction == LossReduction.MEAN.value: + if self.reduction == LossReduction.MEAN: loss = torch.mean(torch.stack(loss_list)) - elif self.reduction == LossReduction.SUM.value: + elif self.reduction == LossReduction.SUM: loss = torch.sum(torch.stack(loss_list)) else: loss = loss_list return loss - def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def _forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: forward: torch.Tensor - if ( - self.criterion == AdversarialCriterions.BCE.value - or self.criterion == AdversarialCriterions.LEAST_SQUARE.value - ): + if self.criterion == AdversarialCriterions.BCE or self.criterion == AdversarialCriterions.LEAST_SQUARE: forward = self.loss_fct(input, target) - elif self.criterion == AdversarialCriterions.HINGE.value: + elif self.criterion == AdversarialCriterions.HINGE: minval = torch.min(input - 1, self.get_zero_tensor(input)) forward = -torch.mean(minval) return forward diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index fe5a39bc2e..5b32b4a574 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -17,11 +17,22 @@ import torch.nn as nn from monai.utils import optional_import +from monai.utils.enums import StrEnum LPIPS, _ = optional_import("lpips", name="LPIPS") torchvision, _ = optional_import("torchvision") +class PercetualNetworkType(StrEnum): + alex = "alex" + vgg = "vgg" + squeeze = "squeeze" + radimagenet_resnet50 = "radimagenet_resnet50" + medicalnet_resnet10_23datasets = "medicalnet_resnet10_23datasets" + medical_resnet50_23datasets = "medical_resnet50_23datasets" + resnet50 = "resnet50" + + class PerceptualLoss(nn.Module): """ Perceptual loss using features from pretrained deep neural networks trained. The function supports networks @@ -32,8 +43,8 @@ class PerceptualLoss(nn.Module): 3D Medical Image Analysis" https://arxiv.org/abs/1904.00625 ; and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html . - The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual on slices from the - three axis. + The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all + three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss. Args: spatial_dims: number of spatial dimensions. @@ -56,7 +67,7 @@ class PerceptualLoss(nn.Module): def __init__( self, spatial_dims: int, - network_type: str = "alex", + network_type: str = PercetualNetworkType.alex, is_fake_3d: bool = True, fake_3d_ratio: float = 0.5, cache_dir: str | None = None, @@ -75,6 +86,12 @@ def __init__( "Argument is_fake_3d must be set to False." ) + if network_type.lower() not in [m for m in PercetualNetworkType]: + raise ValueError( + "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" + % ", ".join([m for m in PercetualNetworkType]) + ) + if cache_dir: torch.hub.set_dir(cache_dir) # raise a warning that this may change the default cache dir for all torch.hub calls From 840e5cc64c8ec7ab2254bb6b201ffbc9eae10af4 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Thu, 3 Aug 2023 13:46:30 +0100 Subject: [PATCH 20/20] Fixes codeformat list comprehension error Signed-off-by: Mark Graham --- monai/losses/adversarial_loss.py | 4 ++-- monai/losses/perceptual.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/losses/adversarial_loss.py b/monai/losses/adversarial_loss.py index e2be9ae036..f16fdee564 100644 --- a/monai/losses/adversarial_loss.py +++ b/monai/losses/adversarial_loss.py @@ -55,10 +55,10 @@ def __init__( ) -> None: super().__init__(reduction=LossReduction(reduction)) - if criterion.lower() not in [m for m in AdversarialCriterions]: + if criterion.lower() not in list(AdversarialCriterions): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" - % ", ".join([m for m in AdversarialCriterions]) + % ", ".join(AdversarialCriterions) ) # Depending on the criterion, a different activation layer is used. diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index 5b32b4a574..2207de5e64 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -86,10 +86,10 @@ def __init__( "Argument is_fake_3d must be set to False." ) - if network_type.lower() not in [m for m in PercetualNetworkType]: + if network_type.lower() not in list(PercetualNetworkType): raise ValueError( "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" - % ", ".join([m for m in PercetualNetworkType]) + % ", ".join(PercetualNetworkType) ) if cache_dir: