Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViT UNETR SSL #10

Draft
wants to merge 100 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
75a5f85
Add notebook debugging SSL pretraining transformations
valosekj Mar 8, 2024
e37ff59
Fix typo
valosekj Mar 8, 2024
80fd7db
Add script for Self-Supervised Pre-training using Vision Transformer …
valosekj Mar 8, 2024
4dc7e81
Fix typo
valosekj Mar 8, 2024
3a9baef
Do sanity dimension check after transformations
valosekj Mar 8, 2024
4517f7a
Use `patch_size=(16, 16, 16)` and `batch_size = 2`
valosekj Mar 8, 2024
a621a9a
rename 'crop_size' to 'spatial_size'
valosekj Mar 9, 2024
516f1ec
use 'contrast-agnostic' data augmentation for Training of the fine-tu…
valosekj Mar 9, 2024
fd8f943
Add script to finetune the 3D Single-Class Spinal Cord Lesion Segment…
valosekj Mar 9, 2024
d1b9c3d
update description
valosekj Mar 9, 2024
45ef357
use 'logger.info' instead of 'print'
valosekj Mar 9, 2024
0955689
make '--data' arg notrequired
valosekj Mar 9, 2024
9ea8bca
explicitly specify CUDA GPU number
valosekj Mar 9, 2024
3bc999a
change sliding_window_inference roi_size and batch_size
valosekj Mar 9, 2024
8e8a977
remove 'to_onehot_y=True' and 'softmax=True' because we work with sin…
valosekj Mar 9, 2024
42394e8
add notebook to debug transformations when using SC masks to crop pat…
valosekj Mar 12, 2024
622c3f6
rerun the notebook to show different slices
valosekj Mar 12, 2024
59f4db0
Add 'keys' arg
valosekj Mar 12, 2024
a3ae050
Add 'keys' arg
valosekj Mar 12, 2024
e1cfad1
Update 'RandCoarseDropoutd' params
valosekj Mar 12, 2024
d4e9995
Update comments
valosekj Mar 12, 2024
4c35d42
Fix imports
valosekj Mar 12, 2024
05ab59a
Change batch_size and NUM_WORKERS to 4
valosekj Mar 12, 2024
0b12c10
Make '--data' non required
valosekj Mar 12, 2024
f5881aa
change 'num_workers' to 0 to prevent 'RuntimeError: received 0 items …
valosekj Mar 12, 2024
7aa9858
Update input arg description
valosekj Mar 13, 2024
16691d7
Rerun the notebook
valosekj Mar 13, 2024
56ca187
Print hyper-parameters into the log file
valosekj Mar 13, 2024
ce7315a
Add '--cuda' input arg
valosekj Mar 13, 2024
1ebf833
track and save epoch time
valosekj Mar 15, 2024
208b2c1
Plot and save input and output validation images to see how the model…
valosekj Mar 15, 2024
5a9e69a
Add comment
valosekj Mar 15, 2024
eca6616
Do not plot 'outputs_v2' as it is a hidden representation
valosekj Mar 15, 2024
8e55985
Include the epoch number as master title
valosekj Mar 15, 2024
e7575c5
Add 'torch.multiprocessing.set_sharing_strategy('file_system')' to so…
valosekj Mar 15, 2024
3b6df14
Create validation_figures directory if it does not exist
valosekj Mar 15, 2024
59b061c
Use 3 leading zeros for the epoch number in the figures fname
valosekj Mar 15, 2024
f2d648b
Link issue
valosekj Mar 15, 2024
a6bc6ec
Add note for 'RandCropByPosNegLabeld'
valosekj Mar 15, 2024
c1553d6
Add 'number_of_holes' arg to specify the number of holes to be used f…
valosekj Mar 15, 2024
0e802df
typo
valosekj Mar 16, 2024
9255855
batch_size = 8
valosekj Mar 16, 2024
8296913
NUM_WORKERS = batch_size
valosekj Mar 16, 2024
93c5415
number_of_holes=5
valosekj Mar 16, 2024
a4182b5
Update transforms for training of the fine-tuned model
valosekj Mar 16, 2024
2573578
update comment, remove unused imports
valosekj Mar 16, 2024
cfb8f12
Add notebook with RandCoarseDropoutd transform debug
valosekj Mar 16, 2024
7f27340
Fix 'dropout_holes=True' and 'dropout_holes=False' comments
valosekj Mar 17, 2024
74b2f2c
remove unused 'max_spatial_size' arg
valosekj Mar 17, 2024
bdb2ac1
use 'fill_value=0' for 'RandCoarseDropoutd'
valosekj Mar 17, 2024
f7f3689
Plot also RandCoarseDropoutd dropout_holes=False fill_value=0
valosekj Mar 17, 2024
af9c6d0
Add note that the batch size is actually doubled (8*2=16), because we…
valosekj Mar 17, 2024
b416135
Add '--cuda' input arg
valosekj Mar 18, 2024
0303776
Remove 'AsDiscrete'
valosekj Mar 18, 2024
4534c57
Remove 'AsDiscrete'
valosekj Mar 18, 2024
94c8611
Add 'CUDA_NUM=args.cuda'
valosekj Mar 18, 2024
5fb325a
Add TODO to increase batch_size to 16
valosekj Mar 18, 2024
d1a03c8
Use 'roi_size' for 'define_finetune_train_transforms'
valosekj Mar 18, 2024
e3d8086
Use 'label_sc' to crop samples around the SC
valosekj Mar 18, 2024
7e628af
batch_size = 8
valosekj Mar 18, 2024
f832c94
NUM_WORKERS = batch_size
valosekj Mar 18, 2024
c68d65d
Add 'import torch.multiprocessing'
valosekj Mar 18, 2024
a7c7c77
Fix shape logging
valosekj Mar 18, 2024
1c6109c
Change 'img_size' to 'ROI_SIZE'
valosekj Mar 18, 2024
b92e7d5
'batch["label"]' --> 'batch["label_lesion"]'
valosekj Mar 18, 2024
f174fb9
Plot and save input and output validation images to see how the model…
valosekj Mar 18, 2024
532e24b
Fix ROI_SIZE for sliding_window_inference
valosekj Mar 18, 2024
5c10890
Crop samples of 64x64x64 also for Validation of the fine-tuned model
valosekj Mar 18, 2024
d33abfa
update docstring
valosekj Mar 18, 2024
efeccf9
log validation samples shapes
valosekj Mar 18, 2024
f34ece1
Save validation figure only if it contains a lesion
valosekj Mar 18, 2024
e0afafe
Plot GT together with image
valosekj Mar 19, 2024
e588a22
print unique values in the slice to see if it is binary
valosekj Mar 19, 2024
2d0b530
update output fig fname
valosekj Mar 19, 2024
988680e
Add 'AsDiscreted' for Training and Validation of the fine-tuned model
valosekj Mar 19, 2024
f675611
threshold val_labels_list and val_outputs_list by 0.5 threshold befor…
valosekj Mar 19, 2024
cc78896
add normalized relu normalization
naga-karthik Mar 19, 2024
92334ea
fix binarization bug
naga-karthik Mar 19, 2024
2c49ca2
remove 'logger.info(np.unique(output.detach().cpu().numpy()))'
valosekj Mar 19, 2024
fbe3dd7
overlay prediction over input image
valosekj Mar 19, 2024
ba39d28
Fix variable when getting probabilities from logits
valosekj Mar 20, 2024
c8d9a5d
Add debug lines
valosekj Mar 20, 2024
ac040a5
PEP8
valosekj Mar 20, 2024
1e2b61d
Set validation batch_size to 1
valosekj Mar 20, 2024
8f49545
improve comments
valosekj Mar 20, 2024
3a74e28
fix figure title
valosekj Mar 20, 2024
dea121c
comment 'AsDiscreted' transforms
valosekj Mar 20, 2024
013ca65
Make '--pretrained-model' non required to allow training from the scr…
valosekj Mar 21, 2024
f1aae1b
Add script to create spine-generic MSD dataset
valosekj Mar 22, 2024
e4c87be
run notebook again
valosekj Mar 22, 2024
4448356
Make 'create_msd_data.py' compatible with other BIDS datasets
valosekj Mar 23, 2024
a354a21
Add note that no testing set is created
valosekj Mar 23, 2024
141c96b
Add README with instructions on how to download T2w images from multi…
valosekj Mar 24, 2024
7f8e6bb
fix typo
valosekj Mar 24, 2024
46f1bd9
remove unused imports
valosekj Mar 24, 2024
afeb7c5
fix sc suffix for sci-paris
valosekj Mar 24, 2024
da1c131
Use os.path.abspath for 'args.path_data' and 'args.path_out'
valosekj Mar 24, 2024
ed42c68
Update logging message
valosekj Mar 24, 2024
7f28fed
Fix '.replace' to prevent '//' in the output string
valosekj Mar 24, 2024
99a936b
Add 'create_msd_data.py' commands to create MSD-style JSON datalists
valosekj Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 272 additions & 0 deletions vit_unetr_ssl/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
"""
Finetuning of the 3D Single-Class Spinal Cord Lesion Segmentation Model Using SSL Pre-trained Weights

This script is based on this MONAI tutorial:
https://github.com/Project-MONAI/tutorials/tree/main/self_supervised_pretraining/vit_unetr_ssl

Author: Jan Valosek
"""

