Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate channel configs for 2 G's and 2 D's for CycleGAN #85

Merged
merged 23 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
24e8850
Added CycleGAN multimodal v2 module
cnmy-ro Apr 5, 2021
fac5f53
Multimodal CycleGAN v2 inherit instead of copy
cnmy-ro Apr 6, 2021
56ee7f0
Multimodal CycleGAN v2 minor bug fixes
cnmy-ro Apr 6, 2021
a662d17
Resolved minor conflict in config cleargrasp yaml
cnmy-ro Apr 7, 2021
143963a
Testing git on DSRI
cnmy-ro Apr 9, 2021
cb16f33
Added support for separate channel configs for CycleGAN Gs and Ds
cnmy-ro Apr 10, 2021
d872e07
Added MIND loss for multimodal CycleGAN V2
cnmy-ro Apr 13, 2021
2414402
Option to use structure loss with cycleGAN multimodal v1
cnmy-ro Apr 15, 2021
70cf3c2
Added CycleGAN multimodal v3
cnmy-ro Apr 18, 2021
d7e7d1e
Added HX4-PET train and val datasets
cnmy-ro Apr 29, 2021
0814a4e
Added code for HX4 pix2pix training and validation
cnmy-ro May 5, 2021
98f369f
Added unpaired patch sampler in hx4 project - a generalization of sto…
cnmy-ro May 6, 2021
5d5cef8
0.2 focal_region_proportion causes error due to too small focal regio…
cnmy-ro May 6, 2021
52a4b0b
Added HX4-CycleGAN-balanced code
cnmy-ro May 12, 2021
3b4a9fe
Truly unpaired training for CycleGANs with unreg HX4-PET and ldCT
cnmy-ro May 14, 2021
f4f7ad9
Added 2 new val metrics- Normalized Mutual Information and Histogram …
cnmy-ro May 22, 2021
fd82327
Handle 0/0 division in histogram chi2 distance metric
cnmy-ro May 22, 2021
f527a25
Updated unpaired patch sampler in HX4 project
cnmy-ro May 24, 2021
40dfa72
Refactored cleargrasp code
cnmy-ro May 25, 2021
6a6053e
syncing
cnmy-ro May 31, 2021
d33022d
resolve conflict
ibro45 Aug 14, 2021
433f251
clean up direction and domain specifiers
ibro45 Aug 17, 2021
bc987f1
redesigned separation of channels in G and D
ibro45 Aug 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions midaGAN/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,29 @@ class BaseOptimizerConfig:
lr_D: float = 0.0001
lr_G: float = 0.0002


@dataclass
class BaseDiscriminatorConfig:
name: str = MISSING
in_channels: int = MISSING

class GeneratorInOutChannelsConfig:
AB: Tuple[int, int] = MISSING
BA: Optional[Tuple[int, int]] = II("train.gan.generator.in_out_channels.AB")

@dataclass
class BaseGeneratorConfig:
name: str = MISSING
in_channels: int = MISSING
out_channels: int = MISSING
# TODO: When OmegaConf implements Union, enable entering a single int when only AB is needed,
# or when AB and BA are the same. Otherwise use the GeneratorInOutChannelsConfig.
in_out_channels: GeneratorInOutChannelsConfig = GeneratorInOutChannelsConfig

@dataclass
class DiscriminatorInChannelsConfig:
B: int = MISSING
A: Optional[int] = II("train.gan.discriminator.in_channels.B")

@dataclass
class BaseDiscriminatorConfig:
name: str = MISSING
# TODO: When OmegaConf implements Union, enable entering a single int when only B is needed,
# or when B and A are the same. Otherwise use the DiscriminatorInChannelsConfig.
in_channels: DiscriminatorInChannelsConfig = DiscriminatorInChannelsConfig

@dataclass
class BaseGANConfig:
Expand Down
4 changes: 4 additions & 0 deletions midaGAN/configs/validation_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class BaseValTestMetricsConfig:
mse: bool = True
# Abs diff between the two images
mae: bool = True
# Normalized Mutual Information
nmi: bool = False
# Chi-squared Histogram Distance
histogram_chi2: bool = False


