diff --git a/midaGAN/configs/base.py b/midaGAN/configs/base.py index d636aa97..e06ca592 100755 --- a/midaGAN/configs/base.py +++ b/midaGAN/configs/base.py @@ -24,19 +24,29 @@ class BaseOptimizerConfig: lr_D: float = 0.0001 lr_G: float = 0.0002 - @dataclass -class BaseDiscriminatorConfig: - name: str = MISSING - in_channels: int = MISSING - +class GeneratorInOutChannelsConfig: + AB: Tuple[int, int] = MISSING + BA: Optional[Tuple[int, int]] = II("train.gan.generator.in_out_channels.AB") @dataclass class BaseGeneratorConfig: name: str = MISSING - in_channels: int = MISSING - out_channels: int = MISSING + # TODO: When OmegaConf implements Union, enable entering a single int when only AB is needed, + # or when AB and BA are the same. Otherwise use the GeneratorInOutChannelsConfig. + in_out_channels: GeneratorInOutChannelsConfig = GeneratorInOutChannelsConfig + +@dataclass +class DiscriminatorInChannelsConfig: + B: int = MISSING + A: Optional[int] = II("train.gan.discriminator.in_channels.B") +@dataclass +class BaseDiscriminatorConfig: + name: str = MISSING + # TODO: When OmegaConf implements Union, enable entering a single int when only B is needed, + # or when B and A are the same. Otherwise use the DiscriminatorInChannelsConfig. + in_channels: DiscriminatorInChannelsConfig = DiscriminatorInChannelsConfig @dataclass class BaseGANConfig: diff --git a/midaGAN/configs/validation_testing.py b/midaGAN/configs/validation_testing.py index 0c1a6ee3..7ee2d519 100644 --- a/midaGAN/configs/validation_testing.py +++ b/midaGAN/configs/validation_testing.py @@ -28,6 +28,10 @@ class BaseValTestMetricsConfig: mse: bool = True # Abs diff between the two images mae: bool = True + # Normalized Mutual Information + nmi: bool = False + # Chi-squared Histogram Distance + histogram_chi2: bool = False @dataclass diff --git a/midaGAN/engines/validator_tester.py b/midaGAN/engines/validator_tester.py index d602448f..2e69e9f7 100644 --- a/midaGAN/engines/validator_tester.py +++ b/midaGAN/engines/validator_tester.py @@ -30,9 +30,9 @@ def run(self, current_idx=""): # Collect visuals device = self.model.device self.visuals = {} - self.visuals["A"] = data["A"].to(device) - self.visuals["fake_B"] = self.infer(self.visuals["A"]) - self.visuals["B"] = data["B"].to(device) + self.visuals["real_A"] = data["A"].to(device) + self.visuals["fake_B"] = self.infer(self.visuals["real_A"]) + self.visuals["real_B"] = data["B"].to(device) # Add masks if provided if "masks" in data: @@ -53,7 +53,7 @@ def run(self, current_idx=""): def _calculate_metrics(self): # TODO: Decide if cycle metrics also need to be scaled - original, pred, target = self.visuals["A"], self.visuals["fake_B"], self.visuals["B"] + original, pred, target = self.visuals["real_A"], self.visuals["fake_B"], self.visuals["real_B"] # Metrics on input compute_over_input = getattr(self.conf[self.conf.mode].metrics, "compute_over_input", False) @@ -61,9 +61,9 @@ def _calculate_metrics(self): # Denormalize the data if dataset has `denormalize` method defined. denormalize = getattr(self.current_data_loader.dataset, "denormalize", False) if denormalize: - pred, target = denormalize(pred), denormalize(target) + pred, target = denormalize(pred.detach().clone()), denormalize(target.detach().clone()) if compute_over_input: - original = denormalize(original) + original = denormalize(original.detach().clone()) # Standard Metrics metrics = self.metricizer.get_metrics(pred, target) @@ -84,7 +84,7 @@ def _calculate_metrics(self): key = f"{name}_{label}" mask_metrics[key] = value - # Get metrics on priginal masked images + # Get metrics on original masked images if compute_over_input: for name, value in self.metricizer.get_metrics(original, target, mask=mask).items(): @@ -96,14 +96,15 @@ def _calculate_metrics(self): # Cycle Metrics cycle_metrics = {} - if self.conf[self.conf.mode].metrics.cycle_metrics: - if "cycle" not in self.model.infer.__code__.co_varnames: + compute_cycle_metrics = getattr(self.conf[self.conf.mode].metrics, "cycle_metrics", False) + if compute_cycle_metrics: + if "direction" not in self.model.infer.__code__.co_varnames: raise RuntimeError("If cycle metrics are enabled, please define" - " behavior of inference with a `cycle` flag in" + " behavior of inference with a `direction` flag in" " the model's `infer()` method") rec_A = self.infer(self.visuals["fake_B"], direction='BA') - cycle_metrics = self.metricizer.get_cycle_metrics(rec_A, self.visuals["A"]) + cycle_metrics = self.metricizer.get_cycle_metrics(rec_A, self.visuals["real_A"]) metrics.update(mask_metrics) metrics.update(cycle_metrics) diff --git a/midaGAN/nn/discriminators/patchgan/ms_patchgan3d.py b/midaGAN/nn/discriminators/patchgan/ms_patchgan3d.py index 3d26cbdc..ec52a51e 100644 --- a/midaGAN/nn/discriminators/patchgan/ms_patchgan3d.py +++ b/midaGAN/nn/discriminators/patchgan/ms_patchgan3d.py @@ -32,7 +32,6 @@ def get_cropped_patch(input: torch.Tensor, scale: int = 1) -> torch.Tensor: @dataclass class MultiScalePatchGAN3DConfig(configs.base.BaseDiscriminatorConfig): name: str = "MultiScalePatchGAN3D" - in_channels: int = 1 ndf: int = 64 n_layers: int = 3 kernel_size: Tuple[int] = (4, 4, 4) diff --git a/midaGAN/nn/discriminators/patchgan/patchgan2d.py b/midaGAN/nn/discriminators/patchgan/patchgan2d.py index 036380f9..fa0262de 100644 --- a/midaGAN/nn/discriminators/patchgan/patchgan2d.py +++ b/midaGAN/nn/discriminators/patchgan/patchgan2d.py @@ -10,7 +10,6 @@ @dataclass class PatchGAN2DConfig(configs.base.BaseDiscriminatorConfig): name: str = "PatchGAN2D" - in_channels: int = 1 ndf: int = 64 n_layers: int = 3 kernel_size: Tuple[int] = (4, 4) diff --git a/midaGAN/nn/discriminators/patchgan/patchgan3d.py b/midaGAN/nn/discriminators/patchgan/patchgan3d.py index 0927e1f1..711f3f3d 100644 --- a/midaGAN/nn/discriminators/patchgan/patchgan3d.py +++ b/midaGAN/nn/discriminators/patchgan/patchgan3d.py @@ -10,7 +10,6 @@ @dataclass class PatchGAN3DConfig(configs.base.BaseDiscriminatorConfig): name: str = "PatchGAN3D" - in_channels: int = 1 ndf: int = 64 n_layers: int = 3 kernel_size: Tuple[int] = (4, 4, 4) @@ -60,7 +59,7 @@ def __init__(self, in_channels, ndf, n_layers, kernel_size, norm_type): nn.LeakyReLU(0.2, True) ] - sequence += [nn.Conv3d(ndf * nf_mult, kernel_size=kw, stride=1, padding=padw)] + sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] self.model = nn.Sequential(*sequence) def forward(self, input): diff --git a/midaGAN/nn/gans/base.py b/midaGAN/nn/gans/base.py index 8d43b4ea..d0742a77 100644 --- a/midaGAN/nn/gans/base.py +++ b/midaGAN/nn/gans/base.py @@ -45,12 +45,25 @@ def __init__(self, conf): self.optimizers = {} self.networks = {} - def init_networks(self): - for name in self.networks.keys(): + def init_networks(self): + + for name in self.networks.keys(): + + # Generator if name.startswith('G'): - self.networks[name] = build_G(self.conf, self.device) + # Direction of the generator. + # 'AB' by default, only bi-directional GANs (e.g. CycleGAN) need + # generator for 'BA' direction as well. + direction = 'BA' if name.endswith('_BA') else 'AB' + self.networks[name] = build_G(self.conf, direction, self.device) + + # Discriminator elif name.startswith('D'): - self.networks[name] = build_D(self.conf, self.device) + # Discriminator's domain. + # 'B' by default, only bi-directional GANs (e.g. CycleGAN) need + # the 'A' domain discriminator as well. + domain = 'A' if name.endswith('_A') else 'B' + self.networks[name] = build_D(self.conf, domain, self.device) @abstractmethod def init_criterions(self): diff --git a/midaGAN/nn/generators/unet/unet3d.py b/midaGAN/nn/generators/unet/unet3d.py index 34f78f76..6c79d3b9 100644 --- a/midaGAN/nn/generators/unet/unet3d.py +++ b/midaGAN/nn/generators/unet/unet3d.py @@ -12,7 +12,7 @@ class Unet3DConfig(configs.base.BaseGeneratorConfig): name: str = 'Unet3D' num_downs: int = 7 ngf: int = 64 - use_dropout = False + use_dropout: bool = False class Unet3D(nn.Module): diff --git a/midaGAN/utils/builders.py b/midaGAN/utils/builders.py index 00b347f2..11b288a1 100644 --- a/midaGAN/utils/builders.py +++ b/midaGAN/utils/builders.py @@ -2,7 +2,7 @@ import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from omegaconf import OmegaConf +import omegaconf from midaGAN.configs.config import Config from midaGAN.configs.utils import IMPORT_LOCATIONS, init_config @@ -13,15 +13,15 @@ def build_conf(): - cli = OmegaConf.from_cli() + cli = omegaconf.OmegaConf.from_cli() conf = init_config(cli.pop("config"), config_class=Config) - return OmegaConf.merge(conf, cli) + return omegaconf.OmegaConf.merge(conf, cli) def build_loader(conf): """Builds the dataloader(s). If the config for dataset is a single dataset, it will return a dataloader for it, but if multiple datasets were specified, - a list of dataloaders, one for each dataset, will be returnet. + a list of dataloaders, one for each dataset, will be returned. """ ############## Multi-dataset loaders ################# if "multi_dataset" in conf[conf.mode] and conf[conf.mode].multi_dataset is not None: @@ -78,16 +78,24 @@ def build_gan(conf): return model -def build_G(conf, device): - return build_network_by_role('generator', conf, device) +def build_G(conf, direction, device): + assert direction in ['AB', 'BA'] + return build_network_by_role('generator', conf, direction, device) -def build_D(conf, device): - return build_network_by_role('discriminator', conf, device) +def build_D(conf, domain, device): + assert domain in ['B', 'A'] + return build_network_by_role('discriminator', conf, domain, device) -def build_network_by_role(role, conf, device): - """Builds a discriminator or generator. TODO: document """ +def build_network_by_role(role, conf, label, device): + """Builds a discriminator or generator. TODO: document better + Parameters: + role -- `generator` or `discriminator` + conf -- conf + label -- role-specific label + device -- torch device + """ assert role in ['discriminator', 'generator'] name = conf.train.gan[role].name @@ -96,6 +104,23 @@ def build_network_by_role(role, conf, device): network_args = dict(conf.train.gan[role]) network_args.pop("name") network_args["norm_type"] = conf.train.gan.norm_type + + # Handle the network's channels settings + if role == 'generator': + in_out_channels = network_args.pop('in_out_channels') + # TODO: This will enable support for both Dict and a single Tuple as + # mentioned in the config (configs/base.py#GeneratorInOutChannelsConfig) + # when OmegaConf will allow Union. Update comment when that happens. + if isinstance(in_out_channels, omegaconf.dictconfig.DictConfig): + in_out_channels = in_out_channels[label] + network_args["in_channels"], network_args["out_channels"] = in_out_channels + + elif role == 'discriminator': + # TODO: This will enable support for both Dict and a single Int as + # mentioned in the config (configs/base.py#DiscriminatorInChannelsConfig) + # when OmegaConf will allow Union. Update comment when that happens. + if isinstance(network_args["in_channels"] , omegaconf.dictconfig.DictConfig): + network_args["in_channels"] = network_args["in_channels"][label] network = network_class(**network_args) return init_net(network, conf, device) diff --git a/midaGAN/utils/metrics/val_test_metrics.py b/midaGAN/utils/metrics/val_test_metrics.py index 2d20c61b..0b13085d 100644 --- a/midaGAN/utils/metrics/val_test_metrics.py +++ b/midaGAN/utils/metrics/val_test_metrics.py @@ -1,8 +1,7 @@ # import midaGAN.nn.losses.ssim as ssim import numpy as np from typing import Optional - -import numpy as np +from scipy.stats import entropy from skimage.metrics import peak_signal_noise_ratio, structural_similarity @@ -85,7 +84,49 @@ def ssim(gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None) -> fl return ssim_sum / size -METRIC_DICT = {"ssim": ssim, "mse": mse, "nmse": nmse, "psnr": psnr, "mae": mae} +def nmi(gt: np.ndarray, pred: np.ndarray) -> float: + """Normalized Mutual Information. + Implementation taken from scikit-image 0.19.0.dev0 source -- + https://github.com/scikit-image/scikit-image/blob/main/skimage/metrics/simple_metrics.py#L193-L261 + Not using scikit-image because NMI is supported only in >=0.19. + """ + bins = 100 # 100 bins by default + hist, bin_edges = np.histogramdd( + [np.reshape(gt, -1), np.reshape(pred, -1)], + bins=bins, + density=True, + ) + H0 = entropy(np.sum(hist, axis=0)) + H1 = entropy(np.sum(hist, axis=1)) + H01 = entropy(np.reshape(hist, -1)) + nmi_value = (H0 + H1) / H01 + return float(nmi_value) + + +def histogram_chi2(gt: np.ndarray, pred: np.ndarray) -> float: + """Chi-squared distance computed between histograms of the GT and the prediction. + More about comparing two histograms -- + https://stackoverflow.com/questions/6499491/comparing-two-histograms + """ + bins = 100 # 100 bins by default + + # Compute histograms + gt_histogram, gt_bin_edges = np.histogram(gt, bins=bins) + pred_histogram, pred_bin_edges = np.histogram(pred, bins=bins) + + # Normalize the histograms to convert them into discrete distributions + gt_histogram = gt_histogram / gt_histogram.sum() + pred_histogram = pred_histogram / pred_histogram.sum() + + # Compute chi-squared distance + bin_to_bin_distances = (pred_histogram - gt_histogram)**2 / (pred_histogram + gt_histogram) + # Remove NaN values caused by 0/0 division. Equivalent to manually setting them as 0. + bin_to_bin_distances = bin_to_bin_distances[np.logical_not(np.isnan(bin_to_bin_distances))] + chi2_distance_value = np.sum(bin_to_bin_distances) + return float(chi2_distance_value) + + +METRIC_DICT = {"ssim": ssim, "mse": mse, "nmse": nmse, "psnr": psnr, "mae": mae, "nmi": nmi, "histogram_chi2": histogram_chi2} class ValTestMetrics: @@ -95,7 +136,22 @@ def __init__(self, conf): def get_metrics(self, inputs, targets, mask=None): metrics = {} - + + # Chinmay HX4-specific hack: If the tensors have 2 channels, take only the 1st channel (HX4-PET), + # because the 2nd channel is a dummy. + # Need this in case of HX4-CycleGAN-balanced. + if inputs.shape[1] == 2: + inputs = inputs[:, :1] + targets = targets[:, :1] + + # Chinmay Cleargrasp-specific hack: If the tensors have 4 channels, take only the last channel (depthmap), + # because the first 3 are a dummy array. + # Need this in case of CycleGAN-balanced applied to Cleargrasp (i.e. version 3 in this project). + if inputs.shape[1] == 4: + inputs = inputs[:, 3:] + targets = targets[:, 3:] + + inputs, targets = get_npy(inputs), get_npy(targets) # Iterating over all metrics that need to be computed diff --git a/midaGAN/utils/trackers/utils.py b/midaGAN/utils/trackers/utils.py index 8c732b3d..16d565d6 100644 --- a/midaGAN/utils/trackers/utils.py +++ b/midaGAN/utils/trackers/utils.py @@ -100,28 +100,35 @@ def _split_multimodal_visuals(visuals, multi_modality_split): return visuals splitted_visuals = {} - # For each tensor in visuals - for name in visuals: - # For each domain (A and B) - for domain in multi_modality_split: - # Names of visuals end with _A or _B - if name.endswith(domain): - channel_split = tuple(multi_modality_split[domain]) - # Possible that the split is defined for only one of the two domains - if channel_split is None: - # Then just copy the visual - splitted_visuals[name] = visuals[name] - continue - - # Num of channels in split need to correspond to the actual num of channels - if sum(channel_split) != visuals[name].shape[1]: - raise ValueError("Please specify channel-split correctly!") - - # Split the modalities and assign them to visuals - splitted_modalities = torch.split(visuals[name], channel_split, dim=1) - for i in range(len(channel_split)): - splitted_visuals[f"{name}{i+1}"] = splitted_modalities[i] + # For each tensor in visuals + for name in visuals.keys(): + # Consider only those visuals for splitting whose names contain `_A` or `_B` (for ex. images with names `real_A` or `fake_B`) + if '_A' in name or '_B' in name: + # For each domain (`A` and `B`) + for domain in multi_modality_split: + # Names of visuals ending with the domain name + if name.endswith(domain): + channel_split = tuple(multi_modality_split[domain]) + # Possible that the split is defined for only one of the two domains + if channel_split is None: + # Then just copy the visual + splitted_visuals[name] = visuals[name] + continue + + # Num of channels in split need to correspond to the actual num of channels + if sum(channel_split) != visuals[name].shape[1]: + raise ValueError("Please specify channel-split correctly!") + + # Split the modalities and assign them to visuals + splitted_modalities = torch.split(visuals[name], channel_split, dim=1) + for i in range(len(channel_split)): + splitted_visuals[f"{name}{i+1}"] = splitted_modalities[i] + + # No processing of visuals whose names do not contain `_A` or `_B` (for ex. masks with names `BODY` or `GTV`) + else: + splitted_visuals[name] = visuals[name] + return splitted_visuals diff --git a/midaGAN/utils/trackers/validation_testing.py b/midaGAN/utils/trackers/validation_testing.py index c0b43449..91f9d855 100644 --- a/midaGAN/utils/trackers/validation_testing.py +++ b/midaGAN/utils/trackers/validation_testing.py @@ -87,6 +87,7 @@ def log_visuals(): # When testing, there aren't multiple iters, so it isn't necessary. name += "/" if self.conf.mode == "val" else "_" name += f"{visuals_idx}" + self._save_image(visuals, name) def clear_buffers(): diff --git a/projects/cleargrasp_depth_estimation/datasets/train_val_cyclegan_dataset.py b/projects/cleargrasp_depth_estimation/datasets/old/train_val_cyclegan_dataset.py similarity index 62% rename from projects/cleargrasp_depth_estimation/datasets/train_val_cyclegan_dataset.py rename to projects/cleargrasp_depth_estimation/datasets/old/train_val_cyclegan_dataset.py index e157dedf..16f03ba1 100644 --- a/projects/cleargrasp_depth_estimation/datasets/train_val_cyclegan_dataset.py +++ b/projects/cleargrasp_depth_estimation/datasets/old/train_val_cyclegan_dataset.py @@ -16,18 +16,22 @@ from dataclasses import dataclass from midaGAN import configs + EXTENSIONS = ['.jpg', '.exr'] -# Max allowed intenity of depthmap images. Specified in metres. -# This value is chosen by analyzing max values throughout the dataset. -UPPER_DEPTH_INTENSITY_LIMIT = 8.0 + +# Max allowed intenity of depthmap images. Specified in metres. +# This value is chosen by analyzing max values throughout the dataset. +UPPER_DEPTH_INTENSITY_LIMIT = 8.0 + @dataclass class ClearGraspCycleGANDatasetConfig(configs.base.BaseDatasetConfig): name: str = "ClearGraspCycleGANDataset" load_size: Tuple[int, int] = (512, 256) - paired: bool = False # `True` for paired A-B + paired: bool = False # `True` for paired A-B. Need paired during validation + fetch_rgb_b: bool = False # Whether to fetch noisy RGB photo for domain B class ClearGraspCycleGANDataset(Dataset): @@ -36,15 +40,15 @@ class ClearGraspCycleGANDataset(Dataset): Curated from Cleargrasp robot-vision dataset. Here, the GAN translation task is: RGB + Normalmap --> Depthmap This is the CycleGAN version: - Domain A: RGB and Normalmap - Domain B: RGB and Depthmap + Domain A: RGB photo and Normalmap + Domain B: Several options -- (1) Depthmap (used with CycleGAN multimodal v1) + (2) Noisy RGB photo and Depthmap (used with v2 and v3) """ - def __init__(self, conf): - + self.mode = conf.mode - self.dir_rgb = Path(conf[conf.mode].dataset.root) / "rgb" + self.dir_rgb = Path(conf[conf.mode].dataset.root) / "rgb" self.dir_normal = Path(conf[conf.mode].dataset.root) / "normal" self.dir_depth = Path(conf[conf.mode].dataset.root) / "depth" @@ -55,20 +59,25 @@ def __init__(self, conf): self.load_size = conf[conf.mode].dataset.load_size self.load_resize_transform = transforms.Resize( - size=(load_size[1], load_size[0]), interpolation=transforms.InterpolationMode.BICUBIC) + size=(self.load_size[1], self.load_size[0]), interpolation=transforms.InterpolationMode.BICUBIC + ) + + self.paired = conf[conf.mode].dataset.paired + self.fetch_rgb_b = conf[conf.mode].dataset.fetch_rgb_b + def __getitem__(self, index): index_A = index % self.dataset_size index_B = index_A if self.paired else random.randint(0, self.dataset_size - 1) - + rgb_A_path = self.rgb_paths[index_A] - normal_path = self.normal_paths[index_A] - rgb_B_path = self.rgb_paths[index_B] + normal_path = self.normal_paths[index_A] + rgb_B_path = self.rgb_paths[index_B] if self.fetch_rgb_b else None depth_path = self.depth_paths[index_B] rgb_A = read_rgb_to_tensor(rgb_A_path) normalmap = read_normalmap_to_tensor(normal_path) - rgb_B = read_rgb_to_tensor(rgb_B_path) + rgb_B = read_rgb_to_tensor(rgb_B_path) if self.fetch_rgb_b else torch.zeros_like(rgb_A) depthmap = read_depthmap_to_tensor(depth_path) # Resize @@ -76,19 +85,33 @@ def __getitem__(self, index): normalmap = self.load_resize_transform(normalmap) rgb_B = self.load_resize_transform(rgb_B) depthmap = self.load_resize_transform(depthmap) - + # Transform rgb_A, normalmap, rgb_B, depthmap = self.apply_transforms(rgb_A, normalmap, rgb_B, depthmap) # Normalize rgb_A, normalmap, rgb_B, depthmap = self.normalize(rgb_A, normalmap, rgb_B, depthmap) - return {'A': torch.cat([rgb_img, normalmap], dim=0), - 'B': torch.cat([rgb_img, depthmap, depthmap, depthmap], dim=0)} + # Add noise in B's RGB photo + if self.fetch_rgb_b: + rgb_B = rgb_B + torch.normal(mean=0, std=0.05, size=(self.load_size[1], self.load_size[0])) + rgb_B = torch.clamp(rgb_B, -1, 1) # Clip to remove out-of-range values + + # Prepare A and B + A = torch.cat([rgb_A, normalmap], dim=0) + if self.fetch_rgb_b: + B = torch.cat([rgb_B, depthmap], dim=0) + else: + B = depthmap + + return {'A': A, 'B': B} + + def __len__(self): return self.dataset_size + def apply_transforms(self, rgb_A, normalmap, rgb_B, depthmap): """ TODO: What transform to use for augmentation? @@ -97,6 +120,7 @@ def apply_transforms(self, rgb_A, normalmap, rgb_B, depthmap): """ return rgb_A, normalmap, rgb_B, depthmap + def normalize(self, rgb_A, normalmap, rgb_B, depthmap): """ Scale intensities to [-1,1] range @@ -107,11 +131,20 @@ def normalize(self, rgb_A, normalmap, rgb_B, depthmap): depthmap_min, depthmap_max = 0.0, UPPER_DEPTH_INTENSITY_LIMIT # Normalize - rgb_A = (rgb_A-rgb_min)/(rgb_max-rgb_min) * 2 - 1 - rgb_B = (rgb_B-rgb_min)/(rgb_max-rgb_min) * 2 - 1 - depthmap = (depthmap-depthmap_min)/(depthmap_max-depthmap_min) * 2 - 1 - return torch.clamp(rgb_A, -1, 1), torch.clamp(normalmap, -1, 1), \ - torch.clamp(rgb_B, -1, 1), torch.clamp(depthmap, -1, 1) + rgb_A = (rgb_A - rgb_min) / (rgb_max - rgb_min) * 2 - 1 + if self.fetch_rgb_b: + rgb_B = (rgb_B - rgb_min) / (rgb_max - rgb_min) * 2 - 1 + depthmap = (depthmap - depthmap_min) / (depthmap_max - depthmap_min) * 2 - 1 + + # Clip to remove out-of-range overshoots + rgb_A = torch.clamp(rgb_A, -1, 1) + normalmap = torch.clamp(normalmap, -1, 1) + if self.fetch_rgb_b: + rgb_B = torch.clamp(rgb_B, -1, 1) + depthmap = torch.clamp(depthmap, -1, 1) + + return rgb_A, normalmap, rgb_B, depthmap + def read_rgb_to_tensor(path): @@ -119,26 +152,25 @@ def read_rgb_to_tensor(path): RGB reader based on cv2.imread(). Just for consistency with normalmap and depthmap readers. """ - bgr_img = cv2.imread(path) + bgr_img = cv2.imread(str(path)) rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) - rgb_img = rgb_img.transpose(2, 0, 1) # (H,W,C) to (C,H,W) format + rgb_img = rgb_img.transpose(2,0,1) # (H,W,C) to (C,H,W) return torch.tensor(rgb_img, dtype=torch.float32) - def read_normalmap_to_tensor(path): """ Read normalmap image from EXR format to tensor of form (3,H,W) """ - normalmap = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + normalmap = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) normalmap = cv2.cvtColor(normalmap, cv2.COLOR_BGR2RGB) - normalmap = normalmap.transpose(2, 0, 1) # (H,W,C) to (C,H,W) format + normalmap = normalmap.transpose(2,0,1) # (H,W,C) to (C,H,W) return torch.tensor(normalmap, dtype=torch.float32) - def read_depthmap_to_tensor(path): """ Read depthmap image from EXR format to tensor of form (1,H,W) """ - depthmap = cv2.imread(path, cv2.IMREAD_ANYDEPTH) + depthmap = cv2.imread(str(path), cv2.IMREAD_ANYDEPTH) depthmap = np.expand_dims(depthmap, axis=0) # (H,W) to (1,H,W) return torch.tensor(depthmap, dtype=torch.float32) + diff --git a/projects/cleargrasp_depth_estimation/datasets/train_val_pix2pix_dataset.py b/projects/cleargrasp_depth_estimation/datasets/old/train_val_pix2pix_dataset.py similarity index 94% rename from projects/cleargrasp_depth_estimation/datasets/train_val_pix2pix_dataset.py rename to projects/cleargrasp_depth_estimation/datasets/old/train_val_pix2pix_dataset.py index 3ff3e644..40dbed02 100644 --- a/projects/cleargrasp_depth_estimation/datasets/train_val_pix2pix_dataset.py +++ b/projects/cleargrasp_depth_estimation/datasets/old/train_val_pix2pix_dataset.py @@ -31,7 +31,6 @@ class ClearGraspPix2PixDatasetConfig(configs.base.BaseDatasetConfig): name: str = "ClearGraspPix2PixDataset" load_size: Tuple[int, int] = (512, 256) - paired: bool = True # `True` for paired A-B class ClearGraspPix2PixDataset(Dataset): @@ -41,7 +40,7 @@ class ClearGraspPix2PixDataset(Dataset): Here, the GAN translation task is: RGB + Normalmap --> Depthmap This is the Pix2Pix version: Domain A: RGB and Normalmap - Domain B: Depthmap + Domain B: Depthmap (paired with A) """ def __init__(self, conf): @@ -79,14 +78,11 @@ def __getitem__(self, index): depthmap = self.load_resize_transform(depthmap) # Transforms - rgb, normalmap, depthmap = self.apply_transforms(rgb_img, normalmap, depthmap) + rgb_img, normalmap, depthmap = self.apply_transforms(rgb_img, normalmap, depthmap) # Normalize rgb_img, normalmap, depthmap = self.normalize(rgb_img, normalmap, depthmap) - # Make 3-channel image from 1-channel - # depthmap = torch.cat([depthmap,depthmap,depthmap], dim=0) - return {'A': torch.cat([rgb_img, normalmap], dim=0), 'B': depthmap} diff --git a/projects/cleargrasp_depth_estimation/datasets/train_dataset.py b/projects/cleargrasp_depth_estimation/datasets/train_dataset.py new file mode 100644 index 00000000..64f7d125 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/datasets/train_dataset.py @@ -0,0 +1,193 @@ +import random +from pathlib import Path +from typing import Tuple + +import glob +import numpy as np +import cv2 +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF + +from midaGAN.utils.io import make_dataset_of_files + +# Config imports +from dataclasses import dataclass +from midaGAN import configs + +from midaGAN.data.utils.normalization import min_max_normalize + + +EXTENSIONS = ['.jpg', '.exr'] + + +# Max allowed intenity of depthmap images. Specified in metres. +# This value is chosen by analyzing max values throughout the dataset. +UPPER_DEPTH_INTENSITY_LIMIT = 8.0 + + + +@dataclass +class ClearGraspTrainDatasetConfig(configs.base.BaseDatasetConfig): + name: str = "ClearGraspTrainDataset" + load_size: Tuple[int, int] = (512, 256) + paired: bool = True # `True` for paired A-B. + require_domain_B_rgb: bool = False # Whether to fetch noisy RGB photo for domain B + + +class ClearGraspTrainDataset(Dataset): + """ + Multimodality dataset containing RGB photos, surface normalmaps and depthmaps. + Curated from Cleargrasp robot-vision dataset. + The domain translation task is: RGB + Normalmap --> Depthmap + """ + def __init__(self, conf): + + # self.mode = conf.mode + self.paired = conf[conf.mode].dataset.paired + self.require_domain_B_rgb = conf[conf.mode].dataset.require_domain_B_rgb + + rgb_dir = Path(conf[conf.mode].dataset.root) / "rgb" + normalmap_dir = Path(conf[conf.mode].dataset.root) / "normal" + depthmap_dir = Path(conf[conf.mode].dataset.root) / "depth" + + self.image_paths = {'RGB': [], 'normalmap': [], 'depthmap': []} + self.image_paths['RGB'] = make_dataset_of_files(rgb_dir, EXTENSIONS) + self.image_paths['normalmap'] = make_dataset_of_files(normalmap_dir, EXTENSIONS) + self.image_paths['depthmap'] = make_dataset_of_files(depthmap_dir, EXTENSIONS) + self.dataset_size = len(self.image_paths['RGB']) + + self.load_size = conf[conf.mode].dataset.load_size + self.load_resize_transform = transforms.Resize( + size=(self.load_size[1], self.load_size[0]), interpolation=transforms.InterpolationMode.BICUBIC + ) + + # Clipping ranges + self.rgb_min, self.rgb_max = 0.0, 255.0 + self.normalmap_min, self.normalmap_max = -1.0, 1.0 + self.depthmap_min, self.depthmap_max = 0.0, UPPER_DEPTH_INTENSITY_LIMIT + + + def __len__(self): + return self.dataset_size + + + def __getitem__(self, index): + + # ------------ + # Fetch images + + index_A = index % self.dataset_size + index_B = index_A if self.paired else random.randint(0, self.dataset_size - 1) + index_A, index_B = 9, 1 ## + + image_path_A, image_path_B = {}, {} + image_path_A['RGB'] = self.image_paths['RGB'][index_A] + image_path_A['normalmap'] = self.image_paths['normalmap'][index_A] + image_path_B['depthmap'] = self.image_paths['depthmap'][index_B] + if self.require_domain_B_rgb: + image_path_B['RGB'] = self.image_paths['RGB'][index_B] + + images_A, images_B = {}, {} + images_A['RGB'] = read_rgb_to_tensor(image_path_A['RGB']) + images_A['normalmap'] = read_normalmap_to_tensor(image_path_A['normalmap']) + images_B['depthmap'] = read_depthmap_to_tensor(image_path_B['depthmap']) + if self.require_domain_B_rgb: + images_B['RGB'] = read_rgb_to_tensor(image_path_B['RGB']) + + + # ------ + # Resize + + for k in images_A.keys(): + images_A[k] = self.load_resize_transform(images_A[k]) + for k in images_B.keys(): + images_B[k] = self.load_resize_transform(images_B[k]) + + + # --------- + # Transform + + images_A, images_B = self.apply_transforms(images_A, images_B) + + + # ------------- + # Normalization + + # Clip and then rescale all intensties to range [-1, 1] + # Normalmap is already in this scale. + images_A['RGB'] = clip_and_min_max_normalize(images_A['RGB'], self.rgb_min, self.rgb_max) + images_A['normalmap'] = torch.clamp(images_A['normalmap'], self.normalmap_min, self.normalmap_max) + images_B['depthmap'] = clip_and_min_max_normalize(images_B['depthmap'], self.depthmap_min, self.depthmap_max) + if self.require_domain_B_rgb: + images_B['RGB'] = clip_and_min_max_normalize(images_B['RGB'], self.rgb_min, self.rgb_max) + + + # ------------------------- + # Add noise in domain-B RGB + + if self.require_domain_B_rgb: + images_B['RGB'] = images_B['RGB'] + torch.normal(mean=0, std=0.05, size=(self.load_size[1], self.load_size[0])) + images_B['RGB'] = torch.clamp(images_B['RGB'], -1, 1) # Clip to remove out-of-range overshoots + + + # --------------------- + # Construct sample dict + + # A and B need to have dims (C,D,H,W) + A = torch.cat([images_A['RGB'], images_A['normalmap']], dim=0) + if self.require_domain_B_rgb: + B = torch.cat([images_B['RGB'], images_B['depthmap']], dim=0) + else: + B = images_B['depthmap'] + + sample_dict = {'A': A, 'B': B} + + return sample_dict + + + + def apply_transforms(self, images_A, images_B): + """ + TODO: What transform to use for augmentation? + Cannot naively apply random flip and crop, would mess up the normalmap and depthmap info, resp. + Maybe flipping + changing normalmap colour mapping (by changing order of its RGB channels) + """ + return images_A, images_B + + + + +def read_rgb_to_tensor(path): + """ + RGB reader based on cv2.imread(). + Just for consistency with normalmap and depthmap readers. + """ + bgr_img = cv2.imread(str(path)) + rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + rgb_img = rgb_img.transpose(2,0,1) # (H,W,C) to (C,H,W) + return torch.tensor(rgb_img, dtype=torch.float32) + +def read_normalmap_to_tensor(path): + """ + Read normalmap image from EXR format to tensor of form (3,H,W) + """ + normalmap = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + normalmap = cv2.cvtColor(normalmap, cv2.COLOR_BGR2RGB) + normalmap = normalmap.transpose(2,0,1) # (H,W,C) to (C,H,W) + return torch.tensor(normalmap, dtype=torch.float32) + +def read_depthmap_to_tensor(path): + """ + Read depthmap image from EXR format to tensor of form (1,H,W) + """ + depthmap = cv2.imread(str(path), cv2.IMREAD_ANYDEPTH) + depthmap = np.expand_dims(depthmap, axis=0) # (H,W) to (1,H,W) + return torch.tensor(depthmap, dtype=torch.float32) + + +def clip_and_min_max_normalize(tensor, min_value, max_value): + tensor = torch.clamp(tensor, min_value, max_value) + tensor = min_max_normalize(tensor, min_value, max_value) + return tensor diff --git a/projects/cleargrasp_depth_estimation/datasets/val_test_dataset.py b/projects/cleargrasp_depth_estimation/datasets/val_test_dataset.py new file mode 100644 index 00000000..fe306a5a --- /dev/null +++ b/projects/cleargrasp_depth_estimation/datasets/val_test_dataset.py @@ -0,0 +1,210 @@ +import os +import random +from pathlib import Path +from typing import Tuple + +import glob +import numpy as np +import cv2 +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF + +from midaGAN.utils.io import make_dataset_of_files + +# Config imports +from dataclasses import dataclass +from midaGAN import configs + +from midaGAN.data.utils.normalization import min_max_normalize, min_max_denormalize + + +EXTENSIONS = ['.jpg', '.exr'] + + +# Max allowed intenity of depthmap images. Specified in metres. +# This value is chosen by analyzing max values throughout the dataset. +UPPER_DEPTH_INTENSITY_LIMIT = 8.0 + + + +@dataclass +class ClearGraspValTestDatasetConfig(configs.base.BaseDatasetConfig): + """ + Note: Val dataset is paired, and does not supply RGB in domain-B + """ + name: str = "ClearGraspValTestDataset" + load_size: Tuple[int, int] = (512, 256) + model_is_cyclegan_balanced: bool = False + + +class ClearGraspValTestDataset(Dataset): + """ + Multimodality dataset containing RGB photos, surface normalmaps and depthmaps. + Curated from Cleargrasp robot-vision dataset. + The domain translation task is: RGB + Normalmap --> Depthmap + """ + def __init__(self, conf): + + rgb_dir = Path(conf[conf.mode].dataset.root) / "rgb" + normalmap_dir = Path(conf[conf.mode].dataset.root) / "normal" + depthmap_dir = Path(conf[conf.mode].dataset.root) / "depth" + + self.image_paths = {'RGB': [], 'normalmap': [], 'depthmap': []} + self.image_paths['RGB'] = make_dataset_of_files(rgb_dir, EXTENSIONS) + self.image_paths['normalmap'] = make_dataset_of_files(normalmap_dir, EXTENSIONS) + self.image_paths['depthmap'] = make_dataset_of_files(depthmap_dir, EXTENSIONS) + self.dataset_size = len(self.image_paths['RGB']) + + self.sample_ids = ['-'.join(str(path).split('/')[-1].split('.')[0].split('-')[:-1]) \ + for path in self.image_paths['RGB']] + + self.load_size = conf[conf.mode].dataset.load_size + self.load_resize_transform = transforms.Resize( + size=(self.load_size[1], self.load_size[0]), interpolation=transforms.InterpolationMode.BICUBIC + ) + + # Clipping ranges + self.rgb_min, self.rgb_max = 0.0, 255.0 + self.normalmap_min, self.normalmap_max = -1.0, 1.0 + self.depthmap_min, self.depthmap_max = 0.0, UPPER_DEPTH_INTENSITY_LIMIT + + # Using Cyclegan-balanced (v3) ? + self.model_is_cyclegan_balanced = conf[conf.mode].dataset.model_is_cyclegan_balanced + + + def __len__(self): + return self.dataset_size + + + def __getitem__(self, index): + + # ------------ + # Fetch images + + image_path = {} + image_path['RGB'] = self.image_paths['RGB'][index] + image_path['normalmap'] = self.image_paths['normalmap'][index] + image_path['depthmap'] = self.image_paths['depthmap'][index] + + images = {} + images['RGB'] = read_rgb_to_tensor(image_path['RGB']) + images['normalmap'] = read_normalmap_to_tensor(image_path['normalmap']) + images['depthmap'] = read_depthmap_to_tensor(image_path['depthmap']) + + + # Store the sample ID, need while saving the predicted image + metadata = { + 'sample_id': self.sample_ids[index] + } + + + # ------ + # Resize + + for k in images.keys(): + images[k] = self.load_resize_transform(images[k]) + + + # ------------- + # Normalization + + # Clip and then rescale all intensties to range [-1, 1] + # Normalmap is already in this scale. + images['RGB'] = clip_and_min_max_normalize(images['RGB'], self.rgb_min, self.rgb_max) + images['normalmap'] = torch.clamp(images['normalmap'], self.normalmap_min, self.normalmap_max) + images['depthmap'] = clip_and_min_max_normalize(images['depthmap'], self.depthmap_min, self.depthmap_max) + + + # --------------------- + # Construct sample dict + + # A and B need to have dims (C,D,H,W) + A = torch.cat([images['RGB'], images['normalmap']], dim=0) + + if self.model_is_cyclegan_balanced: + zeros_dummy = torch.zeros_like(images['RGB']) + B = torch.cat([zeros_dummy, images['depthmap']], dim=0) + else: + B = images['depthmap'] + + sample_dict = {'A': A, 'B': B} + + # Include meta data + sample_dict['metadata'] = metadata + + return sample_dict + + + + def denormalize(self, tensor): + """Allows the Tester and Validator to calculate the metrics in + the original range of values. + `tensor` can be either the predicted or the ground truth depthmap image tensor + """ + tensor = min_max_denormalize(tensor, self.depthmap_min, self.depthmap_max) + return tensor + + + def save(self, tensor, save_dir, metadata): + """ Save predicted tensors as EXR + """ + + # If the model is CycleGAN-balanced, tensor is 4-channel with the + # last channel containing depthmap and first 3 channels containing a dummy array. + if self.model_is_cyclegan_balanced: # Convert from (C,H,W) to (H,W) format + tensor = tensor[3] # (4,H,W) -> (H,W) + else: + tensor = tensor.squeeze() # (1,H,W) -> (H,W) + + # Rescale back to [self.depthmap_min, self.depthmap_max] + tensor = min_max_denormalize(tensor.cpu(), self.depthmap_min, self.depthmap_max) + + # Write to file + os.makedirs(save_dir, exist_ok=True) + sample_id = metadata['sample_id'] + save_path = f"{save_dir}/{sample_id}.exr" + write_depthmap_tensor_to_exr(tensor, save_path) + + + +def read_rgb_to_tensor(path): + """ + RGB reader based on cv2.imread(). + Just for consistency with normalmap and depthmap readers. + """ + bgr_img = cv2.imread(str(path)) + rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) + rgb_img = rgb_img.transpose(2,0,1) # (H,W,C) to (C,H,W) + return torch.tensor(rgb_img, dtype=torch.float32) + + +def read_normalmap_to_tensor(path): + """ + Read normalmap image from EXR format to tensor of form (3,H,W) + """ + normalmap = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + normalmap = cv2.cvtColor(normalmap, cv2.COLOR_BGR2RGB) + normalmap = normalmap.transpose(2,0,1) # (H,W,C) to (C,H,W) + return torch.tensor(normalmap, dtype=torch.float32) + + +def read_depthmap_to_tensor(path): + """ + Read depthmap image from EXR format to tensor of form (1,H,W) + """ + depthmap = cv2.imread(str(path), cv2.IMREAD_ANYDEPTH) + depthmap = np.expand_dims(depthmap, axis=0) # (H,W) to (1,H,W) + return torch.tensor(depthmap, dtype=torch.float32) + + +def write_depthmap_tensor_to_exr(depthmap, path): + depthmap = depthmap.numpy() + cv2.imwrite(path, depthmap, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + + +def clip_and_min_max_normalize(tensor, min_value, max_value): + tensor = torch.clamp(tensor, min_value, max_value) + tensor = min_max_normalize(tensor, min_value, max_value) + return tensor diff --git a/projects/cleargrasp_depth_estimation/experiments/cyclegan_balanced.yaml b/projects/cleargrasp_depth_estimation/experiments/cyclegan_balanced.yaml new file mode 100644 index 00000000..f72d8345 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/experiments/cyclegan_balanced.yaml @@ -0,0 +1,75 @@ +project_dir: "./projects/cleargrasp_depth_estimation/" + + + +train: + output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_balanced/" + cuda: True + n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") + n_iters_decay: 62500 # Extra 62500 iters with lr decay + batch_size: 1 + mixed_precision: False + seed: 1 + + logging: + freq: 50 + multi_modality_split: + A: [3, 3] + B: [3, 1] + wandb: + project: "cleargrasp_depth_estimation" + run: "cyclegan_balanced" + + checkpointing: + freq: 25000 + # load_iter: 125000 ## + + dataset: + name: "ClearGraspTrainDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" + load_size: [512, 256] + paired: False # Unpaired + require_domain_B_rgb: True # Required here for cyclegan-balanced + num_workers: 8 + + gan: + name: "CycleGANMultiModalV3" + generator: + name: "Unet2D" + in_out_channels_AB: [6, 1] # RGB + Normal -> Depth + in_out_channels_BA: [4, 3] # RGB + Depth -> Normal + num_downs: 4 + ngf: 64 + use_dropout: True + + discriminator: + name: "PatchGAN2D" + in_channels_B: 1 + in_channels_A: 3 + n_layers: 3 + kernel_size: [4, 4] + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_AB: 10.0 + lambda_BA: 10.0 + lambda_identity: 0 + proportion_ssim: 0 + + metrics: + discriminator_evolution: True + ssim: False + + +val: + freq: 2500 + dataset: + name: "ClearGraspValTestDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" + load_size: [512, 256] + model_is_cyclegan_balanced: True # True + num_workers: 8 + metrics: + cycle_metrics: False diff --git a/projects/cleargrasp_depth_estimation/experiments/cyclegan_naive.yaml b/projects/cleargrasp_depth_estimation/experiments/cyclegan_naive.yaml new file mode 100644 index 00000000..17ccabf3 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/experiments/cyclegan_naive.yaml @@ -0,0 +1,74 @@ +project_dir: "./projects/cleargrasp_depth_estimation/" + +# output_dir: +train: + output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_naive/" + cuda: True + n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") + n_iters_decay: 62500 # Extra 62500 iters with lr decay + batch_size: 1 + mixed_precision: False + seed: 1 + + logging: + freq: 50 + multi_modality_split: + A: [3, 3] + B: [1] + wandb: + project: "cleargrasp_depth_estimation" + run: "cyclegan_naive" + + checkpointing: + freq: 25000 + # load_iter: 125000 ## + + dataset: + name: "ClearGraspTrainDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" + load_size: [512, 256] + paired: False # Unpaired + require_domain_B_rgb: False # Not required for cyclegan-naive + num_workers: 8 + + gan: + name: "CycleGAN" + generator: + name: "Unet2D" + in_out_channels_AB: [6, 1] # RGB + Normal -> Depth + in_out_channels_BA: [1, 6] # Depth -> Normal + RGB + num_downs: 4 + ngf: 64 + use_dropout: True + + discriminator: + name: "PatchGAN2D" + in_channels_B: 1 + in_channels_A: 6 + n_layers: 3 + kernel_size: [4, 4] + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_AB: 10.0 + lambda_BA: 10.0 + lambda_identity: 0 + proportion_ssim: 0 + + metrics: + discriminator_evolution: True + ssim: False + + +val: + freq: 2500 + dataset: + name: "ClearGraspValTestDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" + load_size: [512, 256] + model_is_cyclegan_balanced: False # False + num_workers: 8 + metrics: + cycle_metrics: False diff --git a/projects/cleargrasp_depth_estimation/experiments/cyclegan.yaml b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1.yaml similarity index 55% rename from projects/cleargrasp_depth_estimation/experiments/cyclegan.yaml rename to projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1.yaml index 51c0d08d..7812f263 100644 --- a/projects/cleargrasp_depth_estimation/experiments/cyclegan.yaml +++ b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1.yaml @@ -1,18 +1,21 @@ project_dir: "./projects/cleargrasp_depth_estimation/" train: - output_dir: "./checkpoints/cleargrasp_depth_estimation/" + output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_naive/" cuda: True - n_iters: 250000 # 2500 (images) x 100 ("epochs") - n_iters_decay: 5000 - batch_size: 2 + n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") + n_iters_decay: 62500 # Extra 125000 iters with lr decay + batch_size: 1 mixed_precision: False logging: - freq: 100 + freq: 50 + multi_modality_split: + A: [3, 3] + B: [1] wandb: project: "cleargrasp_depth_estimation" - run: "cyclegan_trial" + run: "cyclegan_naive" checkpointing: freq: 5000 @@ -22,27 +25,30 @@ train: root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" load_size: [512, 256] paired: False - shuffle: True + fetch_rgb_b: False # `False` for v1 num_workers: 8 - gan: # TODO: Add config after implementing custom MM cycleGAN + gan: name: "CycleGAN" generator: - name: "Unet2D" - in_channels: 6 - out_channels: 6 - num_downs: 5 + name: "Unet2D" + in_out_channels_AB: [6, 1] # RGB + Normal -> Depth + in_out_channels_BA: [1, 6] # Depth -> Normal + RGB + num_downs: 4 ngf: 64 + use_dropout: True discriminator: name: "PatchGAN2D" + in_channels_B: 1 + in_channels_A: 6 n_layers: 3 - in_channels: 3 + kernel_size: [4, 4] ndf: 64 optimizer: - lr_D: 0.0002 - lr_G: 0.0001 + lr_D: 0.0001 + lr_G: 0.0002 lambda_AB: 10.0 lambda_BA: 10.0 lambda_identity: 0 @@ -50,17 +56,17 @@ train: metrics: discriminator_evolution: True - ssim: True + ssim: False val: - freq: 5000 + freq: 2500 dataset: name: "ClearGraspCycleGANDataset" root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" load_size: [512, 256] paired: True - shuffle: False + fetch_rgb_b: False # `False` for v1 num_workers: 8 metrics: cycle_metrics: False diff --git a/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1_structure.yaml b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1_structure.yaml new file mode 100644 index 00000000..59def115 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v1_structure.yaml @@ -0,0 +1,71 @@ +project_dir: "./projects/cleargrasp_depth_estimation/" + +train: + output_dir: "./checkpoints/cleargrasp_cycleganv1_struct0_5/" + cuda: True + n_iters: 250000 # (2500 (images) / 1 (batch_size)) x 100 ("epochs") + n_iters_decay: 5000 # Extra 5000 iters with lr decay + batch_size: 1 + mixed_precision: False + + logging: + freq: 50 + multi_modality_split: + A: [3, 3] + B: [1] + wandb: + project: "cleargrasp_depth_estimation" + run: "cyclegan_v1_structure0.5" + + checkpointing: + freq: 10000 + + dataset: + name: "ClearGraspCycleGANDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/train" + load_size: [512, 256] + paired: False + fetch_rgb_b: False # `False` for v1 + num_workers: 8 + + gan: + name: "CycleGANMultiModalV1Structure" + generator: + name: "Unet2D" + in_out_channels_AB: [6, 1] + in_out_channels_BA: [1, 6] + num_downs: 5 + ngf: 64 + + discriminator: + name: "PatchGAN2D" + in_channels_B: 1 + in_channels_A: 6 + n_layers: 4 + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_AB: 10.0 + lambda_BA: 10.0 + lambda_identity: 0 + proportion_ssim: 0 + lambda_structure: 0.5 + + metrics: + discriminator_evolution: True + ssim: True + + +val: + freq: 5000 + dataset: + name: "ClearGraspCycleGANDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/val" + load_size: [512, 256] + fetch_rgb_b: False # `False` for v1 + paired: True + num_workers: 8 + metrics: + cycle_metrics: True diff --git a/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v2.yaml b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v2.yaml new file mode 100644 index 00000000..0bd489b6 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v2.yaml @@ -0,0 +1,71 @@ +project_dir: "./projects/cleargrasp_depth_estimation/" + +train: + output_dir: "./checkpoints/cleargrasp_cycleganv2_5_struct0_5/" + cuda: True + n_iters: 250000 # (2500 (images) / 2 (batch_size)) x 200 ("epochs") + n_iters_decay: 5000 # Extra 5000 iters with lr decay + batch_size: 1 + mixed_precision: False + + logging: + freq: 50 + multi_modality_split: + A: [3, 3] + B: [3, 1] + wandb: + project: "cleargrasp_depth_estimation" + run: "cyclegan_v2.5_structure0.5" + + checkpointing: + freq: 10000 + + dataset: + name: "ClearGraspCycleGANDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/train" + load_size: [512, 256] + paired: False + fetch_rgb_b: True # `True` for v2 + num_workers: 8 + + gan: + name: "CycleGANMultiModalV2" + generator: + name: "Unet2D" + in_out_channels_AB: [6, 4] + in_out_channels_BA: [4, 6] + num_downs: 5 + ngf: 64 + + discriminator: + name: "PatchGAN2D" + in_channels_B: 4 + in_channels_A: 6 + n_layers: 4 + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_AB: 10.0 + lambda_BA: 10.0 + lambda_identity: 0 + proportion_ssim: 0 + lambda_structure: 0.5 + + metrics: + discriminator_evolution: True + ssim: True + + +val: + freq: 5000 + dataset: + name: "ClearGraspCycleGANDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/Cleargrasp_rgbnormal2depth_resized/val" + load_size: [512, 256] + fetch_rgb_b: True # `True` for v2 + paired: True + num_workers: 8 + metrics: + cycle_metrics: True diff --git a/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v3.yaml b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v3.yaml new file mode 100644 index 00000000..65c5d9fb --- /dev/null +++ b/projects/cleargrasp_depth_estimation/experiments/old/cyclegan_multimodal_v3.yaml @@ -0,0 +1,72 @@ +project_dir: "./projects/cleargrasp_depth_estimation/" + +train: + output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/cyclegan_balanced/" + cuda: True + n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") + n_iters_decay: 62500 # Extra 62500 iters with lr decay + batch_size: 1 + mixed_precision: False + + logging: + freq: 50 + multi_modality_split: + A: [3, 3] + B: [3, 1] + wandb: + project: "cleargrasp_depth_estimation" + run: "cyclegan_balanced" + + checkpointing: + freq: 5000 + + dataset: + name: "ClearGraspCycleGANDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" + load_size: [512, 256] + paired: False + fetch_rgb_b: True # `True` for v3 + num_workers: 8 + + gan: + name: "CycleGANMultiModalV3" + generator: + name: "Unet2D" + in_out_channels_AB: [6, 1] # RGB + Normal -> Depth + in_out_channels_BA: [4, 3] # RGB + Depth -> Normal + num_downs: 4 + ngf: 64 + use_dropout: True + + discriminator: + name: "PatchGAN2D" + in_channels_B: 1 + in_channels_A: 3 + n_layers: 3 + kernel_size: [4, 4] + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_AB: 10.0 + lambda_BA: 10.0 + lambda_identity: 0 + proportion_ssim: 0 + + metrics: + discriminator_evolution: True + ssim: False + + +val: + freq: 2500 + dataset: + name: "ClearGraspCycleGANDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" + load_size: [512, 256] + fetch_rgb_b: True # `True` for v3 + paired: True + num_workers: 8 + metrics: + cycle_metrics: False diff --git a/projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml b/projects/cleargrasp_depth_estimation/experiments/old/pix2pix.yaml similarity index 66% rename from projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml rename to projects/cleargrasp_depth_estimation/experiments/old/pix2pix.yaml index 30a6b12e..2663aa1c 100644 --- a/projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml +++ b/projects/cleargrasp_depth_estimation/experiments/old/pix2pix.yaml @@ -1,21 +1,21 @@ project_dir: "./projects/cleargrasp_depth_estimation/" train: - output_dir: "./checkpoints/cleargrasp_depth_estimation/" + output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/pix2pix_lambda100/" cuda: True - n_iters: 250000 # (2500 (images) / batch_size) x 200 ("epochs") - n_iters_decay: 5000 - batch_size: 2 + n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") + n_iters_decay: 62500 # Extra 62500 iters with lr decay + batch_size: 1 mixed_precision: False logging: - freq: 100 + freq: 50 multi_modality_split: - A: [3,3] + A: [3, 3] B: [1] wandb: project: "cleargrasp_depth_estimation" - run: "pix2pix_6layer_patchgan" + run: "pix2pix_lambda_100" checkpointing: freq: 5000 @@ -24,43 +24,40 @@ train: name: "ClearGraspPix2PixDataset" root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" load_size: [512, 256] - paired: True num_workers: 8 gan: name: "Pix2PixConditionalGAN" generator: name: "Unet2D" - in_channels: 6 - out_channels: 1 - num_downs: 6 + in_out_channels: [6, 1] + num_downs: 4 ngf: 64 use_dropout: True discriminator: name: "PatchGAN2D" in_channels: 7 - n_layers: 6 - kernel_size: [4,4] - ndf: 32 + n_layers: 3 + kernel_size: [4, 4] + ndf: 64 optimizer: - lr_D: 0.00005 + lr_D: 0.0001 lr_G: 0.0002 - lambda_pix2pix: 10.0 + lambda_pix2pix: 100 metrics: discriminator_evolution: True - ssim: True + ssim: False val: - freq: 5000 + freq: 2500 dataset: name: "ClearGraspPix2PixDataset" root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" load_size: [512, 256] - paired: True num_workers: 8 metrics: cycle_metrics: False diff --git a/projects/cleargrasp_depth_estimation/experiments/pix2pix_new.yaml b/projects/cleargrasp_depth_estimation/experiments/pix2pix_new.yaml new file mode 100644 index 00000000..cb95b5e6 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/experiments/pix2pix_new.yaml @@ -0,0 +1,67 @@ +project_dir: "./projects/cleargrasp_depth_estimation/" + +train: + output_dir: "/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/pix2pix_lambda100/" + cuda: True + n_iters: 62500 # (2500 (images) / 1 (batch_size)) x 25 ("epochs") + n_iters_decay: 62500 # Extra 62500 iters with lr decay + batch_size: 1 + mixed_precision: False + seed: 1 + + logging: + freq: 50 + multi_modality_split: + A: [3, 3] + B: [1] + wandb: + project: "cleargrasp_depth_estimation" + run: "pix2pix_lambda100" + + checkpointing: + freq: 25000 + + dataset: + name: "ClearGraspTrainDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/train" + load_size: [512, 256] + paired: True + require_domain_B_rgb: False + num_workers: 8 + + gan: + name: "Pix2PixConditionalGAN" + generator: + name: "Unet2D" + in_out_channels: [6, 1] + num_downs: 4 + ngf: 64 + use_dropout: True + + discriminator: + name: "PatchGAN2D" + in_channels: 7 + n_layers: 3 + kernel_size: [4, 4] + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_pix2pix: 100 + + metrics: + discriminator_evolution: True + ssim: False + + +val: + freq: 2500 + dataset: + name: "ClearGraspValTestDataset" + root: "/home/zk315372/Chinmay/Datasets/Cleargrasp_rgbnormal2depth_resized/val" + load_size: [512, 256] + model_is_cyclegan_balanced: False + num_workers: 8 + metrics: + cycle_metrics: False diff --git a/projects/cleargrasp_depth_estimation/jobscript.sh b/projects/cleargrasp_depth_estimation/jobscript.sh index da3bdedf..8e91f10c 100644 --- a/projects/cleargrasp_depth_estimation/jobscript.sh +++ b/projects/cleargrasp_depth_estimation/jobscript.sh @@ -3,8 +3,8 @@ # Job configuration --- -#SBATCH --job-name=cleargrasp_depth_estimation_pix2pix -#SBATCH --output=/home/zk315372/Chinmay/Git/midaGAN/projects/cleargrasp_depth_estimation/slurm_logs/%j.log +#SBATCH --job-name=pix2pix_lambda100 +#SBATCH --output=/home/zk315372/Chinmay/Cleargrasp-Depth-Estimation/pix2pix_lambda100/slurm-%j.log ## OpenMP settings #SBATCH --cpus-per-task=8 @@ -13,7 +13,7 @@ ## Request for a node with 2 Tesla P100 GPUs #SBATCH --gres=gpu:pascal:2 -#SBATCH --time=5:00:00 +#SBATCH --time=4:00:00 ## TO use the UM DKE project account # #SBATCH --account=um_dke @@ -30,7 +30,7 @@ echo; echo # Execute training python_interpreter="/home/zk315372/miniconda3/envs/gan_env/bin/python3" training_file="/home/zk315372/Chinmay/Git/midaGAN/tools/train.py" -config_file="/home/zk315372/Chinmay/Git/midaGAN/projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml" +config_file="/home/zk315372/Chinmay/Git/midaGAN/projects/cleargrasp_depth_estimation/experiments/pix2pix_new.yaml" CUDA_VISIBLE_DEVICES=0 $python_interpreter $training_file config=$config_file @@ -40,4 +40,4 @@ CUDA_VISIBLE_DEVICES=0 $python_interpreter $training_file config=$config_file # CUDA_VISIBLE_DEVICES=0 python tools/train.py config="./projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml" # Run distributed example: -# python -m torch.distributed.launch --use_env --nproc_per_node 2 tools/train.py config="./projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml" \ No newline at end of file +# python -m torch.distributed.launch --use_env --nproc_per_node 2 tools/train.py config="./projects/cleargrasp_depth_estimation/experiments/pix2pix.yaml" diff --git a/projects/cleargrasp_depth_estimation/modules/__init__.py b/projects/cleargrasp_depth_estimation/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/cleargrasp_depth_estimation/modules/cyclegan_losses_for_v3.py b/projects/cleargrasp_depth_estimation/modules/cyclegan_losses_for_v3.py new file mode 100644 index 00000000..5c465679 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/modules/cyclegan_losses_for_v3.py @@ -0,0 +1,35 @@ +import math +import torch + +from midaGAN.nn.losses import cyclegan_losses + + +class CycleGANLossesForV3(cyclegan_losses.CycleGANLosses): + """ Modified to make Cycle-consitency account for only + Normalmap images (in domain A) and depthmap images (in domain B), + and ignore RGB """ + + def __init__(self, conf): + self.lambda_AB = conf.train.gan.optimizer.lambda_AB + self.lambda_BA = conf.train.gan.optimizer.lambda_BA + + lambda_identity = conf.train.gan.optimizer.lambda_identity + proportion_ssim = conf.train.gan.optimizer.proportion_ssim + + # Cycle-consistency - L1, with optional weighted combination with SSIM + self.criterion_cycle = cyclegan_losses.CycleLoss(proportion_ssim) + + + def __call__(self, visuals): + # Separate out the normalmap and depthmap parts from the visuals tensors + real_A2, real_B2 = visuals['real_A'][:, 3:], visuals['real_B'][:, 3:] + fake_A2, fake_B2 = visuals['fake_A'][:, 3:], visuals['fake_B'][:, 3:] + rec_A2, rec_B2 = visuals['rec_A'][:, 3:], visuals['rec_B'][:, 3:] + + losses = {} + + # cycle-consistency loss + losses['cycle_A'] = self.lambda_AB * self.criterion_cycle(real_A2, rec_A2) + losses['cycle_B'] = self.lambda_BA * self.criterion_cycle(real_B2, rec_B2) + + return losses diff --git a/projects/cleargrasp_depth_estimation/modules/cyclegan_multimodal_v3.py b/projects/cleargrasp_depth_estimation/modules/cyclegan_multimodal_v3.py new file mode 100644 index 00000000..25286fd5 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/modules/cyclegan_multimodal_v3.py @@ -0,0 +1,125 @@ +from dataclasses import dataclass + +import torch + +from midaGAN.nn.gans.unpaired import cyclegan +from midaGAN.nn.losses.adversarial_loss import AdversarialLoss + +from projects.cleargrasp_depth_estimation.modules.cyclegan_losses_for_v3 \ + import CycleGANLossesForV3 + + +@dataclass +class CycleGANMultiModalV3Config(cyclegan.CycleGANConfig): + """ CycleGANMultiModalV3 Config """ + name: str = "CycleGANMultiModalV3" + + +class CycleGANMultiModalV3(cyclegan.CycleGAN): + """ CycleGAN for multimodal images -- Version 3 + a.k.a CycleGAN-balanced + + Notation: + A1, A2 -- rgb_A, normalmap + B1, B2 -- rgb_B, depthmap + """ + + def __init__(self, conf): + super().__init__(conf) + + def init_criterions(self): + # Standard GAN loss + self.criterion_adv = AdversarialLoss( + self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) + # Generator-related losses -- Cycle-consistency and Identity loss + self.criterion_G = CycleGANLossesForV3(self.conf) + + def forward(self): + """Run forward pass; called by both methods and .""" + real_A = self.visuals['real_A'] + real_B = self.visuals['real_B'] + + # Forward cycle G_AB (A to B) + fake_B2 = self.networks['G_AB'](real_A) # Compute depthmap + real_A1 = real_A[:, :3] # Get rgb_A + rec_A2 = self.networks['G_BA'](torch.cat([real_A1, fake_B2], dim=1)) # Compute normalmap recon. + + # Backward cycle G_BA (B to A) + fake_A2 = self.networks['G_BA'](real_B) # Compute normalmap + real_B1 = real_B[:, :3] # Get rgb_B + rec_B2 = self.networks['G_AB'](torch.cat([real_B1, fake_A2], dim=1)) # Compute depthmap recon. + + # Hack -- Use dummy zeros arrays to fill up the channels of rgb components + dummy_array = torch.zeros_like(real_A1) + self.visuals.update({ + 'fake_B': torch.cat([dummy_array, fake_B2], dim=1), + 'rec_A': torch.cat([dummy_array, rec_A2], dim=1), + 'fake_A': torch.cat([dummy_array, fake_A2], dim=1), + 'rec_B': torch.cat([dummy_array, rec_B2], dim=1), + }) + + def backward_D(self, discriminator): + """Calculate GAN loss for the discriminator""" + # D_B only evaluates depthmap + if discriminator == 'D_B': + real = self.visuals['real_B'][:, 3:] + fake = self.visuals['fake_B'][:, 3:] + fake = self.fake_B_pool.query(fake) + loss_id = 0 + + # D_A only evaluates normalmap + elif discriminator == 'D_A': + real = self.visuals['real_A'][:, 3:] + fake = self.visuals['fake_A'][:, 3:] + fake = self.fake_A_pool.query(fake) + loss_id = 1 + else: + raise ValueError('The discriminator has to be either "D_A" or "D_B".') + + self.pred_real = self.networks[discriminator](real) + + # Detaching fake: https://github.com/pytorch/examples/issues/116 + self.pred_fake = self.networks[discriminator](fake.detach()) + + loss_real = self.criterion_adv(self.pred_real, target_is_real=True) + loss_fake = self.criterion_adv(self.pred_fake, target_is_real=False) + self.losses[discriminator] = loss_real + loss_fake + + # backprop + self.backward(loss=self.losses[discriminator], optimizer=self.optimizers['D'], loss_id=2) + + def backward_G(self): + """Calculate the loss for generators G_AB and G_BA using all specified losses""" + # Get depthmap and normalmap + fake_B2 = self.visuals['fake_B'][:, 3:] # G_AB(A) + fake_A2 = self.visuals['fake_A'][:, 3:] # G_BA(B) + + # ------------------------- GAN Loss ---------------------------- + pred_B = self.networks['D_B'](fake_B2) # D_B(G_AB(A)) + pred_A = self.networks['D_A'](fake_A2) # D_A(G_BA(B)) + + # Forward GAN loss D_A(G_AB(A)) + self.losses['G_AB'] = self.criterion_adv(pred_B, target_is_real=True) + # Backward GAN loss D_B(G_BA(B)) + self.losses['G_BA'] = self.criterion_adv(pred_A, target_is_real=True) + # --------------------------------------------------------------- + + # ------------- G Losses (Cycle, Identity) ------------- + losses_G = self.criterion_G(self.visuals) + self.losses.update(losses_G) + # --------------------------------------------------------------- + + # combine losses and calculate gradients + combined_loss_G = sum(losses_G.values()) + self.losses['G_AB'] + self.losses['G_BA'] + self.backward(loss=combined_loss_G, optimizer=self.optimizers['G'], loss_id=0) + + + + def infer(self, input, direction='AB'): + assert direction in ['AB', 'BA'], "Specify which generator direction, AB or BA, to use." + assert f'G_{direction}' in self.networks.keys() + + with torch.no_grad(): + fake_B2 = self.networks[f'G_{direction}'](input) + real_A1 = input[:, :3] + return torch.cat([torch.zeros_like(real_A1), fake_B2], dim=1) \ No newline at end of file diff --git a/projects/cleargrasp_depth_estimation/modules/old/cyclegan_losses_with_structure.py b/projects/cleargrasp_depth_estimation/modules/old/cyclegan_losses_with_structure.py new file mode 100644 index 00000000..5f28b8b5 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/modules/old/cyclegan_losses_with_structure.py @@ -0,0 +1,182 @@ +import math +import torch + +from midaGAN.nn.losses import cyclegan_losses + + +MIND_DESCRIPTOR_CONFIG = {'non_local_region_size': 9, 'patch_size': 7, 'neighbor_size': 3, 'gaussian_patch_sigma': 2.0} + + +class CycleGANLossesWithStructure(cyclegan_losses.CycleGANLosses): + """ Additonal constraint: Structure-consistency loss """ + + def __init__(self, conf, cyclegan_design_version): + super().__init__(conf) + + lambda_structure = conf.train.gan.optimizer.lambda_structure + use_cuda = conf.train.cuda + self.cyclegan_design_version = cyclegan_design_version + + if lambda_structure > 0: + self.criterion_structure = StructureLoss(lambda_structure, use_cuda, cyclegan_design_version) + else: + self.criterion_structure = None + + def __call__(self, visuals): + real_A, real_B = visuals['real_A'], visuals['real_B'] + fake_A, fake_B = visuals['fake_A'], visuals['fake_B'] + + losses = super().__call__(visuals) + + # A2-B2 structure-consistency loss + if self.criterion_structure is not None: + # MINDLoss(G_AB(real_A)[Depthmap], real_A[Normalmap]) + losses['structure_AB'] = self.lambda_AB * self.criterion_structure(real_A, fake_B) + # MINDLoss(G_BA(real_B)[Normalmap], real_B[Depthmap]) + losses['structure_BA'] = self.lambda_BA * self.criterion_structure(real_B, fake_A) + + return losses + + +class StructureLoss: + """ + Structure-consistency loss -- Yang et al. (2018) - Unpaired Brain MR-to-CT Synthesis using a Structure-Constrained CycleGAN + Applied here in 2 different ways: + v1 -- Between A2 component (normalmap) and B (depthmap) + v2 -- Between A2 (normalmap )and B2 (depthmap) components + """ + def __init__(self, lambda_structure, use_cuda, cyclegan_design_version): + self.lambda_structure = lambda_structure + self.cyclegan_design_version = cyclegan_design_version + + self.nl_size = MIND_DESCRIPTOR_CONFIG['non_local_region_size'] + self.mind_descriptor = MINDDescriptor(**MIND_DESCRIPTOR_CONFIG) + if use_cuda: + self.mind_descriptor = self.mind_descriptor.cuda() + + def __call__(self, input_, fake): + + if self.cyclegan_design_version == 'v1': + # Get depthmap and normalmap from appropriate tensors + if fake.shape[1] == 1: + depthmap = fake + normalmap = input_[:, 3:] + elif fake.shape[1] == 6: + depthmap = input_ + normalmap = fake[:, 3:] + + # Convert normalmap to single channel grayscale (descriptor requires this) + normalmap = normalmap.mean(dim=1).unsqueeze(dim=1) + + # Extract MIND features + depthmap_features = self.mind_descriptor(depthmap) + normalmap_features = self.mind_descriptor(normalmap) + feature_diff = depthmap_features - normalmap_features + + elif self.cyclegan_design_version == 'v2': + # Take the 2nd modality from each tensor + input_2nd_mod, fake_2nd_mod = input_[:, 3:], fake[:, 3:] + + # Convert to single channel grayscale (descriptor requires this) + input_2nd_mod = input_2nd_mod.mean(dim=1).unsqueeze(dim=1) + fake_2nd_mod = fake_2nd_mod.mean(dim=1).unsqueeze(dim=1) + + # Extract MIND features + input_features = self.mind_descriptor(input_2nd_mod) + fake_features = self.mind_descriptor(fake_2nd_mod) + feature_diff = input_features - fake_features + + # Compute loss + l1_distance = torch.norm(feature_diff, 1) + loss_structure = l1_distance / (input_.shape[2] * input_.shape[3] * self.nl_size * self.nl_size) + return loss_structure * self.lambda_structure + + +class MINDDescriptor(torch.nn.Module): + """ + Taken from the public repository -- https://github.com/tomosu/MIND-pytorch. + Minor changes made in style for better readability. + """ + def __init__(self, non_local_region_size=9, patch_size=7, neighbor_size=3, gaussian_patch_sigma=3.0): + super().__init__() + self.nl_size = non_local_region_size + self.p_size = patch_size + self.n_size = neighbor_size + self.sigma2 = gaussian_patch_sigma * gaussian_patch_sigma + + # Calculate shifted images in non-local region + self.image_shifter = torch.nn.Conv2d(in_channels=1, out_channels=self.nl_size * self.nl_size, + kernel_size=(self.nl_size, self.nl_size), + stride=1, padding=((self.nl_size - 1) // 2, (self.nl_size - 1) // 2), + dilation=1, groups=1, bias=False, padding_mode='zeros') + + for i in range(self.nl_size * self.nl_size): + t = torch.zeros((1, self.nl_size, self.nl_size)) + t[0, i % self.nl_size, i // self.nl_size] = 1 + self.image_shifter.weight.data[i] = t + + # Patch summation + self.summation_patcher = torch.nn.Conv2d(in_channels=self.nl_size * self.nl_size, out_channels=self.nl_size * self.nl_size, + kernel_size=(self.p_size, self.p_size), + stride=1, padding=((self.p_size - 1) // 2, (self.p_size - 1) // 2), + dilation=1, groups=self.nl_size * self.nl_size, bias=False, padding_mode='zeros') + + for i in range(self.nl_size * self.nl_size): + # Gaussian kernel + t = torch.zeros((1, self.p_size, self.p_size)) + cx = (self.p_size - 1) // 2 + cy = (self.p_size - 1) // 2 + for j in range(self.p_size * self.p_size): + x = j % self.p_size + y = j // self.p_size + d2 = torch.norm(torch.tensor([x - cx, y - cy]).float(), 2) + t[0, x, y] = math.exp(-d2 / self.sigma2) + + self.summation_patcher.weight.data[i] = t + + # Neighbor images + self.neighbors = torch.nn.Conv2d(in_channels=1, out_channels=self.n_size * self.n_size, + kernel_size=(self.n_size, self.n_size), + stride=1, padding=((self.n_size - 1) // 2, (self.n_size - 1) // 2), + dilation=1, groups=1, bias=False, padding_mode='zeros') + + for i in range(self.n_size*self.n_size): + t = torch.zeros((1, self.n_size, self.n_size)) + t[0, i % self.n_size, i // self.n_size] = 1 + self.neighbors.weight.data[i] = t + + # Neighbor patcher + self.neighbor_summation_patcher = torch.nn.Conv2d(in_channels=self.n_size * self.n_size, out_channels=self.n_size * self.n_size, + kernel_size=(self.p_size, self.p_size), + stride=1, padding=((self.p_size - 1) // 2, (self.p_size - 1) // 2), + dilation=1, groups=self.n_size*self.n_size, bias=False, padding_mode='zeros') + + for i in range(self.n_size * self.n_size): + t = torch.ones((1, self.p_size, self.p_size)) + self.neighbor_summation_patcher.weight.data[i] = t + + def forward(self, orig): + assert len(orig.shape) == 4 + assert orig.shape[1] == 1 + + # Get original image channel stack + orig_stack = torch.stack([orig.squeeze(dim=1) for i in range(self.nl_size * self.nl_size)], dim=1) + + # Get shifted images + shifted = self.image_shifter(orig) + + # Get image diff + diff_images = shifted - orig_stack + + # L2 norm of image diff + Dx_alpha = self.summation_patcher(torch.pow(diff_images, 2.0)) + + # Calculate neighbors' variance + neighbor_images = self.neighbor_summation_patcher(self.neighbors(orig)) + Vx = neighbor_images.var(dim=1).unsqueeze(dim=1) + + # Output MIND features + numerator = torch.exp(- Dx_alpha / (Vx + 1e-8)) + denominator = numerator.sum(dim=1).unsqueeze(dim=1) + mind_features = numerator / denominator + return mind_features \ No newline at end of file diff --git a/projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v1_structure.py b/projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v1_structure.py new file mode 100644 index 00000000..3c7d5e74 --- /dev/null +++ b/projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v1_structure.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass + +import torch + +from midaGAN.nn.gans.unpaired import cyclegan +from midaGAN.nn.losses.adversarial_loss import AdversarialLoss + +from projects.cleargrasp_depth_estimation.modules.cyclegan_losses_with_structure import CycleGANLossesWithStructure + + +@dataclass +class OptimizerV1StructureConfig(cyclegan.OptimizerConfig): + """ Structure consistency config for CycleGAN multimodal v1 """ + lambda_structure: float = 0 + + +@dataclass +class CycleGANMultiModalV1StructureConfig(cyclegan.CycleGANConfig): + """ CycleGANMultiModalV1Structure Config """ + name: str = "CycleGANMultiModalV1Structure" + optimizer: OptimizerV1StructureConfig = OptimizerV1StructureConfig() + + +class CycleGANMultiModalV1Structure(cyclegan.CycleGAN): + """ """ + + def __init__(self, conf): + super().__init__(conf) + + # Additional losses used by the model + structure_loss_names = ['structure_AB', 'structure_BA'] + self.losses.update({name: None for name in structure_loss_names}) + + + def init_criterions(self): + # Standard GAN loss + self.criterion_adv = AdversarialLoss( + self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) + # G losses - Includes Structure-consistency loss + self.criterion_G = CycleGANLossesWithStructure(self.conf, cyclegan_design_version='v1') diff --git a/projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v2.py b/projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v2.py new file mode 100644 index 00000000..c44836fe --- /dev/null +++ b/projects/cleargrasp_depth_estimation/modules/old/cyclegan_multimodal_v2.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass + +import torch + +from midaGAN.nn.gans.unpaired import cyclegan +from midaGAN.nn.losses.adversarial_loss import AdversarialLoss + +from projects.cleargrasp_depth_estimation.modules.cyclegan_losses_with_structure import CycleGANLossesWithStructure + + +@dataclass +class OptimizerV2Config(cyclegan.OptimizerConfig): + """ Optimizer Config CycleGAN multimodal v2 """ + lambda_structure: float = 0 + + +@dataclass +class CycleGANMultiModalV2Config(cyclegan.CycleGANConfig): + """ CycleGANMultiModalV2 Config """ + name: str = "CycleGANMultiModalV2" + optimizer: OptimizerV2Config = OptimizerV2Config() + + +class CycleGANMultiModalV2(cyclegan.CycleGAN): + """ CycleGAN for multimodal images -- Version 2 """ + + def __init__(self, conf): + super().__init__(conf) + + # Additional losses used by the model + structure_loss_names = ['structure_AB', 'structure_BA'] + self.losses.update({name: None for name in structure_loss_names}) + + + def init_criterions(self): + # Standard GAN loss + self.criterion_adv = AdversarialLoss( + self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) + # G losses - Includes Structure-consistency loss + self.criterion_G = CycleGANLossesWithStructure(self.conf, cyclegan_design_version='v2') diff --git a/projects/horse2zebra/experiments/default.yaml b/projects/horse2zebra/experiments/default.yaml index 0d68e675..0da855f7 100644 --- a/projects/horse2zebra/experiments/default.yaml +++ b/projects/horse2zebra/experiments/default.yaml @@ -31,13 +31,14 @@ train: generator: name: "Resnet2D" n_residual_blocks: 9 - in_channels: 3 - out_channels: 3 + in_out_channels: + AB: [3, 3] discriminator: name: "PatchGAN2D" n_layers: 3 - in_channels: 3 + in_channels: + B: 3 optimizer: lambda_AB: 10.0 diff --git a/projects/maastro_hx4_pet_translation/datasets/__init__.py b/projects/maastro_hx4_pet_translation/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/maastro_hx4_pet_translation/datasets/train_dataset.py b/projects/maastro_hx4_pet_translation/datasets/train_dataset.py new file mode 100644 index 00000000..8f1d8652 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/datasets/train_dataset.py @@ -0,0 +1,190 @@ +""" +TODO list: +- What's a good way to use data augmentation ? +x Set proper value for `focal_region_proportion` +""" + +import os +import random +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from midaGAN import configs +from midaGAN.utils import sitk_utils + + +import projects.maastro_hx4_pet_translation.datasets.utils.patch_samplers as patch_samplers +from projects.maastro_hx4_pet_translation.datasets.utils.basic import (sitk2np, + np2tensor, + apply_body_mask, + clip_and_min_max_normalize) + + + +@dataclass +class HX4PETTranslationTrainDatasetConfig(configs.base.BaseDatasetConfig): + name: str = "HX4PETTranslationTrainDataset" + paired: bool = True # `True` only for Pix2Pix + require_ldct_for_training: bool = False # `True` only for HX4-CycleGAN-balanced + hu_range: Tuple[int, int] = (-1000, 2000) + fdg_suv_range: Tuple[float, float] = (0.0, 15.0) + hx4_tbr_range: Tuple[float, float] = (0.0, 3.0) + patch_size: Tuple[int, int, int] = (32, 128, 128) # DHW + patch_sampling: str = 'uniform-random-within-body' + # Focal region proportion only applies when training is unpaired + focal_region_proportion: Tuple[float, float, float] = (0.6, 0.3, 0.3) # DHW + +class HX4PETTranslationTrainDataset(Dataset): + + def __init__(self, conf): + + self.paired = conf.train.dataset.paired + self.require_ldct_for_training = conf.train.dataset.require_ldct_for_training + + # Image file paths + root_path = conf.train.dataset.root + self.patient_ids = sorted(os.listdir(root_path)) + + self.image_paths = {'FDG-PET': [], 'pCT': [], 'HX4-PET': [], 'body-mask-A': [], 'body-mask-B': []} + if self.require_ldct_for_training: + self.image_paths['ldCT'] = [] + + for p_id in self.patient_ids: + patient_image_paths = {} + + patient_image_paths['FDG-PET'] = f"{root_path}/{p_id}/fdg_pet.nrrd" + patient_image_paths['pCT'] = f"{root_path}/{p_id}/pct.nrrd" + patient_image_paths['body-mask-A'] = f"{root_path}/{p_id}/pct_body.nrrd" + + if self.paired: + # If paired, get HX4-PET-reg and use the pCT's body mask + patient_image_paths['HX4-PET'] = f"{root_path}/{p_id}/hx4_pet_reg.nrrd" + patient_image_paths['body-mask-B'] = patient_image_paths['body-mask-A'] + else: + # Else, get unregistered HX4-PET and use the ldCT's auto generated body mask + patient_image_paths['HX4-PET'] = f"{root_path}/{p_id}/hx4_pet.nrrd" + patient_image_paths['body-mask-B'] = f"{root_path}/{p_id}/ldct_body.nrrd" + + if self.require_ldct_for_training: + # If ldCT image is required to be fetched + patient_image_paths['ldCT'] = f"{root_path}/{p_id}/ldct.nrrd" + + for k in self.image_paths.keys(): + self.image_paths[k].append(patient_image_paths[k]) + + self.num_datapoints_A = len(self.image_paths['FDG-PET']) + self.num_datapoints_B = len(self.image_paths['HX4-PET']) + + # SUVmean_aorta values for normalizing HX4-PET SUV to TBR + suv_aorta_mean_file = f"{os.path.dirname(root_path)}/SUVmean_aorta_HX4.csv" + self.suv_aorta_mean_values = pd.read_csv(suv_aorta_mean_file, index_col=0) + self.suv_aorta_mean_values = self.suv_aorta_mean_values.to_dict()['HX4 aorta SUVmean baseline'] + + # Clipping ranges + self.hu_min, self.hu_max = conf.train.dataset.hu_range + self.fdg_suv_min, self.fdg_suv_max = conf.train.dataset.fdg_suv_range + self.hx4_tbr_min, self.hx4_tbr_max = conf.train.dataset.hx4_tbr_range + + # Patch sampler setup + patch_size = np.array(conf.train.dataset.patch_size) + patch_sampling = conf.train.dataset.patch_sampling + if self.paired: + self.patch_sampler = patch_samplers.PairedPatchSampler3D(patch_size, patch_sampling) + else: + focal_region_proportion = conf.train.dataset.focal_region_proportion + self.patch_sampler = patch_samplers.UnpairedPatchSampler3D(patch_size, patch_sampling, focal_region_proportion) + + + def __len__(self): + return max(self.num_datapoints_A, self.num_datapoints_B) + + + def __getitem__(self, index): + + # ------------ + # Fetch images + + index_A = index % self.num_datapoints_A + index_B = index_A if self.paired else random.randint(0, self.num_datapoints_B - 1) + + image_path_A, image_path_B = {}, {} + image_path_A['FDG-PET'] = self.image_paths['FDG-PET'][index_A] + image_path_A['pCT'] = self.image_paths['pCT'][index_A] + image_path_B['HX4-PET'] = self.image_paths['HX4-PET'][index_B] + + if self.require_ldct_for_training: + image_path_B['ldCT'] = self.image_paths['ldCT'][index_B] + + image_path_A['body-mask'] = self.image_paths['body-mask-A'][index_A] + image_path_B['body-mask'] = self.image_paths['body-mask-B'][index_B] + + # Load NRRD as SimpleITK objects (WHD) + images_A, images_B = {}, {} + for k in image_path_A.keys(): + images_A[k] = sitk_utils.load(image_path_A[k]) + for k in image_path_B.keys(): + images_B[k] = sitk_utils.load(image_path_B[k]) + + + # --------- + # Transform + # TODO: What's a good way to use data aug ? + + + # --------------- + # Apply body mask + + # Convert to numpy (DHW) + images_A = sitk2np(images_A) + images_B = sitk2np(images_B) + + images_A = apply_body_mask(images_A) + images_B = apply_body_mask(images_B) + + + # -------------- + # Sample patches + + # Get patches + images_A, images_B = self.patch_sampler.get_patch_pair(images_A, images_B) + + # Convert to tensors + images_A = np2tensor(images_A) + images_B = np2tensor(images_B) + + + # ------------- + # Normalization + + # Normalize HX4-PET SUVs with SUVmean_aorta + patient_id = self.patient_ids[index_B] + images_B['HX4-PET'] = images_B['HX4-PET'] / self.suv_aorta_mean_values[patient_id] + + # Clip and then rescale all intensties to range [-1, 1] + images_A['FDG-PET'] = clip_and_min_max_normalize(images_A['FDG-PET'], self.fdg_suv_min, self.fdg_suv_max) + images_A['pCT'] = clip_and_min_max_normalize(images_A['pCT'], self.hu_min, self.hu_max) + images_B['HX4-PET'] = clip_and_min_max_normalize(images_B['HX4-PET'], self.hx4_tbr_min, self.hx4_tbr_max) + if self.require_ldct_for_training: + images_B['ldCT'] = clip_and_min_max_normalize(images_B['ldCT'], self.hu_min, self.hu_max) + + + # --------------------- + # Construct sample dict + + # A and B need to have dims (C,D,H,W) + A = torch.stack((images_A['FDG-PET'], images_A['pCT']), dim=0) + + if self.require_ldct_for_training: + B = torch.stack((images_B['HX4-PET'], images_B['ldCT']), dim=0) + else: + B = images_B['HX4-PET'].unsqueeze(dim=0) + + sample_dict = {'A': A, 'B': B} + + return sample_dict + diff --git a/projects/maastro_hx4_pet_translation/datasets/utils/__init__.py b/projects/maastro_hx4_pet_translation/datasets/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/maastro_hx4_pet_translation/datasets/utils/basic.py b/projects/maastro_hx4_pet_translation/datasets/utils/basic.py new file mode 100644 index 00000000..5be89b34 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/datasets/utils/basic.py @@ -0,0 +1,54 @@ +import numpy as np +import torch +import SimpleITK as sitk + +from midaGAN.utils import sitk_utils +from midaGAN.data.utils.body_mask import get_body_mask +from midaGAN.data.utils.normalization import min_max_normalize + + +# Body mask settings +OUT_OF_BODY_HU = -1024 +OUT_OF_BODY_SUV = 0 +HU_THRESHOLD = -300 + + + +def apply_body_mask(image_dict, generate_body_mask=False): + + # If body mask doesn't exist, then create one from the available CT using morph. ops + if generate_body_mask: + assert image_dict['body-mask'] is None + assert any(['CT' in k for k in image_dict.keys()]) # There should be a CT in the dict to be able to generate a mask + ct_image_name = [k for k in image_dict.keys() if 'CT' in k][0] + image_dict['body-mask'] = get_body_mask(image_dict[ct_image_name], HU_THRESHOLD) + + # Apply masking to any CT or PET image present in image_dict + assert image_dict['body-mask'] is not None + body_mask = image_dict['body-mask'] + for k in image_dict.keys(): + if 'PET' in k: + image_dict[k] = np.where(body_mask, image_dict[k], OUT_OF_BODY_SUV) + elif 'CT' in k: + image_dict[k] = np.where(body_mask, image_dict[k], OUT_OF_BODY_HU) + + return image_dict + + +def clip_and_min_max_normalize(tensor, min_value, max_value): + tensor = torch.clamp(tensor, min_value, max_value) + tensor = min_max_normalize(tensor, min_value, max_value) + return tensor + + +def sitk2np(image_dict): + # WHD to DHW + for k in image_dict.keys(): + if isinstance(image_dict[k], sitk.SimpleITK.Image): + image_dict[k] = sitk_utils.get_npy(image_dict[k]) + return image_dict + +def np2tensor(image_dict): + for k in image_dict.keys(): + image_dict[k] = torch.tensor(image_dict[k]) + return image_dict diff --git a/projects/maastro_hx4_pet_translation/datasets/utils/patch_samplers.py b/projects/maastro_hx4_pet_translation/datasets/utils/patch_samplers.py new file mode 100644 index 00000000..cb4840a5 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/datasets/utils/patch_samplers.py @@ -0,0 +1,275 @@ +import numpy as np +from loguru import logger + + +PAIRED_SAMPLING_SCHEMES = ('uniform-random-within-body', 'fdg-pet-weighted') +UNPAIRED_SAMPLING_SCHEMES = ('uniform-random-within-body-sf', 'fdg-pet-weighted-sf') + + +class PairedPatchSampler3D(): + """3D patch sampler for paired training. + + Available patch sampling schemes: + 1. 'uniform-random-within-body' + 2. 'fdg-pet-weighted' + """ + def __init__(self, patch_size, sampling): + + if sampling not in PAIRED_SAMPLING_SCHEMES: + raise ValueError(f"`{sampling}` not a valid paired patch sampling scheme. \ + Available schemes: {PAIRED_SAMPLING_SCHEMES}") + + self.patch_size = np.array(patch_size) + self.sampling = sampling + + + def get_patch_pair(self, image_dict_A, image_dict_B): + + # Sample a single focal point to be used for both domain A and B images + # Domain A and domain B images are expected to be voxel-to-voxel paired + focal_point = self._sample_common_focal_point(image_dict_A) + + # Extract patches from all volumes given this focal point and the patch size + start_idx = focal_point - np.floor(self.patch_size/2) + end_idx = start_idx + self.patch_size + z1, y1, x1 = start_idx.astype(np.uint16) + z2, y2, x2 = end_idx.astype(np.uint16) + + patch_dict_A, patch_dict_B = {}, {} + for k in image_dict_A.keys(): + patch_dict_A[k] = image_dict_A[k][z1:z2, y1:y2, x1:x2] + for k in image_dict_B.keys(): + patch_dict_B[k] = image_dict_B[k][z1:z2, y1:y2, x1:x2] + + return patch_dict_A, patch_dict_B + + + def _sample_common_focal_point(self, image_dict_A): + body_mask = image_dict_A['body-mask'] + volume_size = body_mask.shape[-3:] # DHW + + # Initialize sampling probability map as a volumetric mask of body region contained inside the + # volume's valid patch region (i.e. suffieciently away from the volume borders) + sampling_prob_map = init_sampling_probability_map(volume_size, self.patch_size, body_mask) + + # Depending on the sampling technique, construct the probability map + if self.sampling == 'uniform-random-within-body': + # Uniform random over all valid focal points + sampling_prob_map = sampling_prob_map / np.sum(sampling_prob_map) + + elif self.sampling == 'fdg-pet-weighted': + # Random sampling, biased to high SUV regions in FDG-PET + FDG_PET_volume = image_dict_A['FDG-PET'] + # Clip negative values to zero + FDG_PET_volume = np.clip(FDG_PET_volume, 0, None) + # Update the probability map + sampling_prob_map = sampling_prob_map * FDG_PET_volume + sampling_prob_map = sampling_prob_map / np.sum(sampling_prob_map) + + # Sample focal points using this probability map + focal_point = sample_from_probability_map(sampling_prob_map) + + return np.array(focal_point).astype(np.uint16) + + + +class UnpairedPatchSampler3D(): + """3D patch sampler for unpaired training. + + Variations of Stochastic Focal patch sampling, where different schemes + differ in the way the focal point is sampled from domain A image(s). + Essentially, the schemes implement different "prior" patch sampling + probability distributions. + + + Available patch sampling schemes: + 1. 'uniform-random-sf' + 2. 'fdg-pet-weighted-sf' + """ + def __init__(self, patch_size, sampling, focal_region_proportion): + + if sampling not in UNPAIRED_SAMPLING_SCHEMES: + raise ValueError(f"`{sampling}` not a valid unpaired patch sampling scheme. \ + Available schemes: {UNPAIRED_SAMPLING_SCHEMES}") + + self.patch_size = np.array(patch_size) + self.sampling = sampling + self.focal_region_proportion = np.array(focal_region_proportion) + + + def get_patch_pair(self, image_dict_A, image_dict_B): + # Sample a focal point and its size-normlaized version for domain A images + focal_point_A, relative_focal_point = self._sample_focal_point_A(image_dict_A) + + # Sample a focal point for B images that is in relative neighborhood of the focal point of A images + focal_point_B = self._sample_focal_point_B(image_dict_B, relative_focal_point) + + # Extract patches from all volumes given this focal point and the patch size + start_idx_A = focal_point_A - np.floor(self.patch_size/2) + end_idx_A = start_idx_A + self.patch_size + z1_A, y1_A, x1_A = start_idx_A.astype(np.uint16) + z2_A, y2_A, x2_A = end_idx_A.astype(np.uint16) + + start_idx_B = focal_point_B - np.floor(self.patch_size/2) + end_idx_B = start_idx_B + self.patch_size + z1_B, y1_B, x1_B = start_idx_B.astype(np.uint16) + z2_B, y2_B, x2_B = end_idx_B.astype(np.uint16) + + patch_dict_A = {} + for k in image_dict_A.keys(): + patch_dict_A[k] = image_dict_A[k][z1_A:z2_A, y1_A:y2_A, x1_A:x2_A] + + patch_dict_B = {} + for k in image_dict_B.keys(): + patch_dict_B[k] = image_dict_B[k][z1_B:z2_B, y1_B:y2_B, x1_B:x2_B] + + return patch_dict_A, patch_dict_B + + + def _sample_focal_point_A(self, image_dict_A): + body_mask = image_dict_A['body-mask'] + volume_size = body_mask.shape # DHW + + # Initialize sampling probability map as a volumetric mask of body region contained inside the + # volume's valid patch region (i.e. suffieciently away from the volume borders) + sampling_prob_map = sampling_prob_map = init_sampling_probability_map(volume_size, self.patch_size, body_mask) + + # Depending on the sampling technique, construct the probability map + if self.sampling == 'uniform-random-within-body-sf': + # Uniform random over all valid focal points + sampling_prob_map = sampling_prob_map / np.sum(sampling_prob_map) + + elif self.sampling == 'fdg-pet-weighted-sf': + # Random sampling, biased to high SUV regions in FDG-PET + FDG_PET_volume = image_dict_A['FDG-PET'] + # Clip negative values to zero + FDG_PET_volume = np.clip(FDG_PET_volume, 0, None) + # Update the probability map + sampling_prob_map = sampling_prob_map * FDG_PET_volume + sampling_prob_map = sampling_prob_map / np.sum(sampling_prob_map) + + # Sample focal point using this probability map + focal_point = sample_from_probability_map(sampling_prob_map) + focal_point = np.array(focal_point) + + # Calculate the relative focal point by normalizing focal point indices with the volume size + relative_focal_point = focal_point / np.array(volume_size) + + return focal_point.astype(np.uint16), relative_focal_point + + + def _sample_focal_point_B(self, image_dict_B, relative_focal_point): + body_mask = image_dict_B['body-mask'] + volume_size = body_mask.shape # DHW + + focal_region_size = self.focal_region_proportion * np.array(volume_size) + focal_region_size = focal_region_size.astype(np.uint16) + + # Map relative point to corresponding point in this volume + focal_point = relative_focal_point * np.array(volume_size) + + # Intialize a sampling probability map for domain B images + sampling_prob_map = init_sampling_probability_map(volume_size, self.patch_size, body_mask) + + # Apply Stochastic focal sampling + focal_point_after_sf = self._apply_stochastic_focal_method(focal_point, focal_region_size, sampling_prob_map) + return focal_point_after_sf + + + def _apply_stochastic_focal_method(self, focal_point, focal_region_size, sampling_prob_map): + + # Create a focal region mask having the same size as the volume + volume_size = sampling_prob_map.shape + focal_region_min, focal_region_max = [], [] + + for axis in range(len(focal_point)): + # Find the lowest and highest position between which to focus for this axis + min_position = int(focal_point[axis] - focal_region_size[axis] / 2) + max_position = int(focal_point[axis] + focal_region_size[axis] / 2) + + # If one of the boundaries of the focus is outside of the volume size, cap it + min_position = max(min_position, 0) + max_position = min(max_position, volume_size[axis]) + + focal_region_min.append(min_position) + focal_region_max.append(max_position) + + z_min, y_min, x_min = focal_region_min + z_max, y_max, x_max = focal_region_max + focal_region_mask = np.zeros_like(sampling_prob_map) + focal_region_mask[z_min:z_max, y_min:y_max, x_min:x_max] = 1 + + # Update the sampling map by taking the intersection with the focal region mask. + # This is to make sure the sampled focal point is: + # 1. Within the volume's valid region + # 2. AND, Within body region + # 3. AND, Within focal region + intersection_mask = sampling_prob_map * focal_region_mask + if 1 not in list(np.unique(intersection_mask)): + # Edge case: If no intersection region is found between (1+2) and (3), + # just sample a B-image patch from anywhere within (1+2) region, i.e. valid body region + logger.warning("Stochastic focal sampling failed in a domain B image. \ + A likely cause might be a too small `focal_region_proportion` value. \ + Sampling a random valid patch from within the body region.") + sampling_prob_map = sampling_prob_map / np.sum(sampling_prob_map) + focal_point_B = sample_from_probability_map(sampling_prob_map) + return focal_point_B + + # Otherwise, continue with using the intersection mask and update the sampling probability map + sampling_prob_map = intersection_mask / np.sum(intersection_mask) + + # Sample focal point using this updated probability map + focal_point_after_sf = sample_from_probability_map(sampling_prob_map) + return focal_point_after_sf + + + +# -------------- +# Util functions + +def sample_from_probability_map(sampling_prob_map): + """TODO: Doc + """ + # Check if samplig prob map is a proper distribution (i.e. its sum is approx. equal to 1) + epsilon = 0.001 + assert np.sum(sampling_prob_map) > 1 - epsilon and np.sum(sampling_prob_map) < 1 + epsilon + + # Select relevant indices to sample from (i.e. those having a non-zero probability) + relevant_idxs = np.argwhere(sampling_prob_map > 0) + + # Using the sampling probability map, define the sampling distribution over these relevant indices + distribution = sampling_prob_map[sampling_prob_map > 0].flatten() + + # Sample a single voxel index. This is the focal point. + s = np.random.choice(len(relevant_idxs), p=distribution) + sampled_idx = relevant_idxs[s] + + return sampled_idx + + +def init_sampling_probability_map(volume_size, patch_size, body_mask=None): + """Initialize sampling probability map as a volumetric mask of body region contained inside the + volume's valid patch region (i.e. suffieciently away from the volume borders) + """ + # Initialize sampling probability map as zeros + sampling_prob_map = np.zeros(volume_size) + + # Get valid index range for focal points - upper-bound inclusive + valid_foc_pt_idx_min, valid_foc_pt_idx_max = get_valid_region_corner_points(volume_size, patch_size) + z_min, y_min, x_min = valid_foc_pt_idx_min.astype(np.uint16) + z_max, y_max, x_max = valid_foc_pt_idx_max.astype(np.uint16) + + # Set valid zone values as 1 + sampling_prob_map[z_min:z_max, y_min:y_max, x_min:x_max] = 1 + + # If body mask is given, filter out voxels outside the body region to avoid sampling patches from the background areas. + if body_mask is not None: + sampling_prob_map = sampling_prob_map * body_mask # Implemented as taking an intersection between the 2 volumes. + + return sampling_prob_map + + +def get_valid_region_corner_points(volume_size, patch_size): + valid_foc_pt_idx_min = np.zeros(3) + np.floor(patch_size/2) + valid_foc_pt_idx_max = np.array(volume_size) - np.ceil(patch_size/2) + return valid_foc_pt_idx_min.astype(np.int16), valid_foc_pt_idx_max.astype(np.int16) diff --git a/projects/maastro_hx4_pet_translation/datasets/val_test_dataset.py b/projects/maastro_hx4_pet_translation/datasets/val_test_dataset.py new file mode 100644 index 00000000..951d8c86 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/datasets/val_test_dataset.py @@ -0,0 +1,224 @@ +import os +from dataclasses import dataclass +from typing import Tuple +from loguru import logger + +import pandas as pd +import torch +from torch.utils.data import Dataset + +from midaGAN import configs +from midaGAN.utils import sitk_utils +from midaGAN.data.utils.normalization import min_max_denormalize +from midaGAN.data.utils.ops import pad + +from projects.maastro_hx4_pet_translation.datasets.utils.basic import (sitk2np, + np2tensor, + apply_body_mask, + clip_and_min_max_normalize) + + +@dataclass +class HX4PETTranslationValTestDatasetConfig(configs.base.BaseDatasetConfig): + """ + Note: Val dataset is paired, and does not supply ldCT + """ + name: str = "HX4PETTranslationValTestDataset" + hu_range: Tuple[int, int] = (-1000, 2000) + fdg_suv_range: Tuple[float, float] = (0.0, 15.0) + hx4_tbr_range: Tuple[float, float] = (0.0, 3.0) + # Use sliding window inference - If True, the val test engine takes care of it. + # Patch size value is interpolated from training patch size + use_patch_based_inference: bool = False + # Option to supply body and GTV masks in the sample_dict. If supplied, masked metrics will + # be computed additionally during validation which would slow down the training. + supply_masks: bool = False + # Is the model HX4CycleGANBalanced? If so, need to do a small hack while supplying HX4-PET + model_is_hx4_cyclegan_balanced: bool = False + + +class HX4PETTranslationValTestDataset(Dataset): + + def __init__(self, conf): + + # Image file paths + root_path = conf.val.dataset.root + + self.patient_ids = sorted(os.listdir(root_path)) + self.image_paths = {'FDG-PET': [], 'pCT': [], 'HX4-PET': [], 'body-mask': [], 'gtv-mask': []} + + for p_id in self.patient_ids: + patient_image_paths = {} + patient_image_paths['FDG-PET'] = f"{root_path}/{p_id}/fdg_pet.nrrd" + patient_image_paths['pCT'] = f"{root_path}/{p_id}/pct.nrrd" + patient_image_paths['HX4-PET'] = f"{root_path}/{p_id}/hx4_pet_reg.nrrd" + patient_image_paths['body-mask'] = f"{root_path}/{p_id}/pct_body.nrrd" + patient_image_paths['gtv-mask'] = f"{root_path}/{p_id}/pct_gtv.nrrd" + + for k in self.image_paths.keys(): + self.image_paths[k].append(patient_image_paths[k]) + + self.num_datapoints = len(self.image_paths['FDG-PET']) + + # SUVmean_aorta values for normalizing HX4-PET SUV to TBR + suv_aorta_mean_file = f"{os.path.dirname(root_path)}/SUVmean_aorta_HX4.csv" + self.suv_aorta_mean_values = pd.read_csv(suv_aorta_mean_file, index_col=0) + self.suv_aorta_mean_values = self.suv_aorta_mean_values.to_dict()['HX4 aorta SUVmean baseline'] + + # Clipping ranges + self.hu_min, self.hu_max = conf.val.dataset.hu_range + self.fdg_suv_min, self.fdg_suv_max = conf.val.dataset.fdg_suv_range + self.hx4_tbr_min, self.hx4_tbr_max = conf.val.dataset.hx4_tbr_range + + # Using sliding window inferer or performing full-image inference ? + self.use_patch_based_inference = conf.val.dataset.use_patch_based_inference + + # Supply body and GTV masks ? + self.supply_masks = conf.val.dataset.supply_masks + + # Is HX4-CycleGAN-balanced the model being validated/tested ? + self.model_is_hx4_cyclegan_balanced = conf.val.dataset.model_is_hx4_cyclegan_balanced + + + def __len__(self): + return self.num_datapoints + + + def __getitem__(self, index): + + # ------------ + # Fetch images + index = index % self.num_datapoints + + image_path = {} + image_path['FDG-PET'] = self.image_paths['FDG-PET'][index] + image_path['pCT'] = self.image_paths['pCT'][index] + image_path['HX4-PET'] = self.image_paths['HX4-PET'][index] + image_path['body-mask'] = self.image_paths['body-mask'][index] + image_path['gtv-mask'] = self.image_paths['gtv-mask'][index] + + # Load NRRD as SimpleITK objects (WHD) + images = {} + for k in image_path.keys(): + # One patient in val set (N046) doesn't have a pCT body mask + try: + images[k] = sitk_utils.load(image_path[k]) + except RuntimeError: + if k == 'body-mask': + logger.warning(f"Patient {self.patient_ids[index]} does not have a body mask. It will be generated automatically") + # Set as `None` for now, handle it later in apply_body_mask() + # by creating a mask on the go using thresholding + images[k] = None + + + # ---------------------- + # Collect image metadata + metadata = { + 'patient_id': self.patient_ids[index], + 'size': images['FDG-PET'].GetSize(), + 'origin': images['FDG-PET'].GetOrigin(), + 'spacing': images['FDG-PET'].GetSpacing(), + 'direction': images['FDG-PET'].GetDirection(), + 'dtype': sitk_utils.get_npy_dtype(images['FDG-PET']) + } + + + # --------------- + # Apply body mask + + # Convert to numpy (DHW) + images = sitk2np(images) + + if self.patient_ids[index] == 'N046': + generate_body_mask = True + else: + generate_body_mask = False + + images = apply_body_mask(images, generate_body_mask) + + + # -------------------- + # Pad images if needed + + # If doing full-image inference, pad images to have a standard size of (64, 512, 512) + # to avoid issues with UNet's up- and downsampling + if not self.use_patch_based_inference: + for k in images.keys(): + images[k] = pad(images[k], target_shape=(64, 512, 512)) + + # Convert to tensors + images = np2tensor(images) + + # ------------- + # Normalization + + # Normalize HX4-PET SUVs with SUVmean_aorta + patient_id = self.patient_ids[index] + images['HX4-PET'] = images['HX4-PET'] / self.suv_aorta_mean_values[patient_id] + + # Clip and then rescale all intensties to range [-1, 1] + images['FDG-PET'] = clip_and_min_max_normalize(images['FDG-PET'], self.fdg_suv_min, self.fdg_suv_max) + images['pCT'] = clip_and_min_max_normalize(images['pCT'], self.hu_min, self.hu_max) + images['HX4-PET'] = clip_and_min_max_normalize(images['HX4-PET'], self.hx4_tbr_min, self.hx4_tbr_max) + + + # --------------------- + # Construct sample dict + + # A and B need to have dims (C,D,H,W) + A = torch.stack((images['FDG-PET'], images['pCT']), dim=0) + + if self.model_is_hx4_cyclegan_balanced: + # Create a dummy array to fill up the 2nd channel + zeros_dummy = torch.zeros_like(images['HX4-PET']) + B = torch.stack([images['HX4-PET'], zeros_dummy], dim=0) + else: + B = images['HX4-PET'].unsqueeze(dim=0) + + sample_dict = {'A': A, 'B': B} + + # Include masks, if needed + if self.supply_masks: + sample_dict['masks'] = {'BODY': images['body-mask'].unsqueeze(dim=0), + 'GTV': images['gtv-mask'].unsqueeze(dim=0)} + + # Include metadata + sample_dict['metadata'] = metadata + + return sample_dict + + + def denormalize(self, tensor): + """Allows the Tester and Validator to calculate the metrics in + the original range of values. + `tensor` can be either the predicted or the ground truth HX4-PET image tensor + """ + tensor = min_max_denormalize(tensor, self.hx4_tbr_min, self.hx4_tbr_max) + return tensor + + + def save(self, tensor, save_dir, metadata): + """ Save predicted tensors as NRRD + """ + + # If the model is HX4-CycleGAN-balanced, tensor is 2-channel with the + # 1st channel containing HX4-PET and 2nd channel containing a dummy array. + if self.model_is_hx4_cyclegan_balanced: + tensor = tensor[0] # Dim 1 is the channel dim + else: + tensor = tensor.squeeze() + + # Rescale back to [self.hx4_tbr_min, self.hx4_tbr_max] + tensor = min_max_denormalize(tensor.cpu(), self.hx4_tbr_min, self.hx4_tbr_max) + + # Denormalize TBR to SUV + patient_id = metadata['patient_id'] + tensor = tensor * self.suv_aorta_mean_values[patient_id] + + sitk_image = sitk_utils.tensor_to_sitk_image(tensor, metadata['origin'], + metadata['spacing'], metadata['direction'], + metadata['dtype']) + # Write to file + os.makedirs(save_dir, exist_ok=True) + save_path = f"{save_dir}/{patient_id}.nrrd" + sitk_utils.write(sitk_image, save_path) diff --git a/projects/maastro_hx4_pet_translation/experiments/cyclegan_balanced.yaml b/projects/maastro_hx4_pet_translation/experiments/cyclegan_balanced.yaml new file mode 100644 index 00000000..9f357b21 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/experiments/cyclegan_balanced.yaml @@ -0,0 +1,84 @@ +project_dir: "./projects/maastro_hx4_pet_translation/" + +train: + output_dir: "/workspace/Chinmay-Checkpoints-Ephemeral/HX4-PET-Translation/hx4_pet_cyclegan_balanced/" + cuda: True + n_iters: 30000 + n_iters_decay: 30000 + batch_size: 1 + mixed_precision: False + seed: 1 + + logging: + freq: 50 + multi_modality_split: + A: [1, 1] + B: [1, 1] # ldCT is the 2nd component + wandb: + project: "maastro_hx4_pet_translation" + run: "cyclegan_balanced_lambdas10" + + checkpointing: + freq: 1000 + + dataset: + name: "HX4PETTranslationTrainDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/train" + paired: False # Unpaired training + require_ldct_for_training: True # ldCT required for training + patch_size: [32, 128, 128] # (D,H,W) + patch_sampling: uniform-random-within-body-sf + focal_region_proportion: [0.6, 0.35, 0.35] # (D,H,W) + num_workers: 8 + + gan: + name: "HX4CycleGANBalanced" + generator: + name: "Unet3D" + in_out_channels_AB: [2, 1] # Both G's take 2 inputs and predict 1 output + in_out_channels_BA: [2, 1] # + num_downs: 4 + ngf: 64 + use_dropout: True + + discriminator: + name: "PatchGAN3D" + in_channels_B: 1 # Both D's evaluate a single modality (the PETs) + in_channels_A: 1 # + n_layers: 3 + kernel_size: [4, 4, 4] + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_AB: 10.0 + lambda_BA: 10.0 + lambda_identity: 0 + proportion_ssim: 0 + + metrics: + discriminator_evolution: True + ssim: False # `False` because it's computed by with the dummy array included, and is hence wrong + + +val: + freq: 1000 + dataset: + name: "HX4PETTranslationValTestDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/val" + num_workers: 8 + supply_masks: False # Do not supply masks -> less number of val metrics to compute, faster training + model_is_hx4_cyclegan_balanced: True # Using HX4-CycleGAN-balanced + use_patch_based_inference: True # + sliding_window: # Enable sliding window inferer + window_size: ${train.dataset.patch_size} + metrics: # Enabled metrics: MSE, MAE, PSNR, SSIM, NMI, Chi-squared histogram distance + mse: True + mae: True + nmse: False + psnr: True + ssim: True + nmi: True + histogram_chi2: True + cycle_metrics: False # `False` because cycle in validation is hardcoded to be the default (naive) way diff --git a/projects/maastro_hx4_pet_translation/experiments/cyclegan_naive.yaml b/projects/maastro_hx4_pet_translation/experiments/cyclegan_naive.yaml new file mode 100644 index 00000000..d66cab8c --- /dev/null +++ b/projects/maastro_hx4_pet_translation/experiments/cyclegan_naive.yaml @@ -0,0 +1,83 @@ +project_dir: "./projects/maastro_hx4_pet_translation/" + +train: + output_dir: "/workspace/Chinmay-Checkpoints-Ephemeral/HX4-PET-Translation/hx4_pet_cyclegan_naive/" + cuda: True + n_iters: 30000 + n_iters_decay: 30000 + batch_size: 1 + mixed_precision: False + seed: 1 + + logging: + freq: 50 + multi_modality_split: + A: [1, 1] + B: [1] + wandb: + project: "maastro_hx4_pet_translation" + run: "cyclegan_naive_lambdas10" + + checkpointing: + freq: 1000 + + dataset: + name: "HX4PETTranslationTrainDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/train" + paired: False # Unpaired training + require_ldct_for_training: False # ldCT not required for training + patch_size: [32, 128, 128] # (D,H,W) + patch_sampling: uniform-random-within-body-sf + focal_region_proportion: [0.6, 0.35, 0.35] # (D,H,W) + num_workers: 8 + + gan: + name: "CycleGAN" + generator: + name: "Unet3D" + in_out_channels_AB: [2, 1] + in_out_channels_BA: [1, 2] + num_downs: 4 + ngf: 64 + use_dropout: True + + discriminator: + name: "PatchGAN3D" + in_channels_B: 1 + in_channels_A: 2 + n_layers: 3 + kernel_size: [4, 4, 4] + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_AB: 10.0 + lambda_BA: 10.0 + lambda_identity: 0 + proportion_ssim: 0 + + metrics: + discriminator_evolution: True + ssim: False # `False`, to match with HX4-CycleGAN-balanced + + +val: + freq: 1000 + dataset: + name: "HX4PETTranslationValTestDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/val" + num_workers: 8 + supply_masks: False # Do not supply masks -> less number of val metrics to compute, faster training + use_patch_based_inference: True # + sliding_window: # Enable sliding window inferer + window_size: ${train.dataset.patch_size} + metrics: # Enabled metrics: MSE, MAE, PSNR, SSIM, NMI, Chi-squared histogram distance + mse: True + mae: True + nmse: False + psnr: True + ssim: True + nmi: True + histogram_chi2: True + cycle_metrics: False # `False`, to match with HX4-CycleGAN-balanced diff --git a/projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml b/projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml new file mode 100644 index 00000000..9411eb73 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml @@ -0,0 +1,75 @@ +project_dir: "./projects/maastro_hx4_pet_translation/" + +train: + output_dir: "/workspace/Chinmay-Checkpoints-Ephemeral/HX4-PET-Translation/hx4_pet_pix2pix_lambda10/" + cuda: True + n_iters: 30000 + n_iters_decay: 30000 + batch_size: 1 + mixed_precision: False + seed: 1 + + logging: + freq: 50 + multi_modality_split: + A: [1, 1] + B: [1] + wandb: + project: "maastro_hx4_pet_translation" + run: "pix2pix_lambda10" + + checkpointing: + freq: 1000 + + dataset: + name: "HX4PETTranslationTrainDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/train" + paired: True # Paired training + patch_size: [32, 128, 128] # (D,H,W) + patch_sampling: uniform-random-within-body + num_workers: 8 + + gan: + name: "Pix2PixConditionalGAN" + generator: + name: "Unet3D" + in_out_channels: [2, 1] + num_downs: 4 + ngf: 64 + use_dropout: True + + discriminator: + name: "PatchGAN3D" + in_channels: 3 + n_layers: 3 + kernel_size: [4, 4, 4] + ndf: 64 + + optimizer: + lr_D: 0.0001 + lr_G: 0.0002 + lambda_pix2pix: 10.0 + + metrics: + discriminator_evolution: True + + +val: + freq: 1000 + dataset: + name: "HX4PETTranslationValTestDataset" + root: "/workspace/Chinmay-Datasets-Ephemeral/HX4-PET-Translation-Processed/val" + num_workers: 8 + supply_masks: False # Do not supply masks -> less number of val metrics to compute, faster training + use_patch_based_inference: True # + sliding_window: # Enable sliding window inferer + window_size: ${train.dataset.patch_size} + metrics: # Enabled metrics: MSE, MAE, PSNR, SSIM, NMI, Chi-squared histogram distance + mse: True + mae: True + nmse: False + psnr: True + ssim: True + nmi: True + histogram_chi2: True + cycle_metrics: False \ No newline at end of file diff --git a/projects/maastro_hx4_pet_translation/jobscript.sh b/projects/maastro_hx4_pet_translation/jobscript.sh new file mode 100644 index 00000000..895dcc47 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/jobscript.sh @@ -0,0 +1,43 @@ +#!/usr/local_rwth/bin/zsh + + +# Job configuration --- + +#SBATCH --job-name=hx4_pet_pix2pix +#SBATCH --output=/home/zk315372/Chinmay/Git/midaGAN/projects/maastro_hx4_pet_translation/slurm_logs/%j.log + +## OpenMP settings +#SBATCH --cpus-per-task=8 +#SBATCH --mem-per-cpu=4G + +## Request for a node with 2 Tesla P100 GPUs +#SBATCH --gres=gpu:pascal:2 + +#SBATCH --time=5:00:00 + +## TO use the UM DKE project account +# #SBATCH --account=um_dke + + +# Load CUDA +module load cuda + +# Debug info +echo; echo +nvidia-smi +echo; echo + +# Execute training +python_interpreter="/home/zk315372/miniconda3/envs/gan_env/bin/python3" +training_file="/home/zk315372/Chinmay/Git/midaGAN/tools/train.py" +config_file="/home/zk315372/Chinmay/Git/midaGAN/projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml" + +CUDA_VISIBLE_DEVICES=0 $python_interpreter $training_file config=$config_file + + +# ---------------------- +# Run single GPU example: +# CUDA_VISIBLE_DEVICES=0 python tools/train.py config="./projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml" + +# Run distributed example: +# python -m torch.distributed.launch --use_env --nproc_per_node 2 tools/train.py config="./projects/maastro_hx4_pet_translation/experiments/pix2pix.yaml" \ No newline at end of file diff --git a/projects/maastro_hx4_pet_translation/modules/__init__.py b/projects/maastro_hx4_pet_translation/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/projects/maastro_hx4_pet_translation/modules/hx4_cyclegan_balanced.py b/projects/maastro_hx4_pet_translation/modules/hx4_cyclegan_balanced.py new file mode 100644 index 00000000..bd0d7660 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/modules/hx4_cyclegan_balanced.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass + +import torch + +from midaGAN.nn.gans.unpaired import cyclegan +from midaGAN.nn.losses.adversarial_loss import AdversarialLoss + +from projects.maastro_hx4_pet_translation.modules.hx4_cyclegan_balanced_losses \ + import HX4CycleGANBalancedLosses + + +@dataclass +class HX4CycleGANBalancedConfig(cyclegan.CycleGANConfig): + """ HX4CycleGANBalanced Config """ + name: str = "HX4CycleGANBalanced" + + +class HX4CycleGANBalanced(cyclegan.CycleGAN): + """ Balanced CycleGAN for HX4-PET synthesis + Notation: + A1, A2 -- FDG-PET, pCT + B1, B2 -- HX4-PET, ldCT + """ + + def __init__(self, conf): + super().__init__(conf) + + def init_criterions(self): + # Standard GAN loss + self.criterion_adv = AdversarialLoss( + self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) + # Generator-related losses -- Cycle-consistency and Identity loss + self.criterion_G = HX4CycleGANBalancedLosses(self.conf) + + def forward(self): + """Run forward pass; called by both methods and .""" + real_A = self.visuals['real_A'] # [real FDG-PET, real pCT] + real_B = self.visuals['real_B'] # [real HX4-PET, real ldCT] + + # Forward cycle G_AB (A to B) + fake_B1 = self.networks['G_AB'](real_A) # Compute [fake HX4-PET] + real_A2 = real_A[:, 1:] # Get [real pCT] + rec_A1 = self.networks['G_BA'](torch.cat([fake_B1, real_A2], dim=1)) # Compute [recon FDG-PET], given [fake HX4-PET, real pCT] + + # Backward cycle G_BA (B to A) + fake_A1 = self.networks['G_BA'](real_B) # Compute [fake FDG-PET], given [real HX4-PET, real ldCT] + real_B2 = real_B[:, 1:] # Get [real ldCT] + rec_B1 = self.networks['G_AB'](torch.cat([fake_A1, real_B2], dim=1)) # Compute [recon HX4-PET], given [fake FDG-PET, real ldCT] + + # In self.visuals, fake and recon A's and B's are expected to have 2 channels, because + # the real ones have 2 channels. This is because the multimodal channel split is specified + # for each domain A and B generally, but not for reals, fakes and recons separately. + # Hack -- Use dummy zeros arrays to fill up the channels of CT components (i.e. the 2nd channel) + zeros_dummy = torch.zeros_like(real_A2) + self.visuals.update({ + 'fake_B': torch.cat([fake_B1, zeros_dummy], dim=1), + 'rec_A': torch.cat([rec_A1, zeros_dummy], dim=1), + 'fake_A': torch.cat([fake_A1, zeros_dummy], dim=1), + 'rec_B': torch.cat([rec_B1, zeros_dummy], dim=1), + }) + + def backward_D(self, discriminator): + """Calculate GAN loss for the discriminator""" + # D_B only evaluates HX4-PET + if discriminator == 'D_B': + real = self.visuals['real_B'][:, :1] + fake = self.visuals['fake_B'][:, :1] + fake = self.fake_B_pool.query(fake) + loss_id = 0 + + # D_A only evaluates FDG-PET + elif discriminator == 'D_A': + real = self.visuals['real_A'][:, :1] + fake = self.visuals['fake_A'][:, :1] + fake = self.fake_A_pool.query(fake) + loss_id = 1 + else: + raise ValueError('The discriminator has to be either "D_A" or "D_B".') + + self.pred_real = self.networks[discriminator](real) + + # Detaching fake: https://github.com/pytorch/examples/issues/116 + self.pred_fake = self.networks[discriminator](fake.detach()) + + loss_real = self.criterion_adv(self.pred_real, target_is_real=True) + loss_fake = self.criterion_adv(self.pred_fake, target_is_real=False) + self.losses[discriminator] = loss_real + loss_fake + + # backprop + self.backward(loss=self.losses[discriminator], optimizer=self.optimizers['D'], loss_id=2) + + def backward_G(self): + """Calculate the loss for generators G_AB and G_BA using all specified losses""" + # Get HX4-PET and FDG-PET + fake_B1 = self.visuals['fake_B'][:, :1] # G_AB(A) + fake_A1 = self.visuals['fake_A'][:, :1] # G_BA(B) + + # ------------------------- GAN Loss ---------------------------- + pred_B = self.networks['D_B'](fake_B1) # D_B(G_AB(A)) + pred_A = self.networks['D_A'](fake_A1) # D_A(G_BA(B)) + + # Forward GAN loss D_A(G_AB(A)) + self.losses['G_AB'] = self.criterion_adv(pred_B, target_is_real=True) + # Backward GAN loss D_B(G_BA(B)) + self.losses['G_BA'] = self.criterion_adv(pred_A, target_is_real=True) + # --------------------------------------------------------------- + + # ------------- G Losses (Cycle, Identity) ---------------------- + losses_G = self.criterion_G(self.visuals) + self.losses.update(losses_G) + # --------------------------------------------------------------- + + # combine losses and calculate gradients + combined_loss_G = sum(losses_G.values()) + self.losses['G_AB'] + self.losses['G_BA'] + self.backward(loss=combined_loss_G, optimizer=self.optimizers['G'], loss_id=0) + + + + def infer(self, input, direction='AB'): + assert direction in ['AB', 'BA'], "Specify which generator direction, AB or BA, to use." + assert f'G_{direction}' in self.networks.keys() + + with torch.no_grad(): + fake_B1 = self.networks[f'G_{direction}'](input) # Compute [fake HX4-PET] + real_A2 = input[:, 1:] # Create zeros dummy array to fill up the 2nd channel + zeros_dummy = torch.zeros_like(real_A2) # + return torch.cat([fake_B1, zeros_dummy], dim=1) diff --git a/projects/maastro_hx4_pet_translation/modules/hx4_cyclegan_balanced_losses.py b/projects/maastro_hx4_pet_translation/modules/hx4_cyclegan_balanced_losses.py new file mode 100644 index 00000000..c28d8c50 --- /dev/null +++ b/projects/maastro_hx4_pet_translation/modules/hx4_cyclegan_balanced_losses.py @@ -0,0 +1,35 @@ +import math +import torch + +from midaGAN.nn.losses import cyclegan_losses + + +class HX4CycleGANBalancedLosses(cyclegan_losses.CycleGANLosses): + """ Modified to make Cycle-consitency account for only + FDG-PET images (in domain A) and HX4-PET images (in domain B), + and ignore CT components """ + + def __init__(self, conf): + self.lambda_AB = conf.train.gan.optimizer.lambda_AB + self.lambda_BA = conf.train.gan.optimizer.lambda_BA + + lambda_identity = conf.train.gan.optimizer.lambda_identity + proportion_ssim = conf.train.gan.optimizer.proportion_ssim + + # Cycle-consistency - L1, with optional weighted combination with SSIM + self.criterion_cycle = cyclegan_losses.CycleLoss(proportion_ssim) + + + def __call__(self, visuals): + # Separate out the FDG-PET and HX4-PET parts from the visuals tensors + real_A1, real_B1 = visuals['real_A'][:, :1], visuals['real_B'][:, :1] + fake_A1, fake_B1 = visuals['fake_A'][:, :1], visuals['fake_B'][:, :1] + rec_A1, rec_B1 = visuals['rec_A'][:, :1], visuals['rec_B'][:, :1] + + losses = {} + + # cycle-consistency loss + losses['cycle_A'] = self.lambda_AB * self.criterion_cycle(real_A1, rec_A1) + losses['cycle_B'] = self.lambda_BA * self.criterion_cycle(real_B1, rec_B1) + + return losses \ No newline at end of file