import os
import argparse
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from loguru import logger
from monai.utils import set_determinism, first
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference

from monai.transforms import AsDiscrete

from monai.metrics import DiceMetric
from monai.networks.nets import UNETR

from monai.data import (
Dataset,
DataLoader,
CacheDataset,
decollate_batch,
)

from load_data import load_data
from transforms import define_finetune_train_transforms, define_finetune_val_transforms


def get_parser():
# parse command line arguments
parser = argparse.ArgumentParser(description='Run Fine-tuning.')
parser.add_argument('--dataset-split', required=True, type=str,
help='Path to the JSON file with training/validation split. '
'If paths are absolute, you do NOT need to use --data. '
'If only filenames are provided, you need to use --data to specify the root directory '
'of the dataset.')
parser.add_argument('--data', required=False, type=str, default="",
help='Path to the dataset root directory. If not provided, path to data specified in the JSON '
'file will be used.')
parser.add_argument('--logdir', required=True, type=str,
help='Path to the directory for logging.')
parser.add_argument('--pretrained-model', required=True, type=str,
help='Path to the pretrained model.')

return parser


def main():
parser = get_parser()
args = parser.parse_args()

# -----------------------------------------------------
# Define file paths & output directory path
# -----------------------------------------------------
json_path = os.path.abspath(args.dataset_split)
data_root = os.path.abspath(args.data)
logdir_path = os.path.abspath(args.logdir)
pretrained_model_path = os.path.abspath(args.pretrained_model)
use_pretrained = True if pretrained_model_path is not None else False