@dataclass
Expand Down
23 changes: 12 additions & 11 deletions midaGAN/engines/validator_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def run(self, current_idx=""):
# Collect visuals
device = self.model.device
self.visuals = {}
self.visuals["A"] = data["A"].to(device)
self.visuals["fake_B"] = self.infer(self.visuals["A"])
self.visuals["B"] = data["B"].to(device)
self.visuals["real_A"] = data["A"].to(device)
self.visuals["fake_B"] = self.infer(self.visuals["real_A"])
self.visuals["real_B"] = data["B"].to(device)

# Add masks if provided
if "masks" in data:
Expand All @@ -53,17 +53,17 @@ def run(self, current_idx=""):

def _calculate_metrics(self):
# TODO: Decide if cycle metrics also need to be scaled
original, pred, target = self.visuals["A"], self.visuals["fake_B"], self.visuals["B"]
original, pred, target = self.visuals["real_A"], self.visuals["fake_B"], self.visuals["real_B"]

# Metrics on input
compute_over_input = getattr(self.conf[self.conf.mode].metrics, "compute_over_input", False)

# Denormalize the data if dataset has `denormalize` method defined.
denormalize = getattr(self.current_data_loader.dataset, "denormalize", False)
if denormalize:
pred, target = denormalize(pred), denormalize(target)
pred, target = denormalize(pred.detach().clone()), denormalize(target.detach().clone())
if compute_over_input:
original = denormalize(original)
original = denormalize(original.detach().clone())

# Standard Metrics
metrics = self.metricizer.get_metrics(pred, target)
Expand All @@ -84,7 +84,7 @@ def _calculate_metrics(self):
key = f"{name}_{label}"
mask_metrics[key] = value

# Get metrics on priginal masked images
# Get metrics on original masked images
if compute_over_input:
for name, value in self.metricizer.get_metrics(original, target,
mask=mask).items():
Expand All @@ -96,14 +96,15 @@ def _calculate_metrics(self):

# Cycle Metrics
cycle_metrics = {}
if self.conf[self.conf.mode].metrics.cycle_metrics:
if "cycle" not in self.model.infer.__code__.co_varnames:
compute_cycle_metrics = getattr(self.conf[self.conf.mode].metrics, "cycle_metrics", False)
if compute_cycle_metrics:
if "direction" not in self.model.infer.__code__.co_varnames:
raise RuntimeError("If cycle metrics are enabled, please define"
" behavior of inference with a `cycle` flag in"
" behavior of inference with a `direction` flag in"
" the model's `infer()` method")

rec_A = self.infer(self.visuals["fake_B"], direction='BA')
cycle_metrics = self.metricizer.get_cycle_metrics(rec_A, self.visuals["A"])
cycle_metrics = self.metricizer.get_cycle_metrics(rec_A, self.visuals["real_A"])

metrics.update(mask_metrics)
metrics.update(cycle_metrics)
Expand Down
1 change: 0 additions & 1 deletion midaGAN/nn/discriminators/patchgan/ms_patchgan3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def get_cropped_patch(input: torch.Tensor, scale: int = 1) -> torch.Tensor:
@dataclass
class MultiScalePatchGAN3DConfig(configs.base.BaseDiscriminatorConfig):
name: str = "MultiScalePatchGAN3D"
in_channels: int = 1
ndf: int = 64
n_layers: int = 3
kernel_size: Tuple[int] = (4, 4, 4)
Expand Down
1 change: 0 additions & 1 deletion midaGAN/nn/discriminators/patchgan/patchgan2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
@dataclass
class PatchGAN2DConfig(configs.base.BaseDiscriminatorConfig):
name: str = "PatchGAN2D"
in_channels: int = 1
ndf: int = 64
n_layers: int = 3
kernel_size: Tuple[int] = (4, 4)
Expand Down
3 changes: 1 addition & 2 deletions midaGAN/nn/discriminators/patchgan/patchgan3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
@dataclass
class PatchGAN3DConfig(configs.base.BaseDiscriminatorConfig):
name: str = "PatchGAN3D"
in_channels: int = 1
ndf: int = 64
n_layers: int = 3
kernel_size: Tuple[int] = (4, 4, 4)
Expand Down Expand Up @@ -60,7 +59,7 @@ def __init__(self, in_channels, ndf, n_layers, kernel_size, norm_type):
nn.LeakyReLU(0.2, True)
]

