From b2033f1e59f81cfd5e5f3b264d4308a2be7d537e Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 00:53:43 -0400 Subject: [PATCH 1/8] removed files to avoid confusion --- modeling/models.py | 32 ------ modeling/train.py | 252 --------------------------------------------- 2 files changed, 284 deletions(-) delete mode 100644 modeling/models.py delete mode 100644 modeling/train.py diff --git a/modeling/models.py b/modeling/models.py deleted file mode 100644 index 239e425..0000000 --- a/modeling/models.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class TestModel(nn.Module): - """This is a simple test model with 3D convolutions and up-convolutions.""" - def __init__(self): - super(TestModel, self).__init__() - self.conv1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(3, 3, 3)) - self.conv2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(3, 3, 3)) - - self.upconv1 = nn.ConvTranspose3d(in_channels=16, out_channels=8, kernel_size=(3, 3, 3)) - self.upconv2 = nn.ConvTranspose3d(in_channels=8, out_channels=1, kernel_size=(3, 3, 3)) - - def forward(self, x1, x2): - # Quick check: only work with 3D volumes (extra dimension comes from batch size) - assert len(x1.shape) == 4 and len(x2.shape) == 4 - - # Expand dims (i.e. add channel dim. to beginning) - x1_, x2_ = x1.unsqueeze(1), x2.unsqueeze(1) - - # Pass to network - x1_, x2_ = F.relu(self.conv1(x1_)), F.relu(self.conv1(x2_)) - x1_, x2_ = F.relu(self.conv2(x1_)), F.relu(self.conv2(x2_)) - x1_, x2_ = F.relu(self.upconv1(x1_)), F.relu(self.upconv1(x2_)) - x1_, x2_ = self.upconv2(x1_), self.upconv2(x2_) - - # Simply get the differentiable diff. between the two feature maps for now - y_hat = torch.sigmoid(x2_ - x1_)[:, 0] - - return y_hat diff --git a/modeling/train.py b/modeling/train.py deleted file mode 100644 index 79a578b..0000000 --- a/modeling/train.py +++ /dev/null @@ -1,252 +0,0 @@ -import os -from copy import deepcopy -import argparse -from tqdm import tqdm - -import torch -import torch.nn as nn -import torch.multiprocessing as mp -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from transformers import AdamW, get_linear_schedule_with_warmup -from ivadomed.losses import AdapWingLoss, DiceLoss - -from datasets import MSSeg2Dataset -from utils import split_dataset -from models import TestModel - -parser = argparse.ArgumentParser(description='Script for training custom models for MSSeg2 Challenge 2021.') - -# Arguments for model, data, and training -parser.add_argument('-id', '--model_id', default='transunet', type=str, - help='Model ID to-be-used for saving the .pt saved model file') -parser.add_argument('-dr', '--dataset_root', default='/home/GRAMES.POLYMTL.CA/uzmac/duke/projects/ivadomed/tmp_ms_challenge_2021_preprocessed', type=str, - help='Root path to the BIDS- and ivadomed-compatible dataset') -parser.add_argument('-fd', '--fraction_data', default=1.0, type=float, - help='Fraction of data to use for the experiment. Helps with debugging.') - -parser.add_argument('-ne', '--num_epochs', default=200, type=int, - help='Number of epochs for the training process') -parser.add_argument('-bs', '--batch_size', default=20, type=int, - help='Batch size of the training and validation processes') -parser.add_argument('-nw', '--num_workers', default=4, type=int, - help='Number of workers for the dataloaders') - -parser.add_argument('-tlr', '--transformer_learning_rate', default=3e-5, type=float, - help='Learning rate for training the transformer') -parser.add_argument('-clr', '--custom_learning_rate', default=1e-3, type=float, - help='Learning rate for training the custom additions to the transformer') -parser.add_argument('-wd', '--weight_decay', type=float, default=0.01, - help='Weight decay (i.e. regularization) value in AdamW') -parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), - help='Decay terms for the AdamW optimizer') -parser.add_argument('--eps', type=float, default=1e-8, - help='Epsilon value for the AdamW optimizer') - -parser.add_argument('-sv', '--save', default='./saved_models', type=str, - help='Path to the saved models directory') -parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', - help='Load model from checkpoint and continue training') -parser.add_argument('-se', '--seed', default=42, type=int, - help='Set seeds for reproducibility') - -# Arguments for parallelization -parser.add_argument('-loc', '--local_rank', type=int, default=-1, - help='Local rank for distributed training on GPUs; set to != -1 to start distributed GPU training') -parser.add_argument('--master_addr', type=str, default='localhost', - help='Address of master; master must be able to accept network traffic on the address and port') -parser.add_argument('--master_port', type=str, default='29500', - help='Port that master is listening on') - -args = parser.parse_args() - - -def main_worker(rank, world_size): - # Configure model ID - model_id = args.model_id - print('MODEL ID: %s' % model_id) - print('RANK: ', rank) - - # Configure saved models directory - if not os.path.isdir(args.save): - os.makedirs(args.save) - print('Trained model will be saved to: %s' % args.save) - - if args.local_rank == -1: - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - else: - torch.cuda.set_device(rank) - device = torch.device('cuda', rank) - # Use NCCL for GPU training, process must have exclusive access to GPUs - torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=world_size) - - # TODO: Implement new models in `models.py` and change this line - model = TestModel() - - # Load saved model if applicable - if args.continue_from_checkpoint: - load_path = os.path.join('saved_models', '%s.pt' % model_id) - print('Loading learned weights from %s' % load_path) - state_dict = torch.load(load_path) - state_dict_ = deepcopy(state_dict) - # Rename parameters to exclude the starting 'module.' string so they match - # NOTE: We have to do this because of DataParallel saving parameters starting with 'module.' - for param in state_dict: - state_dict_[param.replace('module.', '')] = state_dict_.pop(param) - model.load_state_dict(state_dict_) - else: - print('Initializing model from scratch') - - if torch.cuda.device_count() > 1 and args.local_rank == -1: - model = nn.DataParallel(model) - model.to(device) - elif args.local_rank == -1: - model.to(device) - else: - model.to(device) - try: - from apex.parallel import DistributedDataParallel as DDP - print("Found Apex!") - model = DDP(model) - except ImportError: - from torch.nn.parallel import DistributedDataParallel as DDP - print("Using PyTorch DDP - could not find Apex") - model = DDP(model, device_ids=[rank], find_unused_parameters=False) - # TODO: `find_unused_parameters` might have to be True for certain model types! - - # Load datasets - dataset = MSSeg2Dataset(root=args.dataset_root, - patch_size=(128, 128, 128), - stride_size=(64, 64, 64), - center_crop_size=(320, 384, 512), - fraction_data=args.fraction_data) - - train_dataset, val_dataset = split_dataset(dataset=dataset, val_size=0.3, seed=args.seed) - # TODO: We also need test set, right? - - if args.local_rank == -1: - train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, - pin_memory=True, num_workers=args.num_workers) - val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, - pin_memory=True, num_workers=args.num_workers) - else: - train_sampler = DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=rank) - val_sampler = DistributedSampler(dataset=val_dataset, num_replicas=world_size, rank=rank) - train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, - shuffle=False, pin_memory=True, - num_workers=args.num_workers, sampler=train_sampler) - # NOTE: Train loader's shuffle is made False; using DistributedSampler instead - val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, - shuffle=False, pin_memory=True, - num_workers=args.num_workers, sampler=val_sampler) - - # Setup optimizer - no_weight_decay_ids = ['bias', 'LayerNorm.weight'] - grouped_model_parameters = [ - {'params': [param for name, param in model.named_parameters() - if not any(id_ in name for id_ in no_weight_decay_ids)], - 'lr': args.transformer_learning_rate, - 'betas': args.betas, - 'weight_decay': args.weight_decay, - 'eps': args.eps}, - {'params': [param for name, param in model.named_parameters() - if any(id_ in name for id_ in no_weight_decay_ids)], - 'lr': args.transformer_learning_rate, - 'betas': args.betas, - 'weight_decay': 0.0, - 'eps': args.eps}, - ] - - optimizer = AdamW(grouped_model_parameters) - num_training_steps = int(args.num_epochs * len(train_dataset) / args.batch_size) - num_warmup_steps = 0 # int(num_training_steps * 0.02) - scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps) - - # Setup loss function - criterion = AdapWingLoss(theta=0.5, alpha=2.1, omega=14, epsilon=1) - # TODO: The above params for the AdapWingLoss() are default given by `ivadomed`. Can - # we improve them for our application? - - # Setup other metrics - dice_metric = DiceLoss(smooth=1.0) - - # Training & Evaluation - for i in tqdm(range(args.num_epochs), desc='Iterating over Epochs'): - # -------------------------------- TRAINING ------------------------------------ - model.train() - train_epoch_loss = 0.0 - train_epoch_dice = 0.0 - - for batch in tqdm(train_loader, desc='Iterating over Training Examples'): - optimizer.zero_grad() - - x1, x2, y = batch - x1, x2, y = x1.to(device), x2.to(device), y.to(device) - - y_hat = model(x1, x2) - - loss = criterion(y_hat, y) - loss.backward() - optimizer.step() - - train_epoch_loss += loss.item() - train_epoch_dice += dice_metric(y_hat, y).item() - - train_epoch_loss /= len(train_loader) - train_epoch_dice /= len(train_loader) - - # -------------------------------- EVALUATION ------------------------------------ - model.eval() - val_epoch_loss = 0.0 - val_epoch_dice = 0.0 - - with torch.no_grad(): - for batch in tqdm(val_loader, desc='Iterating over Validation Examples'): - x1, x2, y = batch - x1, x2, y = x1.to(device), x2.to(device), y.to(device) - - y_hat = model(x1, x2) - - loss = criterion(y_hat, y) - - val_epoch_loss += loss.item() - val_epoch_dice += dice_metric(y_hat, y).item() - - val_epoch_loss /= len(val_loader) - val_epoch_dice /= len(val_loader) - - torch.save(model.state_dict(), os.path.join(args.save, '%s.pt' % model_id)) - - print('\n') # Do this in order to go below the second tqdm line - print(f'\tTrain Loss: %0.4f | Validation Loss: %0.4f' % (train_epoch_loss, val_epoch_loss)) - print(f'\tTrain Dice: %0.4f | Validation Dice: %0.4f' % (train_epoch_dice, val_epoch_dice)) - - # Apply learning rate decay before the beginning of next epoch if applicable - scheduler.step() - - # Cleanup DDP if applicable - if args.local_rank != -1: - torch.distributed.destroy_process_group() - - -def main(): - # Configure number of GPUs to be used - n_gpus = torch.cuda.device_count() - print('We are using %d GPUs' % n_gpus) - - # Spawn the training process - if args.local_rank == -1: - main_worker(rank=-1, world_size=1) - else: - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port - print('Spawning...') - # Number of processes spawned is equal to the number of GPUs available - mp.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, )) - - -if __name__ == '__main__': - main() From 4938d215e7844ec1341918ceb6879e2d489c0f5a Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 00:54:40 -0400 Subject: [PATCH 2/8] minor modifications to config files --- config/softseg_unet3D_balanced.json | 60 +++++++++-------------- config/softseg_unet3D_unbalanced.json | 70 +++++++++++---------------- 2 files changed, 53 insertions(+), 77 deletions(-) diff --git a/config/softseg_unet3D_balanced.json b/config/softseg_unet3D_balanced.json index ede7d1c..07c0a8b 100644 --- a/config/softseg_unet3D_balanced.json +++ b/config/softseg_unet3D_balanced.json @@ -1,14 +1,14 @@ { "command": "train", - "gpu_ids": [1], - "log_directory": "msc2021_softseg_unet3D_balanced", + "gpu_ids": [0], + "log_directory": "/home/karthik7/projects/def-laporte1/karthik7/ivadomed_older_version/ivadomed/results/msc2021_softseg_unet3D_balanced", "model_name": "ms_brain", "debugging": true, "object_detection_params": { "object_detection_path": null }, "loader_parameters": { - "bids_path": "duke/projects/ivadomed/tmp_ms_challenge_2021_preprocessed", + "path_data": "/home/karthik7/projects/def-laporte1/karthik7/ivadomed_older_version/ivadomed/datasets/tmp_ms_challenge_2021_preprocessed", "target_suffix": ["_seg-lesion"], "roi_params": { "suffix": null, @@ -32,19 +32,19 @@ "fname_split": null, "random_seed": 42, "center_test": [], - "balance": null, + "balance": null, "method": "per_patient", "train_fraction": 0.6, "test_fraction": 0.2 }, "training_parameters": { - "batch_size": 20, + "batch_size": 16, "loss": { "name": "AdapWingLoss" }, "training_time": { "num_epochs": 200, - "early_stopping_patience": 50, + "early_stopping_patience": 75, "early_stopping_epsilon": 0.001 }, "scheduler": { @@ -55,7 +55,10 @@ "max_lr": 0.01 } }, - "balance_samples": {"applied": true, "type": "gt"}, + "balance_samples": { + "applied": true, + "type": "gt" + }, "mixup_alpha": null, "transfer_learning": { "retrain_model": null, @@ -67,11 +70,11 @@ "length_3D": [128, 128, 128], "stride_3D": [64, 64, 64], "attention": true, - "n_filters": 8 + "n_filters": 16 }, "default_model": { "name": "Unet", - "dropout_rate": 0.3, + "dropout_rate": 0.2, "bn_momentum": 0.9, "depth": 4, "folder_name": "ms_brain", @@ -79,41 +82,26 @@ "out_channel": 1, "final_activation": "relu" }, - "postprocessing": {"binarize_prediction": {"thr": 0.5}}, + "postprocessing": { + "binarize_prediction": {"thr": 0.5} + }, "transformation": { "CenterCrop": { - "size": [ - 320, - 384, - 512 - ], + "size": [ 320, 384, 512 ], "preprocessing": true }, "RandomAffine": { "degrees": 20, - "scale": [ - 0.1, - 0.1, - 0.1 - ], - "translate": [ - 0.1, - 0.1, - 0.1 - ], - "applied_to": [ - "im", - "gt" - ], - "dataset_type": [ - "training" - ] + "scale": [ 0.1, 0.1, 0.1 ], + "translate": [ 0.1, 0.1, 0.1 ], + "applied_to": [ "im","gt" ], + "dataset_type": [ "training" ] + }, + "NumpyToTensor": { + "applied_to": ["im", "gt"] }, - "NumpyToTensor": {}, "NormalizeInstance": { - "applied_to": [ - "im" - ] + "applied_to": ["im"] } } } \ No newline at end of file diff --git a/config/softseg_unet3D_unbalanced.json b/config/softseg_unet3D_unbalanced.json index b694713..6179c4d 100755 --- a/config/softseg_unet3D_unbalanced.json +++ b/config/softseg_unet3D_unbalanced.json @@ -1,14 +1,14 @@ { "command": "train", - "gpu_ids": [0], - "log_directory": "msc2021_softseg_unet3D_unbalanced", + "gpu_ids": [2], + "log_directory": "results/msc2021_softseg_unet3D_unbalanced", "model_name": "ms_brain", "debugging": true, "object_detection_params": { "object_detection_path": null }, "loader_parameters": { - "bids_path": "duke/projects/ivadomed/tmp_ms_challenge_2021_preprocessed", + "path_data": "duke/projects/ivadomed/tmp_ms_challenge_2021_preprocessed", "target_suffix": ["_seg-lesion"], "roi_params": { "suffix": null, @@ -32,30 +32,33 @@ "fname_split": null, "random_seed": 42, "center_test": [], - "balance": null, + "balance": null, "method": "per_patient", "train_fraction": 0.6, "test_fraction": 0.2 }, "training_parameters": { - "batch_size": 20, + "batch_size": 16, "loss": { "name": "AdapWingLoss" }, "training_time": { "num_epochs": 200, - "early_stopping_patience": 50, + "early_stopping_patience": 75, "early_stopping_epsilon": 0.001 }, "scheduler": { - "initial_lr": 5e-05, + "initial_lr": 2e-04, "lr_scheduler": { "name": "CosineAnnealingLR", - "base_lr": 1e-05, + "base_lr": 5e-05, "max_lr": 0.01 } }, - "balance_samples": {"applied": false, "type": "gt"}, + "balance_samples": { + "applied": false, + "type": "gt" + }, "mixup_alpha": null, "transfer_learning": { "retrain_model": null, @@ -64,14 +67,14 @@ }, "Modified3DUNet": { "applied": true, - "length_3D": [128, 128, 128], - "stride_3D": [64, 64, 64], + "length_3D": [64, 64, 64], + "stride_3D": [32, 32, 32], "attention": true, - "n_filters": 8 + "n_filters": 16 }, "default_model": { "name": "Unet", - "dropout_rate": 0.3, + "dropout_rate": 0.2, "bn_momentum": 0.9, "depth": 4, "folder_name": "ms_brain", @@ -79,41 +82,26 @@ "out_channel": 1, "final_activation": "relu" }, - "postprocessing": {"binarize_prediction": {"thr": 0.5}}, + "postprocessing": { + "binarize_prediction": {"thr": 0.5} + }, "transformation": { "CenterCrop": { - "size": [ - 320, - 384, - 512 - ], + "size": [ 320, 384, 512 ], "preprocessing": true }, "RandomAffine": { "degrees": 20, - "scale": [ - 0.1, - 0.1, - 0.1 - ], - "translate": [ - 0.1, - 0.1, - 0.1 - ], - "applied_to": [ - "im", - "gt" - ], - "dataset_type": [ - "training" - ] + "scale": [ 0.1, 0.1, 0.1 ], + "translate": [ 0.1, 0.1, 0.1 ], + "applied_to": [ "im", "gt" ], + "dataset_type": [ "training" ] + }, + "NumpyToTensor": { + "applied_to": ["im", "gt"] }, - "NumpyToTensor": {}, "NormalizeInstance": { - "applied_to": [ - "im" - ] + "applied_to": [ "im" ] } } -} \ No newline at end of file +} From 55a1c8c0b40ee81ea7e3abc054331a1eed1380d1 Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 00:56:57 -0400 Subject: [PATCH 3/8] modified for training on subvolumes and choosing GT at random --- modeling/datasets.py | 248 ++++++++++++++++++++++++++++++------------- 1 file changed, 176 insertions(+), 72 deletions(-) diff --git a/modeling/datasets.py b/modeling/datasets.py index dc7d122..40d9c02 100644 --- a/modeling/datasets.py +++ b/modeling/datasets.py @@ -1,42 +1,55 @@ import os +import sys +import random from tqdm import tqdm import random import pandas as pd +import numpy as np import nibabel as nib + import torch from torch.utils.data import Dataset +from torch.utils.data import DataLoader from ivadomed.transforms import CenterCrop, RandomAffine, NormalizeInstance class MSSeg2Dataset(Dataset): - """Custom PyTorch dataset for the MSSeg2 Challenge 2021. Works only with 3D patches.""" - def __init__(self, root, patch_size=(128, 128, 128), stride_size=(64, 64, 64), - center_crop_size=(512, 512, 512), fraction_data=1.0): + """Custom PyTorch dataset for the MSSeg2 Challenge 2021. Works only with 3D subvolumes.""" + def __init__(self, root, center_crop_size=(512, 512, 512), subvolume_size=(128, 128, 128), + stride_size=(64, 64, 64), patch_size=(32, 32, 32), fraction_data=1.0, + num_gt_experts=4, use_patches=False, seed=42): super(MSSeg2Dataset).__init__() # Quick argument checks if not os.path.exists(root): raise ValueError('Specified path=%s for the challenge data can NOT be found!' % root) - if len(patch_size) != 3: - raise ValueError('The `MSChallenge3D()` expects a 3D patch size (e.g. 128x128x128)!') - if len(stride_size) != 3: - raise ValueError('The `MSChallenge3D()` expects a 3D stride size (e.g. 64x64x64)!') if len(center_crop_size) != 3: raise ValueError('The `MSChallenge3D()` expects a 3D center crop size (e.g. 512x512x512)!') + if len(subvolume_size) != 3: + raise ValueError('The `MSChallenge3D()` expects a 3D subvolume size (e.g. 128x128x128)!') + if len(stride_size) != 3: + raise ValueError('The `MSChallenge3D()` expects a 3D stride size (e.g. 64x64x64)!') + if len(patch_size) != 3: + raise ValueError('The `MSChallenge3D()` expects a 3D patch size (e.g. 32x32x32)!') - if any([center_crop_size[i] < patch_size[i] for i in range(3)]): - raise ValueError('The center crop size must be > patch size in all dimensions!') - if any([(center_crop_size[i] - patch_size[i]) % stride_size[i] != 0 for i in range(3)]): - raise ValueError('center_crop_size - patch_size % stride size must be 0 for all dimensions!') + if any([center_crop_size[i] < subvolume_size[i] for i in range(3)]): + raise ValueError('The center crop size must be >= subvolume size in all dimensions!') + if any([(center_crop_size[i] - subvolume_size[i]) % stride_size[i] != 0 for i in range(3)]): + raise ValueError('center_crop_size - subvolume_size % stride size must be 0 for all dimensions!') + if any([subvolume_size[i] < patch_size[i] for i in range(3)]): + raise ValueError('The subvolume size must be >= patch size in all dimensions!') if not 0.0 < fraction_data <= 1.0: raise ValueError('`fraction_data` needs to be between 0.0 and 1.0!') - self.patch_size = patch_size - self.stride_size = stride_size self.center_crop_size = center_crop_size + self.subvolume_size = subvolume_size + self.stride_size = stride_size + self.patch_size = patch_size + self.use_patches = use_patches + self.num_gt_experts = num_gt_experts self.train = False # Get all subjects @@ -47,95 +60,186 @@ def __init__(self, root, patch_size=(128, 128, 128), stride_size=(64, 64, 64), if fraction_data != 1.0: subjects = subjects[:int(len(subjects) * fraction_data)] - # Iterate over all subjects and extract patches - self.patches = [] + # Iterate over all subjects and extract subvolumes + self.subvolumes = [] + num_negatives, num_positives = 0, 0 for subject in tqdm(subjects, desc='Reading-in volumes'): - # Read-in volumes + # Read-in input volumes ses01 = nib.load(os.path.join(root, subject, 'anat', '%s_FLAIR.nii.gz' % subject)) ses02 = nib.load(os.path.join(root, subject, 'anat', '%s_T1w.nii.gz' % subject)) - gt = nib.load(os.path.join(root, 'derivatives', 'labels', subject, 'anat', '%s_FLAIR_seg-lesion.nii.gz' % subject)) + + # Read-in GT volumes of each expert + gts = [] + for expert_no in range(1, self.num_gt_experts+1): + temp = nib.load(os.path.join(root, 'derivatives', 'labels', subject, 'anat', '%s_FLAIR_acq-expert%d_lesion-manual.nii.gz' % (subject, expert_no))) + gts.append(temp) + + # if gt_type == 'consensus': + # gt = nib.load(os.path.join(root, 'derivatives', 'labels', subject, 'anat', '%s_FLAIR_seg-lesion.nii.gz' % subject)) # Check if image sizes and resolutions match - assert ses01.shape == ses02.shape == gt.shape - assert ses01.header['pixdim'].tolist() == ses02.header['pixdim'].tolist() == gt.header['pixdim'].tolist() + assert ses01.shape == ses02.shape == gts[0].shape + assert ses01.header['pixdim'].tolist() == ses02.header['pixdim'].tolist() == gts[0].header['pixdim'].tolist() # Convert to NumPy - ses01, ses02, gt = ses01.get_fdata(), ses02.get_fdata(), gt.get_fdata() + ses01, ses02 = ses01.get_fdata(), ses02.get_fdata() + # gt = gt.get_fdata(); print(gt.shape) + l, b, h = ses01.shape + gts_temp = np.zeros((num_gt_experts, l, b, h)) + for i in range(num_gt_experts): + gts_temp[i] = gts[i].get_fdata() # Apply center-cropping center_crop = CenterCrop(size=center_crop_size) ses01 = center_crop(sample=ses01, metadata={'crop_params': {}})[0] ses02 = center_crop(sample=ses02, metadata={'crop_params': {}})[0] - gt = center_crop(sample=gt, metadata={'crop_params': {}})[0] - - # Get patches from volumes and update the list - ses01_patches = self.volume2patches(volume=ses01) - ses02_patches = self.volume2patches(volume=ses02) - gt_patches = self.volume2patches(volume=gt) - assert len(ses01_patches) == len(ses02_patches) == len(gt_patches) - - for i in range(len(ses01_patches)): - self.patches.append({ - 'ses01': ses01_patches[i], - 'ses02': ses02_patches[i], - 'gt': gt_patches[i] - }) - - print('Extracted a total of %d patches!' % len(self.patches)) - - def set_train_mode(self): - """Enables training augmentations for the current dataset.""" - self.train = True - - def volume2patches(self, volume): - """Converts 3D volumes into 3D subvolumes, i.e. patches""" - patches = [] + # gt = center_crop(sample=gt, metadata={'crop_params': {}})[0] + gts_npy = np.zeros((num_gt_experts, center_crop_size[0], center_crop_size[1], center_crop_size[2])) + for i in range(num_gt_experts): + gts_npy[i] = center_crop(sample=gts_temp[i], metadata={'crop_params': {}})[0] + # sys.exit() + # Get subvolumes from volumes and update the list + ses01_subvolumes = self.volume2subvolumes(volume=ses01) + ses02_subvolumes = self.volume2subvolumes(volume=ses02) + # gt_subvolumes = self.volume2subvolumes(volume=gt) + gts_subvolumes = [] # so the length of this should be (4, no. of subvolumes) + for i in range(num_gt_experts): + gts_subvolumes.append(self.volume2subvolumes(volume=gts_npy[i])) + + assert len(ses01_subvolumes) == len(ses02_subvolumes) == len(gts_subvolumes[0]) + + for i in range(len(ses01_subvolumes)): + subvolumes_ = { + 'ses01': ses01_subvolumes[i], + 'ses02': ses02_subvolumes[i], + 'gt1': gts_subvolumes[0][i], + 'gt2': gts_subvolumes[1][i], + 'gt3': gts_subvolumes[2][i], + 'gt4': gts_subvolumes[3][i], + } + self.subvolumes.append(subvolumes_) + + # Skipping inbalance calculation for PU-Net implementation + # if np.any(gt_subvolumes[i]): + # num_positives += 1 + # else: + # num_negatives += 1 + + # print('Factor of inbalance is %d!' % (num_negatives // num_positives)) + + # Shuffle subvolumes just in case + # random.seed(seed) # not setting seed because we want as many different GTs as possible to be randomly chosen + random.shuffle(self.subvolumes) + print('Extracted a total of %d subvolumes!' % len(self.subvolumes)) + + def volume2subvolumes(self, volume): + """Converts 3D volumes into 3D subvolumes""" + subvolumes = [] assert volume.ndim == 3 - for x in range(0, (volume.shape[0] - self.patch_size[0]) + 1, self.stride_size[0]): - for y in range(0, (volume.shape[1] - self.patch_size[1]) + 1, self.stride_size[1]): - for z in range(0, (volume.shape[2] - self.patch_size[2]) + 1, self.stride_size[2]): - patch = volume[x: x + self.patch_size[0], - y: y + self.patch_size[1], - z: z + self.patch_size[2]] - patches.append(patch) - - return patches + for x in range(0, (volume.shape[0] - self.subvolume_size[0]) + 1, self.stride_size[0]): + for y in range(0, (volume.shape[1] - self.subvolume_size[1]) + 1, self.stride_size[1]): + for z in range(0, (volume.shape[2] - self.subvolume_size[2]) + 1, self.stride_size[2]): + subvolume = volume[x: x + self.subvolume_size[0], + y: y + self.subvolume_size[1], + z: z + self.subvolume_size[2]] + subvolumes.append(subvolume) + + return subvolumes + + # def subvolume2patches(self, subvolume): + # """Extracts 3D patches from 3D subvolumes; works with PyTorch tensors.""" + # patches = [] + # assert subvolume.ndim == 3 + # + # for x in range(0, (subvolume.shape[0] - self.patch_size[0]) + 1, self.patch_size[0]): + # for y in range(0, (subvolume.shape[1] - self.patch_size[1]) + 1, self.patch_size[1]): + # for z in range(0, (subvolume.shape[2] - self.patch_size[2]) + 1, self.patch_size[2]): + # patch = subvolume[x: x + self.patch_size[0], + # y: y + self.patch_size[1], + # z: z + self.patch_size[2]] + # patches.append(patch) + # + # num_patches = len(patches) + # patches = np.array(patches) + # assert patches.shape == (num_patches, *self.patch_size) + # + # return patches def __getitem__(self, index): - # Retrieve patches belonging to this index - patches = self.patches[index] - ses01_patches, ses02_patches, gt_patches = patches['ses01'], patches['ses02'], patches['gt'] + # Retrieve subvolumes belonging to this index + subvolumes = self.subvolumes[index] + ses01_subvolumes, ses02_subvolumes = subvolumes['ses01'], subvolumes['ses02'] + # Randomly select one of the four experts' segmented volumes + # print('gt'+str(random.randint(1, self.num_gt_experts))) + gt_subvolumes = subvolumes['gt'+str(random.randint(1, self.num_gt_experts))] # Training augmentations if self.train: # Apply random affine: rotation, translation, and scaling - # NOTE: Use of `metadata` ensures that the same affine is applied to all three patches + # NOTE: Use of `metadata` ensures that the same affine is applied to all three subvolumes random_affine = RandomAffine(degrees=20, translate=[0.1, 0.1, 0.1], scale=[0.1, 0.1, 0.1]) - ses01_patches, metadata = random_affine(sample=ses01_patches, metadata={}) - ses02_patches, _ = random_affine(sample=ses02_patches, metadata=metadata) - gt_patches, _ = random_affine(sample=gt_patches, metadata=metadata) + ses01_subvolumes, metadata = random_affine(sample=ses01_subvolumes, metadata={}) + ses02_subvolumes, _ = random_affine(sample=ses02_subvolumes, metadata=metadata) + gt_subvolumes, _ = random_affine(sample=gt_subvolumes, metadata=metadata) - # If patches are uniform: train-time -> skip to random sample, val-time -> mean-subtraction + # If subvolumes uniform: train-time -> skip to random sample, val-time -> mean-subtraction # NOTE: This will also help with discarding empty inputs! - if ses01_patches.std() < 1e-5 or ses02_patches.std() < 1e-5: + if ses01_subvolumes.std() < 1e-5 or ses02_subvolumes.std() < 1e-5: if self.train: return self.__getitem__(random.randint(0, self.__len__() - 1)) else: - ses01_patches = ses01_patches - ses01_patches.mean() - ses02_patches = ses02_patches - ses02_patches.mean() + ses01_subvolumes = ses01_subvolumes - ses01_subvolumes.mean() + ses02_subvolumes = ses02_subvolumes - ses02_subvolumes.mean() # Normalize images to zero mean and unit variance else: normalize_instance = NormalizeInstance() - ses01_patches, _ = normalize_instance(sample=ses01_patches, metadata={}) - ses02_patches, _ = normalize_instance(sample=ses02_patches, metadata={}) + ses01_subvolumes, _ = normalize_instance(sample=ses01_subvolumes, metadata={}) + ses02_subvolumes, _ = normalize_instance(sample=ses02_subvolumes, metadata={}) + # Return subvolumes, i.e. don't extract patches # Conversion to PyTorch tensors - x1 = torch.tensor(ses01_patches, dtype=torch.float) - x2 = torch.tensor(ses02_patches, dtype=torch.float) - y = torch.tensor(gt_patches, dtype=torch.float) - - return x1, x2, y + x1 = torch.tensor(ses01_subvolumes, dtype=torch.float) + x2 = torch.tensor(ses02_subvolumes, dtype=torch.float) + seg_y = torch.tensor(gt_subvolumes, dtype=torch.float) + + return x1, x2, seg_y + + # # Extract & return patches from subvolumes (if applicable -> needed for transformer model) + # if self.use_patches and self.subvolume_size != self.patch_size: + # ses01_patches = self.subvolume2patches(subvolume=ses01_subvolumes) + # ses02_patches = self.subvolume2patches(subvolume=ses02_subvolumes) + # gt_patches = self.subvolume2patches(subvolume=gt_subvolumes) + # + # # Compute classification GTs + # clf_gt_patches = [int(np.any(gt_patch)) for gt_patch in gt_patches] + # + # # Conversion to PyTorch tensors + # x1 = torch.tensor(ses01_patches, dtype=torch.float) + # x2 = torch.tensor(ses02_patches, dtype=torch.float) + # seg_y = torch.tensor(gt_patches, dtype=torch.float) + # clf_y = torch.tensor(clf_gt_patches, dtype=torch.long) + # # NOTE: The times two is added because of x1 and x2 logic in our current model + # + # return x1, x2, seg_y, clf_y + # + # # Return subvolumes, i.e. don't extract patches + # else: + # # Conversion to PyTorch tensors + # x1 = torch.tensor(ses01_subvolumes, dtype=torch.float) + # x2 = torch.tensor(ses02_subvolumes, dtype=torch.float) + # seg_y = torch.tensor(gt_subvolumes, dtype=torch.float) + # + # return x1, x2, seg_y def __len__(self): - return len(self.patches) + return len(self.subvolumes) + + +if __name__ == "__main__": + root = "/home/GRAMES.POLYMTL.CA/u114716/duke/projects/ivadomed/tmp_ms_challenge_2021_preprocessed" + data = MSSeg2Dataset(root) + data_loader = DataLoader(dataset=data, batch_size=1) + x1, x2, y = next(iter(data_loader)) + print(x1.shape, y.shape) + From 7d25bb83ac76d234f25f368970a5070424e69e9a Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 00:57:38 -0400 Subject: [PATCH 4/8] added utilities - weight initializations for PU-Net --- modeling/utils.py | 57 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/modeling/utils.py b/modeling/utils.py index 8b2de63..c7fb534 100644 --- a/modeling/utils.py +++ b/modeling/utils.py @@ -1,21 +1,58 @@ from copy import deepcopy from torch.utils.data import Subset from sklearn.model_selection import train_test_split +import torch.nn as nn -def split_dataset(dataset, val_size, seed): - train_idx, val_idx = train_test_split(list(range(len(dataset))), - test_size=val_size, - random_state=seed) +# Not using seed to ensure that as many different segmentation masks are chosen at random +def split_dataset(dataset, val_size): + train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_size) - # Copy the dataset into a new object, and set train mode for the old object - # NOTE: This is one "hack" to make sure val dataset doesn't have training augmentations - dataset_ = deepcopy(dataset) - dataset.set_train_mode() + # # Copy the dataset into a new object, and set train mode for the old object + # # NOTE: This is one "hack" to make sure val dataset doesn't have training augmentations + # dataset_ = deepcopy(dataset) + # dataset.set_train_mode() + # train_dataset, val_dataset = Subset(dataset, train_idx), Subset(dataset_, val_idx) - train_dataset, val_dataset = Subset(dataset, train_idx), Subset(dataset_, val_idx) + train_dataset, val_dataset = Subset(dataset, train_idx), Subset(dataset, val_idx) print('Divided base dataset of size %d into %d [1] and %d [2] sub-datasets!' % (len(dataset), len(train_dataset), len(val_dataset))) - return train_dataset, val_dataset \ No newline at end of file + return train_dataset, val_dataset + + +def truncated_normal_(tensor, mean=0.0, std=1.0): + size = tensor.shape + temp = tensor.new_empty(size + (4,)).normal_() + valid = (temp < 2) & (temp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(temp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + + +def init_weights(m): + if type(m) == nn.Conv3d: + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + truncated_normal_(m.bias, mean=0.0, std=0.001) + # nn.init.normal_(m.weight, std=0.001) + # nn.init.normal_(m.bias, std=0.001) + + +def init_weights_orthogonal_normal(m): + if type(m) == nn.Conv3d: + nn.init.orthogonal_(m.weight) + truncated_normal_(m.bias, mean=0.0, std=0.001) + + +def l2_regularization(m): + l2_reg = None + for w in m.parameters(): + if l2_reg is None: + l2_reg = w.norm(2) + else: + l2_reg = l2_reg + w.norm(2) + + return l2_reg + + From 51faf4064646fb8e2e99d50d169d7f337fbc5fd6 Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 00:58:25 -0400 Subject: [PATCH 5/8] added implementation of PU-Net --- modeling/probabilistic_unet.py | 332 +++++++++++++++++++++++++++++++++ modeling/unet.py | 122 ++++++++++++ 2 files changed, 454 insertions(+) create mode 100644 modeling/probabilistic_unet.py create mode 100644 modeling/unet.py diff --git a/modeling/probabilistic_unet.py b/modeling/probabilistic_unet.py new file mode 100644 index 0000000..80c148a --- /dev/null +++ b/modeling/probabilistic_unet.py @@ -0,0 +1,332 @@ +# Code adapted from https://github.com/stefanknegt/Probabilistic-Unet-Pytorch/blob/master/probabilistic_unet.py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal, Independent, kl + +from unet import UpConvBlock, DownConvBlock +from unet import Unet +from utils import init_weights, init_weights_orthogonal_normal, l2_regularization + +from ivadomed.losses import DiceLoss + + +class Encoder(nn.Module): + """ + A simple CNN consisting of convs_per_block convolutional layers + Relu activations per block for len(num_feat_maps) + blocks with pooling between consecutive blocks + """ + def __init__(self, in_channels, num_feat_maps, convs_per_block, initializers, padding=True, posterior=False): + super(Encoder, self).__init__() + self.contracting_path = nn.ModuleList() + self.in_channels = in_channels + self.num_feat_maps = num_feat_maps + + if posterior: + # Recall the posterior net is conditioned upon the GT so 1 additional input channel + self.in_channels += 1 + + layers = [] + for i in range(len(self.num_feat_maps)): + in_dim = self.in_channels if i == 0 else out_dim + out_dim = num_feat_maps[i] + if i != 0: + layers.append(nn.AvgPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=True)) + + layers.append(nn.Conv3d(in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=int(padding))) + layers.append(nn.ReLU(inplace=True)) + for _ in range(convs_per_block-1): + layers.append(nn.Conv3d(in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=int(padding))) + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + self.layers.apply(init_weights) # for kaiming_normal initialization + # self.layers.apply(init_weights_orthogonal_normal) # for orthogonal weight initialization + + def forward(self, x): + output = self.layers(x) + # print("Reached: ", output.shape) + return output + + +class AxisAlignedConvGaussian(nn.Module): + """ + A ConvNet that parameterizes a Gaussian distribution with axis-aligned covariance matrix + """ + def __init__(self, in_channels, num_feat_maps, convs_per_block, latent_dim, initializers, posterior=False): + super(AxisAlignedConvGaussian, self).__init__() + self.in_channels = in_channels + self.channel_axis = 1 + self.num_feat_maps = num_feat_maps + self.convs_per_block = convs_per_block + self.latent_dim = latent_dim + self.posterior = posterior + if self.posterior: + self.name = "Posterior" + else: + self.name = "Prior" + self.encoder = Encoder(self.in_channels, self.num_feat_maps, self.convs_per_block, initializers, self.posterior) + self.conv_layer = nn.Conv3d(num_feat_maps[-1], 2*self.latent_dim, kernel_size=1, stride=1) + self.show_img = 0 + self.show_seg = 0 + self.show_concat = 0 + self.show_enc = 0 + self.sum_input = 0 + + nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') + nn.init.normal_(self.conv_layer.bias) + # using the orthogonal weight initialization + # self.conv_layer.apply(init_weights_orthogonal_normal) + + def forward(self, inp, seg=None): + # If segmentation is not None, then concatenate it to the channel axis of the input + if seg is not None: + self.show_img = inp + self.show_seg = seg + # print("input shape: ", inp.shape); print("seg mask shape: ", seg.shape) + inp = torch.cat([inp, seg], dim=1) + # print("concatenated input and GT shape", inp.shape) + self.show_concat = inp + self.sum_input = torch.sum(inp) + + encoding = self.encoder(inp) + self.show_enc = encoding + # for getting the mean of the resulting volume --> (batch-size x Channels x Depth x Height x Width) + encoding = torch.mean(encoding, dim=2, keepdim=True) + encoding = torch.mean(encoding, dim=3, keepdim=True) + encoding = torch.mean(encoding, dim=4, keepdim=True) + + # Convert the encoding into 2xlatent_dim in the output_channels and split into mu and log_sigma + mu_log_sigma = self.conv_layer(encoding) # shape: (B x (2*latent_dim) x 1 x 1 x 1) + + # Squeeze all the singleton dimensions + mu_log_sigma = torch.squeeze(mu_log_sigma) # shape: (B x (2*latent_dim) ) + + mu = mu_log_sigma[:, :self.latent_dim] # take the first "latent_dim" samples as mu + log_sigma = mu_log_sigma[:, self.latent_dim:] # take the remaining as log_sigma + + # This is the multivariate normal with diagonal covariance matrix + # https://github.com/pytorch/pytorch/pull/11178 (see below comments) + dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) + return dist + + +class FComb(nn.Module): + """ + As in the paper, this class creates convs_fcomb number of 1x1 conv. layers that combines the random sample taken + from the latent space and concatenates it with the output of the U-Net (along the channel axis) to get + the final prediction mask. + """ + def __init__(self, num_feat_maps, latent_dim, num_out_channels, num_classes, convs_fcomb, + initializers, use_tile=True): + super(FComb, self).__init__() + self.num_out_channels = num_out_channels + self.num_classes = num_classes + self.channel_axis = 1 + # self.spatial_axes = [2, 3] # 2,3? for images' H x W dimensions + self.spatial_axes = [2, 3, 4] # for volumes' D x H X W dimensions + self.num_feat_maps = num_feat_maps + self.latent_dim = latent_dim + self.use_tile = use_tile + self.convs_fcomb = convs_fcomb + self.name = "FComb" + + if self.use_tile: + # creating a small decoder containing N - (1x1 Conv + ReLU) blocks except for the last layer + layers = [ + nn.Conv3d(self.num_feat_maps[0] + self.latent_dim, self.num_feat_maps[0], kernel_size=1), + nn.ReLU(inplace=True) + ] + + for _ in range(convs_fcomb-2): + layers.append(nn.Conv3d(self.num_feat_maps[0], self.num_feat_maps[0], kernel_size=1)) + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + self.last_layer = nn.Conv3d(self.num_feat_maps[0], self.num_classes, kernel_size=1) + + if initializers['w'] == 'orthogonal': + self.layers.apply(init_weights_orthogonal_normal) + self.last_layer.apply(init_weights_orthogonal_normal) + else: + self.layers.apply(init_weights) + self.last_layer.apply(init_weights) + + def tile(self, a, dim, n_tile): + """ + This function is taken form PyTorch forum and mimics the behavior of tf.tile. + Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 + """ + init_dim = a.size(dim) + repeat_idx = [1] * a.dim() + repeat_idx[dim] = n_tile + a = a.repeat(*repeat_idx) + order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda() + # order_index = torch.LongTensor(torch.cat([init_dim * torch.arange(n_tile) + i for i in range(init_dim)])) # .to(device) + return torch.index_select(a, dim, order_index) + + def forward(self, feature_map, z): + """ + Z is (batch_size x latent_dim) and feature_map is (batch_size x num_channels x P x P x P) + Z is broadcasted to (batch_size(B) x latent_dim(LD) x P x P x P). (Just like tensorflow's tf.tile function) + """ + if self.use_tile: + z = torch.unsqueeze(z, dim=2) # shape: (B x LD x 1) + z = self.tile(z, dim=2, n_tile=feature_map.shape[self.spatial_axes[0]]) # shape: (B x LD x P) + z = torch.unsqueeze(z, dim=3) # shape: (B x LD x P x 1) + z = self.tile(z, dim=3, n_tile=feature_map.shape[self.spatial_axes[1]]) # shape: (B x LD x P x P) + z = torch.unsqueeze(z, dim=4) # shape: (B x LD x P x P x 1) + z = self.tile(z, dim=4, n_tile=feature_map.shape[self.spatial_axes[2]]) # shape: (B x LD x P x P x P) + + # Concatenate UNet's output feature map and a sample taken from latent space + feature_map = torch.cat((feature_map, z), dim=self.channel_axis) + output = self.layers(feature_map) + return self.last_layer(output) + + +class ProbabilisticUnet(nn.Module): + """ + An implementation of the probabilistic U-net + in_channels = number of channels in the input image (1 greyscale and 3 rgb) (int) + num_classes = number of output classes to predict (int) + num_feat_maps = list of number of feature maps (filters) per layer/resolution + latent_dim = dimension of the latent space (int) + convs_fcomb = number of (1x1) convolutional layers per block combining the latent space sample to U-Net's feat. map + beta = weighting parameter for the cross-entropy loss and KL divergence. + """ + def __init__(self, in_channels=2, num_classes=1, num_feat_maps=[32, 64, 128, 192], latent_dim=6, + convs_fcomb=4, beta=10.0): + super(ProbabilisticUnet, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + self.num_feat_maps = num_feat_maps + self.latent_dim = latent_dim + self.convs_per_block = 4 + self.convs_fcomb = convs_fcomb + self.initializers = {'w': 'he_normal', 'b': 'normal'} + # self.initializers = {'w': 'orthogonal', 'b': 'normal'} # orthogonal weight initialization + self.beta = beta + self.z_prior_sample = 0 + + # Instantiating the networks + self.unet = Unet(self.in_channels, self.num_classes, self.num_feat_maps, self.initializers, padding=True) + self.prior = AxisAlignedConvGaussian(self.in_channels, self.num_feat_maps, self.convs_per_block, + self.latent_dim, self.initializers, posterior=False) # .to(device) + # NOTE: in_channels + 1 is used because the encoder for posterior net was not getting initialized properly. + self.posterior = AxisAlignedConvGaussian(self.in_channels + 1, self.num_feat_maps, self.convs_per_block, + self.latent_dim, self.initializers, posterior=True) # .to(device) + self.fcomb = FComb(self.num_feat_maps, self.latent_dim, self.in_channels, self.num_classes, self.convs_fcomb, + initializers={'w': 'orthogonal', 'b': 'normal'}, use_tile=True) # .to(device) + + def forward(self, patch, seg_mask, training=True): + """ + Construct the prior latent space for the input patch and also pass it through the U-Net. + If training is true, construct a posterior latent space also. + """ + if training: + self.posterior_latent_space = self.posterior.forward(patch, seg=seg_mask) # conditioned upon GT + self.prior_latent_space = self.prior.forward(patch) # NOT conditioned upon the GT, just the input patch + self.unet_features = self.unet.forward(patch, False) + + def sample(self, testing=False): + """ + Sample a segmentation mask by picking a random sample and concatenating with U-Net's feature maps + Difference b/w "rsample()" and "sample()" in PyTorch Distributions: + rsample is used whenever the gradients of the distribution parameters w.r.t the functions of the samples + need to be computed (i.e. it supports differentiation through the sampler). Useful for reparameterization + trick in VAEs where backprop is possible through the mean and std parameters + More on this: https://stackoverflow.com/questions/60533150/what-is-the-difference-between-sample-and-rsample + and https://forum.pyro.ai/t/sample-vs-rsample/2344 + """ + if not testing: + z_prior = self.prior_latent_space.rsample() + self.z_prior_sample = z_prior + else: + z_prior = self.prior_latent_space.sample() + self.z_prior_sample = z_prior + + return self.fcomb.forward(self.unet_features, z_prior) + + def reconstruct(self, use_posterior_mean=False, calc_posterior=False, z_posterior=None): + """ + Reconstruct a segmentation map by sampling from the posterior latent space and combine it with U-Net + use_posterior_mean: i.e. use posterior mean instead of just sampling z from Q + calc_posterior: use a provided sample or sample fresh from the posterior latent space + """ + if use_posterior_mean: + z_posterior = self.posterior_latent_space.loc + else: + if calc_posterior: + z_posterior = self.posterior_latent_space.rsample() + + return self.fcomb.forward(self.unet_features, z_posterior) + + def kl_divergence(self, analytic=True, calc_posterior=False, z_posterior=None): + """ + Calculate the KL Divergence between the posterior and prior latent distributions i.e. KL(Q||P) + analytic: calculate KL div. analytically or by sampling from the posterior + calc_posterior: sample here if sampling is used to approximate KL or supply a sample + """ + if analytic: + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + else: + if calc_posterior: + z_posterior = self.posterior_latent_space.rsample() + log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) + log_prior_prob = self.prior_latent_space.log_prob(z_posterior) + kl_div = log_posterior_prob - log_prior_prob + + return kl_div + + def elbo(self, seg_mask, analytic_kl=True, reconstruct_posterior_mean=False): + """ + Calculate the Evidence Lower Bound of the likelihood P(Y|X) + """ + z_posterior = self.posterior_latent_space.rsample() + # use the posterior sample above to get a predicted segmentation mask + self.reconstructed_mask = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, calc_posterior=False, + z_posterior=z_posterior) + # print("Shape of Predicted Mask: ", self.reconstructed_mask.shape) # shape: (4 x 1 x 128 x 128 x 128) + + # 1st half of the loss function + # criterion = nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction='none') + # reconstruction_loss = criterion(input=self.reconstructed_mask, target=seg_mask) + # self.reconstruction_loss = torch.sum(reconstruction_loss) + # self.mean_reconstruction_loss = torch.mean(reconstruction_loss) + + criterion = nn.BCEWithLogitsLoss() + self.reconstruction_loss = criterion(input=self.reconstructed_mask, target=seg_mask) + # self.reconstruction_loss = torch.sum(reconstruction_loss) + # self.mean_reconstruction_loss = torch.mean(reconstruction_loss) + + # # TODO: use DiceLoss as the criterion instead? --> Uncomment below lines + # criterion = DiceLoss(smooth=1.0) + # self.reconstruction_loss = criterion(input=self.reconstructed_mask, target=seg_mask) + + # 2nd half of the loss function + self.kl = torch.mean(self.kl_divergence(analytic=analytic_kl, calc_posterior=False, z_posterior=z_posterior)) + + # Full loss + final_loss = -(self.reconstruction_loss + self.beta * self.kl) + + print("\n") + print(f'\tReconstruction Loss: %0.6f | KL Divergence: %0.4f | Final ELBO: %0.4f' + % (self.reconstruction_loss.item(), self.kl.item(), final_loss.item())) + + return self.reconstructed_mask, final_loss + + +if __name__ == "__main__": + inp = torch.randn(4, 2, 128, 128, 128) + segm = torch.randn(4, 1, 128, 128, 128) + in_channels = 3 # works for posterior net + # in_channels = 2 # works for prior net + convs_per_block, latent_dim = 3, 6 + num_feat_maps = [32, 64, 128, 192] + initializers = {'w': 'he_normal', 'b': 'normal'} + posterior = True + + net = AxisAlignedConvGaussian(in_channels, num_feat_maps, convs_per_block, latent_dim, initializers, posterior) + net.forward(inp, segm) # use for posterior net + # net.forward(inp) # use for prior net diff --git a/modeling/unet.py b/modeling/unet.py new file mode 100644 index 0000000..ef2df30 --- /dev/null +++ b/modeling/unet.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from utils import init_weights + + +class DownConvBlock(nn.Module): + """ + Each block consists of 3 conv layers with ReLU non-linear activation. A pooling layer is added b/w each block. + """ + def __init__(self, in_dim, out_dim, initializers, padding, pool=True): + super(DownConvBlock, self).__init__() + layers = [] + if pool: + layers.append(nn.AvgPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=True)) + + layers.append(nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=int(padding))) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv3d(out_dim, out_dim, kernel_size=3, stride=1, padding=int(padding))) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv3d(out_dim, out_dim, kernel_size=3, stride=1, padding=int(padding))) + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + self.layers.apply(init_weights) + + def forward(self, input): + return self.layers(input) + + +class UpConvBlock(nn.Module): + """ + Consists of a trilinear upsampling layer followed by a convolutional layers and then a DownConvBlock + """ + def __init__(self, in_dim, out_dim, initializers, padding, trilinear=True): + super(UpConvBlock, self).__init__() + self.trilinear = trilinear + self.conv_block = DownConvBlock(in_dim, out_dim, initializers, padding, pool=False) + + def forward(self, x, bridge): + if self.trilinear: + up = nn.functional.interpolate(x, mode='trilinear', scale_factor=2, align_corners=True) + + # print(up.shape, "\t", bridge.shape) + assert up.shape[3] == bridge.shape[3] # checks if the first dimension of the inputs match + out = torch.cat([up, bridge], 1) + # print("shape after concatenation: ", out.shape) + out = self.conv_block(out) + # print("shape after Concat+Downblock: ", out.shape) + + return out + + +class Unet(nn.Module): + """ + Implementation of the standard U-Net module. + """ + def __init__(self, in_channels, num_classes, num_feat_maps, initializers, padding=True): + super(Unet, self).__init__() + # print(padding) + self.in_channels = in_channels + self.num_classes = num_classes + self.num_feat_maps = num_feat_maps + self.padding = padding + self.activation_maps = [] + self.contracting_path = nn.ModuleList() + + # Contractive Path + for i in range(len(self.num_feat_maps)): + inp = self.in_channels if i == 0 else out + out = self.num_feat_maps[i] + if i == 0: + pool = False + else: + pool = True + self.contracting_path.append(DownConvBlock(inp, out, initializers, padding, pool=pool)) + + # print(self.contracting_path) + + # Upsampling Path + self.upsampling_path = nn.ModuleList() + n = len(self.num_feat_maps) - 2 + for i in range(n, -1, -1): + inp = out + self.num_feat_maps[i] # sets the right no. of input channels for Concat+DownBlock + out = self.num_feat_maps[i] # sets the right no. of output channels for next (Concat+DownBlock)'s input + self.upsampling_path.append(UpConvBlock(inp, out, initializers, padding)) + + # print(self.upsampling_path) + + def forward(self, x, val): + blocks = [] + for i, down in enumerate(self.contracting_path): + # print(i, down) + x = down(x) + # print("After DownConv: ", i, "\t", x.shape) + if i != len(self.contracting_path)-1: + blocks.append(x) + + for i, up in enumerate(self.upsampling_path): + # print("Before UpConv: ", x.shape, "\t", blocks[-i-1].shape) + # print(up) + x = up(x, blocks[-i-1]) + # print("After UpConv: ", i, x.shape) + + del blocks + + # for saving the activations + if val: + self.activation_maps.append(x) + + return x + + +if __name__ == "__main__": + inp = torch.randn(4, 1, 128, 128, 128) + net = Unet(in_channels=1, num_classes=1, num_feat_maps=[16, 32, 64, 128]) + out = net(inp, False) + print(out.shape) + + From 528532fadbba9ae4744b40689a554985dffbf33b Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 00:58:51 -0400 Subject: [PATCH 6/8] added training method --- modeling/training.py | 273 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 modeling/training.py diff --git a/modeling/training.py b/modeling/training.py new file mode 100644 index 0000000..e3e9fe0 --- /dev/null +++ b/modeling/training.py @@ -0,0 +1,273 @@ +import os +import argparse +from tqdm import tqdm +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.multiprocessing as mp +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler +import torch.distributed + +from ivadomed.losses import AdapWingLoss, DiceLoss + +from datasets import MSSeg2Dataset +from probabilistic_unet import ProbabilisticUnet +from utils import l2_regularization, split_dataset + +parser = argparse.ArgumentParser(description='Script for training Probabilistic U-Net for MSSeg2 Challenge 2021.') + +# Arguments for model, data, and training +parser.add_argument('-id', '--model_id', default='punet', type=str, + help='Model ID to-be-used for saving the .pt saved model file') +parser.add_argument('-dr', '--dataset_root', + default='/home/GRAMES.POLYMTL.CA/u114716/duke/projects/ivadomed/tmp_ms_challenge_2021_preprocessed', + type=str, help='Root path to the BIDS- and ivadomed-compatible dataset') +parser.add_argument('-fd', '--fraction_data', default=0.1, type=float, help='Fraction of data to use (for debugging)') + +parser.add_argument('-ne', '--num_epochs', default=20, type=int, help='Number of epochs for the training process') +parser.add_argument('-bs', '--batch_size', default=4, type=int, help='Batch size for training and validation processes') +parser.add_argument('-nw', '--num_workers', default=4, type=int, help='Number of workers for the dataloaders') + +parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Learning rate for training') +parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='Decay terms for the AdamW optimizer') +parser.add_argument('-wd', '--weight_decay', type=float, default=0.01, help='Decay value in AdamW') +parser.add_argument('--eps', type=float, default=1e-8, help='Epsilon value for the AdamW optimizer') + +# PU-Net-specific Arguments +parser.add_argument('-ldim', '--latent_dim', default=6, type=int, help='Dimensionality of the latent space') +parser.add_argument('-nfc', '--num_fcomb', default=4, type=int, help='No. of 1x1 conv blocks to concat with UNet') +# TODO: change beta values and see how the training progresses! +parser.add_argument('--beta', default=10.0, type=float, help='Weighting factor for the ELBO loss function') + +parser.add_argument('-sv', '--save', default='../results', type=str, help='Path to the saved models directory') +parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', + help='Load model from checkpoint and continue training') +parser.add_argument('-se', '--seed', default=42, type=int, help='Set seeds for reproducibility') + +# Arguments for parallelization +parser.add_argument('-loc', '--local_rank', type=int, default=-1, + help='Local rank for distributed training on GPUs; set != -1 to start distributed GPU training') +parser.add_argument('--master_addr', type=str, default='localhost', + help='Address of master; master must be able to accept network traffic on the address and port') +parser.add_argument('--master_port', type=str, default='29500', help='Port that master is listening on') + +args = parser.parse_args() + + +def main_worker(rank, world_size): + # Configure model ID + model_id = args.model_id + print('MODEL ID: %s' % model_id) + print('RANK: ', rank) + + # Configure saved models directory + if not os.path.isdir(args.save): + os.makedirs(args.save) + print('Trained model will be saved to: %s' % args.save) + + if args.local_rank == -1: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + else: + torch.cuda.set_device(rank) + device = torch.device('cuda', rank) + # Use NCCL for GPU training, process must have exclusive access to GPUs + torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=world_size) + + model = ProbabilisticUnet(in_channels=2, num_classes=1, num_feat_maps=[32, 64, 128, 192], + latent_dim=args.latent_dim, convs_fcomb=args.num_fcomb, beta=args.beta) + + # Load saved model if applicable + if args.continue_from_checkpoint: + load_path = os.path.join('saved_models', '%s.pt' % model_id) + print('Loading learned weights from %s' % load_path) + state_dict = torch.load(load_path) + state_dict_ = deepcopy(state_dict) + # Rename parameters to exclude the starting 'module.' string so they match + # NOTE: We have to do this because of DataParallel saving parameters starting with 'module.' + for param in state_dict: + state_dict_[param.replace('module.', '')] = state_dict_.pop(param) + model.load_state_dict(state_dict_) + else: + print('Initializing model from scratch') + + if torch.cuda.device_count() > 1 and args.local_rank == -1: + model = nn.DataParallel(model) + model.to(device) + elif args.local_rank == -1: + model.to(device) + else: + model.to(device) + try: + from apex.parallel import DistributedDataParallel as DDP + print("Found Apex!") + model = DDP(model) + except ImportError: + from torch.nn.parallel import DistributedDataParallel as DDP + print("Using PyTorch DDP - could not find Apex") + model = DDP(model, device_ids=[rank]) + + # Load datasets + dataset = MSSeg2Dataset(root=args.dataset_root, subvolume_size=(128, 128, 128), stride_size=(64, 64, 64), + center_crop_size=(320, 384, 512), fraction_data=args.fraction_data, num_gt_experts=4) + + train_dataset, val_dataset = split_dataset(dataset=dataset, val_size=0.3) + # TODO: We also need test set, right? + + if args.local_rank == -1: + train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, + pin_memory=True, num_workers=args.num_workers) + val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=True, num_workers=args.num_workers) + else: + train_sampler = DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=rank) + val_sampler = DistributedSampler(dataset=val_dataset, num_replicas=world_size, rank=rank) + train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, + shuffle=False, pin_memory=True, + num_workers=args.num_workers, sampler=train_sampler) + # NOTE: Train loader's shuffle is made False; using DistributedSampler instead + val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, + shuffle=False, pin_memory=True, + num_workers=args.num_workers, sampler=val_sampler) + + # # Setup optimizer + # no_weight_decay_ids = ['bias', 'LayerNorm.weight'] + # grouped_model_parameters = [ + # {'params': [param for name, param in model.named_parameters() + # if not any(id_ in name for id_ in no_weight_decay_ids)], + # 'lr': args.transformer_learning_rate, + # 'betas': args.betas, + # 'weight_decay': args.weight_decay, + # 'eps': args.eps}, + # {'params': [param for name, param in model.named_parameters() + # if any(id_ in name for id_ in no_weight_decay_ids)], + # 'lr': args.transformer_learning_rate, + # 'betas': args.betas, + # 'weight_decay': 0.0, + # 'eps': args.eps}, + # ] + # + # optimizer = AdamW(grouped_model_parameters) + # num_training_steps = int(args.num_epochs * len(train_dataset) / args.batch_size) + # num_warmup_steps = 0 # int(num_training_steps * 0.02) + # scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, + # num_warmup_steps=num_warmup_steps, + # num_training_steps=num_training_steps) + # + # # Setup loss function + # criterion = AdapWingLoss(theta=0.5, alpha=2.1, omega=14, epsilon=1) + + # Setup optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=0.0) + # TODO: use a learning rate scheduler to decay lr upto 1e-6 + + # Setup other metrics + dice_metric = DiceLoss(smooth=1.0) + + # Training & Evaluation + for i in tqdm(range(args.num_epochs), desc='Iterating over Epochs'): + # -------------------------------- TRAINING ------------------------------------ + model.train() + train_epoch_loss, train_epoch_dice = 0.0, 0.0 + + for batch in tqdm(train_loader, desc='Iterating over Training Examples'): + # Glossary --> B: batch_size; SV: subvolume_size; P: patch_size + x1, x2, y = batch + # print(x1.shape, "\t", x2.shape, "\t", y.shape) + x1, x2, y = x1.to(device), x2.to(device), y.to(device) # size of x1,x2, y: (B x P x P x P) + + # Unsqueeze input patches in the channel dimension + x1, x2, y = x1.unsqueeze(dim=1), x2.unsqueeze(dim=1), torch.unsqueeze(y, dim=1) # size:(B x 1 x P x P x P) + # print(x1.shape, "\t", x2.shape, "\t", y.shape) + + # Concatenate time-points to get single input which is "2 x " the original input size + x = torch.cat([x1, x2], dim=1).to(device) # size of x: (B x 2 x P x P x P) + # print(x.shape) + + model.forward(patch=x, seg_mask=y, training=True) + + y_hat, elbo = model.elbo(y) # y_hat is the reconstructed mask - shape: (B x 1 x P x P x P) + reg_loss = l2_regularization(model.posterior) + l2_regularization(model.prior) + \ + l2_regularization(model.fcomb.layers) + loss = -elbo + 1e-5 * reg_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_epoch_loss += loss.item() + train_epoch_dice += dice_metric(y_hat, y).item() + + train_epoch_loss /= len(train_loader) + train_epoch_dice /= len(train_loader) + + # -------------------------------- EVALUATION ------------------------------------ + # Evaluation done by sampling from the prior latent distribution "NP" no. of times combining with + # U-Net's features. These NP predictions are concatenated and their mean is taken, which is the final predicted + # segmentation. The Dice score is calculated b/w this prediction and the GT. + # TODO: A random GT out of the 4 is chosen now. Should it be compared with the consensus (i.e. the 5th) GT? + + model.eval() + # val_epoch_loss = 0.0 # not calculating loss in validation, only Dice + val_epoch_dice = 0.0 + + with torch.no_grad(): + for batch in tqdm(val_loader, desc='Iterating over Validation Examples'): + x1, x2, y = batch + x1, x2, y = x1.to(device), x2.to(device), y.to(device) + + x1, x2 = x1.unsqueeze(dim=1), x2.unsqueeze(dim=1) + x = torch.cat([x1, x2], dim=1).to(device) + + model.forward(patch=x, seg_mask=None, training=False) + + num_predictions = 5 # NP + predictions = [] + for _ in range(num_predictions): + mask_pred = model.sample(testing=True) + mask_pred = (torch.sigmoid(mask_pred) > 0.5).float() # shape: (B x 1 x P x P x P) + predictions.append(mask_pred) + predictions = torch.cat(predictions, dim=1) # shape: (B x NP x P x P x P) + + y_pred = torch.squeeze(torch.mean(predictions, dim=1)) # y_pred is the mean of all NP predictions + # shape: (B x P x P x P) + + val_epoch_dice += dice_metric(y_pred, y) + + # val_epoch_loss /= len(val_loader) + val_epoch_dice /= len(val_loader) + + torch.save(model.state_dict(), os.path.join(args.save, '%s.pt' % model_id)) + + print('\n') # Do this in order to go below the second tqdm line + print(f'\tTrain Loss: %0.4f | Train Dice: %0.6f' % (train_epoch_loss, train_epoch_dice)) + print(f'\tValidation Dice: %0.6f' % val_epoch_dice) + + # # Apply learning rate decay before the beginning of next epoch if applicable + # scheduler.step() + + # Cleanup DDP if applicable + if args.local_rank != -1: + torch.distributed.destroy_process_group() + + +def main(): + # Configure number of GPUs to be used + n_gpus = torch.cuda.device_count() + print('We are using %d GPUs' % n_gpus) + + # Spawn the training process + if args.local_rank == -1: + main_worker(rank=-1, world_size=1) + else: + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + print('Spawning...') + # Number of processes spawned is equal to the number of GPUs available + mp.spawn(main_worker, nprocs=n_gpus, args=(n_gpus,)) + + +if __name__ == '__main__': + main() From a3d07225503ad5017b60b4f29ee7c025d91c5314 Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 01:13:36 -0400 Subject: [PATCH 7/8] added ReadMe file --- modeling/ReadMe.md | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 modeling/ReadMe.md diff --git a/modeling/ReadMe.md b/modeling/ReadMe.md new file mode 100644 index 0000000..eec7ea2 --- /dev/null +++ b/modeling/ReadMe.md @@ -0,0 +1,9 @@ +### Probabilistic U-Net for Longitudinal MS Lesion Segmentation +This contains the first attempt of using PU-Net for segmenting MS Lesions. + +The current (tested) version supports single-GPU training. How to run: + +``` +$ export CUDA_VISIBLE_DEVICES= +$ python training.py -fd 1.0 +``` From bdb1bf6a285627833b85782ff0968513cef8ca13 Mon Sep 17 00:00:00 2001 From: naga-karthik Date: Tue, 22 Jun 2021 23:06:50 -0400 Subject: [PATCH 8/8] added mixed-precision training functionality --- modeling/training.py | 55 +++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/modeling/training.py b/modeling/training.py index e3e9fe0..e7bfa5c 100644 --- a/modeling/training.py +++ b/modeling/training.py @@ -25,10 +25,10 @@ parser.add_argument('-dr', '--dataset_root', default='/home/GRAMES.POLYMTL.CA/u114716/duke/projects/ivadomed/tmp_ms_challenge_2021_preprocessed', type=str, help='Root path to the BIDS- and ivadomed-compatible dataset') -parser.add_argument('-fd', '--fraction_data', default=0.1, type=float, help='Fraction of data to use (for debugging)') +parser.add_argument('-fd', '--fraction_data', default=1.0, type=float, help='Fraction of data to use (for debugging)') parser.add_argument('-ne', '--num_epochs', default=20, type=int, help='Number of epochs for the training process') -parser.add_argument('-bs', '--batch_size', default=4, type=int, help='Batch size for training and validation processes') +parser.add_argument('-bs', '--batch_size', default=8, type=int, help='Batch size for training and validation processes') parser.add_argument('-nw', '--num_workers', default=4, type=int, help='Number of workers for the dataloaders') parser.add_argument('-lr', '--learning_rate', default=1e-4, type=float, help='Learning rate for training') @@ -166,6 +166,9 @@ def main_worker(rank, world_size): # Setup other metrics dice_metric = DiceLoss(smooth=1.0) + # for mixed-precision training + scaler = torch.cuda.amp.GradScaler() + # Training & Evaluation for i in tqdm(range(args.num_epochs), desc='Iterating over Epochs'): # -------------------------------- TRAINING ------------------------------------ @@ -186,16 +189,22 @@ def main_worker(rank, world_size): x = torch.cat([x1, x2], dim=1).to(device) # size of x: (B x 2 x P x P x P) # print(x.shape) - model.forward(patch=x, seg_mask=y, training=True) + with torch.cuda.amp.autocast(): + model.forward(patch=x, seg_mask=y, training=True) - y_hat, elbo = model.elbo(y) # y_hat is the reconstructed mask - shape: (B x 1 x P x P x P) - reg_loss = l2_regularization(model.posterior) + l2_regularization(model.prior) + \ - l2_regularization(model.fcomb.layers) - loss = -elbo + 1e-5 * reg_loss + y_hat, elbo = model.elbo(y) # y_hat is the reconstructed mask - shape: (B x 1 x P x P x P) + reg_loss = l2_regularization(model.posterior) + l2_regularization(model.prior) + \ + l2_regularization(model.fcomb.layers) + loss = -elbo + 1e-5 * reg_loss optimizer.zero_grad() - loss.backward() - optimizer.step() + # Backprop in Mixed-Precision training + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + # Backprop in Standard training + # loss.backward() + # optimizer.step() train_epoch_loss += loss.item() train_epoch_dice += dice_metric(y_hat, y).item() @@ -221,18 +230,22 @@ def main_worker(rank, world_size): x1, x2 = x1.unsqueeze(dim=1), x2.unsqueeze(dim=1) x = torch.cat([x1, x2], dim=1).to(device) - model.forward(patch=x, seg_mask=None, training=False) - - num_predictions = 5 # NP - predictions = [] - for _ in range(num_predictions): - mask_pred = model.sample(testing=True) - mask_pred = (torch.sigmoid(mask_pred) > 0.5).float() # shape: (B x 1 x P x P x P) - predictions.append(mask_pred) - predictions = torch.cat(predictions, dim=1) # shape: (B x NP x P x P x P) - - y_pred = torch.squeeze(torch.mean(predictions, dim=1)) # y_pred is the mean of all NP predictions - # shape: (B x P x P x P) + with torch.cuda.amp.autocast(): + model.forward(patch=x, seg_mask=None, training=False) + + num_predictions = 5 # NP + predictions = [] + for _ in range(num_predictions): + mask_pred = model.sample(testing=True) + # TODO: this line below gets hard predictions. Use just sigmoid and see how Dice performs. + mask_pred = torch.sigmoid(mask_pred) # getting a soft pred. ; shape: (B x 1 x P x P x P) + # uncomment the line below for getting hard predictions. + # mask_pred = (torch.sigmoid(mask_pred) > 0.5).float() # shape: (B x 1 x P x P x P) + predictions.append(mask_pred) + predictions = torch.cat(predictions, dim=1) # shape: (B x NP x P x P x P) + + y_pred = torch.squeeze(torch.mean(predictions, dim=1)) # y_pred is the mean of all NP predictions + # shape: (B x P x P x P) val_epoch_dice += dice_metric(y_pred, y)