# -----------------------------------------------------
# Create result logging directories, manage data paths & set determinism
# -----------------------------------------------------
train_list, val_list = load_data(data_root, json_path, logdir_path, is_segmentation=True)

# save output to a log file
logger.add(os.path.join(logdir_path, "log.txt"), rotation="10 MB", level="INFO")

logger.info("Total training data are {} and validation data are {}".format(len(train_list), len(val_list)))

# Set Determinism
set_determinism(seed=123)

# -----------------------------------------------------
# Define MONAI Transforms
# -----------------------------------------------------
SPATIAL_SIZE = (64, 256, 256) # keeping the same image size as for pretraining
train_transforms = define_finetune_train_transforms(spatial_size=SPATIAL_SIZE)
val_transforms = define_finetune_val_transforms(spatial_size=SPATIAL_SIZE)

# -----------------------------------------------------
# Sanity Check for the transforms
# -----------------------------------------------------
check_ds = Dataset(data=train_list, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
logger.info(f'original image shape: {check_data["image"][0][0].shape}')
logger.info(f'original label shape: {check_data["label"][0][0].shape}')

# -----------------------------------------------------
# Training Config
# -----------------------------------------------------

CUDA_NUM=2

device = torch.device(f"cuda:{CUDA_NUM}")
model = UNETR(
in_channels=1,
out_channels=1,
img_size=SPATIAL_SIZE,
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed="conv",
norm_name="instance",
res_block=True,
dropout_rate=0.0,
)

# -----------------------------------------------------
# Load ViT backbone weights into UNETR
# -----------------------------------------------------
if use_pretrained is True:
logger.info(f"Loading Weights from the Path {pretrained_model_path}")
vit_dict = torch.load(pretrained_model_path)
vit_weights = vit_dict["state_dict"]

# Remove items of vit_weights if they are not in the ViT backbone (this is used in UNETR).
# For example, some variables names like conv3d_transpose.weight, conv3d_transpose.bias,
# conv3d_transpose_1.weight and conv3d_transpose_1.bias are used to match dimensions
# while pretraining with ViTAutoEnc and are not a part of ViT backbone.
model_dict = model.vit.state_dict()

vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict}
model_dict.update(vit_weights)
model.vit.load_state_dict(model_dict)
del model_dict, vit_weights, vit_dict
logger.info("Pretrained Weights Succesfully Loaded !")