sequence += [nn.Conv3d(ndf * nf_mult, kernel_size=kw, stride=1, padding=padw)]
sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)

def forward(self, input):
Expand Down
21 changes: 17 additions & 4 deletions midaGAN/nn/gans/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,25 @@ def __init__(self, conf):
self.optimizers = {}
self.networks = {}

def init_networks(self):
for name in self.networks.keys():
def init_networks(self):

for name in self.networks.keys():

# Generator
if name.startswith('G'):
self.networks[name] = build_G(self.conf, self.device)
# Direction of the generator.
# 'AB' by default, only bi-directional GANs (e.g. CycleGAN) need
# generator for 'BA' direction as well.
direction = 'BA' if name.endswith('_BA') else 'AB'
self.networks[name] = build_G(self.conf, direction, self.device)

# Discriminator
elif name.startswith('D'):
self.networks[name] = build_D(self.conf, self.device)
# Discriminator's domain.
# 'B' by default, only bi-directional GANs (e.g. CycleGAN) need
# the 'A' domain discriminator as well.
domain = 'A' if name.endswith('_A') else 'B'
self.networks[name] = build_D(self.conf, domain, self.device)

@abstractmethod
def init_criterions(self):
Expand Down
2 changes: 1 addition & 1 deletion midaGAN/nn/generators/unet/unet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Unet3DConfig(configs.base.BaseGeneratorConfig):
name: str = 'Unet3D'
num_downs: int = 7
ngf: int = 64
use_dropout = False
use_dropout: bool = False


class Unet3D(nn.Module):
Expand Down
45 changes: 35 additions & 10 deletions midaGAN/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from omegaconf import OmegaConf
import omegaconf

from midaGAN.configs.config import Config
from midaGAN.configs.utils import IMPORT_LOCATIONS, init_config
Expand All @@ -13,15 +13,15 @@


def build_conf():
cli = OmegaConf.from_cli()
cli = omegaconf.OmegaConf.from_cli()
conf = init_config(cli.pop("config"), config_class=Config)
return OmegaConf.merge(conf, cli)
return omegaconf.OmegaConf.merge(conf, cli)


def build_loader(conf):
"""Builds the dataloader(s). If the config for dataset is a single dataset, it
will return a dataloader for it, but if multiple datasets were specified,
a list of dataloaders, one for each dataset, will be returnet.
a list of dataloaders, one for each dataset, will be returned.
"""
############## Multi-dataset loaders #################
if "multi_dataset" in conf[conf.mode] and conf[conf.mode].multi_dataset is not None:
Expand Down Expand Up @@ -78,16 +78,24 @@ def build_gan(conf):
return model


def build_G(conf, device):
return build_network_by_role('generator', conf, device)
def build_G(conf, direction, device):
assert direction in ['AB', 'BA']
return build_network_by_role('generator', conf, direction, device)


def build_D(conf, device):
return build_network_by_role('discriminator', conf, device)
def build_D(conf, domain, device):
assert domain in ['B', 'A']
return build_network_by_role('discriminator', conf, domain, device)


def build_network_by_role(role, conf, device):
"""Builds a discriminator or generator. TODO: document """
def build_network_by_role(role, conf, label, device):
"""Builds a discriminator or generator. TODO: document better
Parameters:
role -- `generator` or `discriminator`
conf -- conf
label -- role-specific label
device -- torch device
"""
assert role in ['discriminator', 'generator']

name = conf.train.gan[role].name
Expand All @@ -96,6 +104,23 @@ def build_network_by_role(role, conf, device):
network_args = dict(conf.train.gan[role])
network_args.pop("name")
network_args["norm_type"] = conf.train.gan.norm_type

# Handle the network's channels settings
if role == 'generator':
in_out_channels = network_args.pop('in_out_channels')
# TODO: This will enable support for both Dict and a single Tuple as
# mentioned in the config (configs/base.py#GeneratorInOutChannelsConfig)
# when OmegaConf will allow Union. Update comment when that happens.
if isinstance(in_out_channels, omegaconf.dictconfig.DictConfig):
in_out_channels = in_out_channels[label]
network_args["in_channels"], network_args["out_channels"] = in_out_channels