elif use_pretrained is False:
print("No weights were loaded, all weights being used are randomly initialized!")

model.to(device)

# Training Hyper-params
lr = 1e-4
max_iterations = 30000
eval_num = 100
batch_size = 2
loss_function = DiceCELoss()
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

# -----------------------------------------------------
# Create dataloaders for training
# -----------------------------------------------------

NUM_WORKERS = 0

train_dataset = CacheDataset(data=train_list, transform=train_transforms, cache_rate=0.5, num_workers=NUM_WORKERS)
val_dataset = CacheDataset(data=val_list, transform=val_transforms, cache_rate=0.25, num_workers=NUM_WORKERS)
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=False)
val_loader = DataLoader(val_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True,
persistent_workers=False)

# -----------------------------------------------------
# Training Loop with Validation
# -----------------------------------------------------
def validation(epoch_iterator_val):
model.eval()
dice_vals = []

with torch.no_grad():
for _step, batch in enumerate(epoch_iterator_val):
val_inputs, val_labels = (batch["image"].cuda(CUDA_NUM), batch["label"].cuda(CUDA_NUM))
val_outputs = sliding_window_inference(val_inputs, SPATIAL_SIZE, batch_size, model)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
dice = dice_metric.aggregate().item()
dice_vals.append(dice)
epoch_iterator_val.set_description("Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice))

dice_metric.reset()

mean_dice_val = np.mean(dice_vals)
return mean_dice_val

def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch["image"].cuda(CUDA_NUM), batch["label"].cuda(CUDA_NUM))
logit_map = model(x)
loss = loss_function(logit_map, y)
loss.backward()
epoch_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
epoch_iterator.set_description(
"Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss))

if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
dice_val = validation(epoch_iterator_val)

epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save(model.state_dict(), os.path.join(logdir_path, "best_metric_model.pth"))
logger.info(f"Model Was Saved ! Current Best Avg. Dice: {dice_val_best} "
f"Current Avg. Dice: {dice_val}")
else:
logger.info(f"Model Was Not Saved ! Current Best Avg. Dice: {dice_val_best} "
f"Current Avg. Dice: {dice_val}")

plt.figure(1, (12, 6))
plt.subplot(1, 2, 1)
plt.title("Iteration Average Loss")
x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.grid()
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [eval_num * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("Iteration")
plt.plot(x, y)
plt.grid()
plt.savefig(os.path.join(logdir_path, "btcv_finetune_quick_update.png"))
plt.clf()
plt.close(1)

global_step += 1
return global_step, dice_val_best, global_step_best

while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
model.load_state_dict(torch.load(os.path.join(logdir_path, "best_metric_model.pth")))

logger.info(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")


if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions vit_unetr_ssl/load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

from monai.data import load_decathlon_datalist


def load_data(data_root, json_path, logdir_path, is_segmentation=False):
"""
Load data from the json file and return the training and validation data
"""
if os.path.exists(logdir_path) is False:
os.mkdir(logdir_path)

train_list = load_decathlon_datalist(
base_dir=data_root, data_list_file_path=json_path, is_segmentation=is_segmentation, data_list_key="training"
)

val_list = load_decathlon_datalist(
base_dir=data_root, data_list_file_path=json_path, is_segmentation=is_segmentation, data_list_key="validation"
)

#train_data[0]

return train_list, val_list
Loading