elif role == 'discriminator':
# TODO: This will enable support for both Dict and a single Int as
# mentioned in the config (configs/base.py#DiscriminatorInChannelsConfig)
# when OmegaConf will allow Union. Update comment when that happens.
if isinstance(network_args["in_channels"] , omegaconf.dictconfig.DictConfig):
network_args["in_channels"] = network_args["in_channels"][label]

network = network_class(**network_args)
return init_net(network, conf, device)
64 changes: 60 additions & 4 deletions midaGAN/utils/metrics/val_test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# import midaGAN.nn.losses.ssim as ssim
import numpy as np
from typing import Optional

import numpy as np
from scipy.stats import entropy
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


Expand Down Expand Up @@ -85,7 +84,49 @@ def ssim(gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None) -> fl
return ssim_sum / size


METRIC_DICT = {"ssim": ssim, "mse": mse, "nmse": nmse, "psnr": psnr, "mae": mae}
def nmi(gt: np.ndarray, pred: np.ndarray) -> float:
"""Normalized Mutual Information.
Implementation taken from scikit-image 0.19.0.dev0 source --
https://github.com/scikit-image/scikit-image/blob/main/skimage/metrics/simple_metrics.py#L193-L261
Not using scikit-image because NMI is supported only in >=0.19.
"""
bins = 100 # 100 bins by default
hist, bin_edges = np.histogramdd(
[np.reshape(gt, -1), np.reshape(pred, -1)],
bins=bins,
density=True,
)
H0 = entropy(np.sum(hist, axis=0))
H1 = entropy(np.sum(hist, axis=1))
H01 = entropy(np.reshape(hist, -1))
nmi_value = (H0 + H1) / H01
return float(nmi_value)


def histogram_chi2(gt: np.ndarray, pred: np.ndarray) -> float:
"""Chi-squared distance computed between histograms of the GT and the prediction.
More about comparing two histograms --
https://stackoverflow.com/questions/6499491/comparing-two-histograms
"""
bins = 100 # 100 bins by default

# Compute histograms
gt_histogram, gt_bin_edges = np.histogram(gt, bins=bins)
pred_histogram, pred_bin_edges = np.histogram(pred, bins=bins)

# Normalize the histograms to convert them into discrete distributions
gt_histogram = gt_histogram / gt_histogram.sum()
pred_histogram = pred_histogram / pred_histogram.sum()

# Compute chi-squared distance
bin_to_bin_distances = (pred_histogram - gt_histogram)**2 / (pred_histogram + gt_histogram)
# Remove NaN values caused by 0/0 division. Equivalent to manually setting them as 0.
bin_to_bin_distances = bin_to_bin_distances[np.logical_not(np.isnan(bin_to_bin_distances))]
chi2_distance_value = np.sum(bin_to_bin_distances)
return float(chi2_distance_value)


METRIC_DICT = {"ssim": ssim, "mse": mse, "nmse": nmse, "psnr": psnr, "mae": mae, "nmi": nmi, "histogram_chi2": histogram_chi2}


class ValTestMetrics:
Expand All @@ -95,7 +136,22 @@ def __init__(self, conf):

def get_metrics(self, inputs, targets, mask=None):
metrics = {}


# Chinmay HX4-specific hack: If the tensors have 2 channels, take only the 1st channel (HX4-PET),
# because the 2nd channel is a dummy.
# Need this in case of HX4-CycleGAN-balanced.
if inputs.shape[1] == 2:
inputs = inputs[:, :1]
targets = targets[:, :1]

# Chinmay Cleargrasp-specific hack: If the tensors have 4 channels, take only the last channel (depthmap),
# because the first 3 are a dummy array.
# Need this in case of CycleGAN-balanced applied to Cleargrasp (i.e. version 3 in this project).
if inputs.shape[1] == 4:
inputs = inputs[:, 3:]
targets = targets[:, 3:]


inputs, targets = get_npy(inputs), get_npy(targets)

# Iterating over all metrics that need to be computed
Expand Down
Loading