From f2cfa8324a6082ce3e408ad5d6ce8e3e6844752c Mon Sep 17 00:00:00 2001 From: Ioannis Vezakis Date: Wed, 15 Mar 2023 20:59:57 +0200 Subject: [PATCH] first commit --- config/config.yaml | 25 +++++ config/datamodule/osteosarcoma.yaml | 9 ++ config/model/efficientnetb0.yaml | 7 ++ config/model/efficientnetb1.yaml | 6 ++ config/model/efficientnetb3.yaml | 6 ++ config/model/efficientnetb5.yaml | 6 ++ config/model/efficientnetb7.yaml | 6 ++ config/model/mobilenetv2.yaml | 5 + config/model/resnet18.yaml | 5 + config/model/resnet34.yaml | 5 + config/model/resnet50.yaml | 5 + config/model/vgg16.yaml | 5 + config/model/vgg19.yaml | 5 + config/model/vit.yaml | 4 + config/optimizer/adam.yaml | 2 + config/optimizer/sgd.yaml | 2 + config/scheduler/cosine.yaml | 3 + dataloading/datagen.py | 44 +++++++++ dataloading/osteosarcomaDataModule.py | 129 ++++++++++++++++++++++++++ model_loading/mobilenetv2.py | 12 +++ model_loading/resnet18.py | 15 +++ model_loading/resnet34.py | 15 +++ model_loading/resnet50.py | 15 +++ model_loading/vgg16.py | 18 ++++ model_loading/vgg19.py | 18 ++++ model_loading/vit_b_16.py | 20 ++++ network_module.py | 90 ++++++++++++++++++ trainer.py | 43 +++++++++ 28 files changed, 525 insertions(+) create mode 100644 config/config.yaml create mode 100644 config/datamodule/osteosarcoma.yaml create mode 100644 config/model/efficientnetb0.yaml create mode 100644 config/model/efficientnetb1.yaml create mode 100644 config/model/efficientnetb3.yaml create mode 100644 config/model/efficientnetb5.yaml create mode 100644 config/model/efficientnetb7.yaml create mode 100644 config/model/mobilenetv2.yaml create mode 100644 config/model/resnet18.yaml create mode 100644 config/model/resnet34.yaml create mode 100644 config/model/resnet50.yaml create mode 100644 config/model/vgg16.yaml create mode 100644 config/model/vgg19.yaml create mode 100644 config/model/vit.yaml create mode 100644 config/optimizer/adam.yaml create mode 100644 config/optimizer/sgd.yaml create mode 100644 config/scheduler/cosine.yaml create mode 100644 dataloading/datagen.py create mode 100644 dataloading/osteosarcomaDataModule.py create mode 100644 model_loading/mobilenetv2.py create mode 100644 model_loading/resnet18.py create mode 100644 model_loading/resnet34.py create mode 100644 model_loading/resnet50.py create mode 100644 model_loading/vgg16.py create mode 100644 model_loading/vgg19.py create mode 100644 model_loading/vit_b_16.py create mode 100644 network_module.py create mode 100644 trainer.py diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..c6117ed --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,25 @@ +defaults: + - model: efficientnetb0 + - optimizer: adam + - scheduler: cosine + - datamodule: osteosarcoma + - _self_ + +experiment_name: osteosarcoma +run_name: ${model.name} +n_folds: 1 +pretrained: True + +trainer: + _target_: pytorch_lightning.Trainer + gpus: 1 + max_epochs: 100 + log_every_n_steps: 50 + callbacks: + - _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + +criterion: + _target_: torch.nn.CrossEntropyLoss + +num_classes: 3 diff --git a/config/datamodule/osteosarcoma.yaml b/config/datamodule/osteosarcoma.yaml new file mode 100644 index 0000000..d1855e3 --- /dev/null +++ b/config/datamodule/osteosarcoma.yaml @@ -0,0 +1,9 @@ +_target_: dataloading.osteosarcomaDataModule.OsteosarcomaDataModule +batch_size: 8 +train_val_ratio: 0.7 +base_path: # Set this to the path containing the Osteosarcoma data from UT Southwestern/UT Dallas https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=52756935 +num_workers: 12 +img_size: + - 512 + - 512 +n_splits: ${n_folds} diff --git a/config/model/efficientnetb0.yaml b/config/model/efficientnetb0.yaml new file mode 100644 index 0000000..61849a1 --- /dev/null +++ b/config/model/efficientnetb0.yaml @@ -0,0 +1,7 @@ +name: efficientnetb0 +object: + _target_: monai.networks.nets.EfficientNetBN + model_name: efficientnet-b0 + num_classes: ${num_classes} + pretrained: ${pretrained} + \ No newline at end of file diff --git a/config/model/efficientnetb1.yaml b/config/model/efficientnetb1.yaml new file mode 100644 index 0000000..4a0569a --- /dev/null +++ b/config/model/efficientnetb1.yaml @@ -0,0 +1,6 @@ +name: efficientnetb1 +object: + _target_: monai.networks.nets.EfficientNetBN + model_name: efficientnet-b1 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/efficientnetb3.yaml b/config/model/efficientnetb3.yaml new file mode 100644 index 0000000..075483f --- /dev/null +++ b/config/model/efficientnetb3.yaml @@ -0,0 +1,6 @@ +name: efficientnetb3 +object: + _target_: monai.networks.nets.EfficientNetBN + model_name: efficientnet-b3 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/efficientnetb5.yaml b/config/model/efficientnetb5.yaml new file mode 100644 index 0000000..defb542 --- /dev/null +++ b/config/model/efficientnetb5.yaml @@ -0,0 +1,6 @@ +name: efficientnetb5 +object: + _target_: monai.networks.nets.EfficientNetBN + model_name: efficientnet-b5 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/efficientnetb7.yaml b/config/model/efficientnetb7.yaml new file mode 100644 index 0000000..86093bc --- /dev/null +++ b/config/model/efficientnetb7.yaml @@ -0,0 +1,6 @@ +name: efficientnetb7 +object: + _target_: monai.networks.nets.EfficientNetBN + model_name: efficientnet-b7 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/mobilenetv2.yaml b/config/model/mobilenetv2.yaml new file mode 100644 index 0000000..e487364 --- /dev/null +++ b/config/model/mobilenetv2.yaml @@ -0,0 +1,5 @@ +name: mobilenetv2 +object: + _target_: model_loading.mobilenetv2.MobileNetV2 + num_classes: ${num_classes} + pretrained: ${pretrained} \ No newline at end of file diff --git a/config/model/resnet18.yaml b/config/model/resnet18.yaml new file mode 100644 index 0000000..babcceb --- /dev/null +++ b/config/model/resnet18.yaml @@ -0,0 +1,5 @@ +name: resnet18 +object: + _target_: model_loading.resnet18.ResNet18 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/resnet34.yaml b/config/model/resnet34.yaml new file mode 100644 index 0000000..3357668 --- /dev/null +++ b/config/model/resnet34.yaml @@ -0,0 +1,5 @@ +name: resnet34 +object: + _target_: model_loading.resnet34.ResNet34 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/resnet50.yaml b/config/model/resnet50.yaml new file mode 100644 index 0000000..2f5cbbc --- /dev/null +++ b/config/model/resnet50.yaml @@ -0,0 +1,5 @@ +name: resnet50 +object: + _target_: model_loading.resnet50.ResNet50 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/vgg16.yaml b/config/model/vgg16.yaml new file mode 100644 index 0000000..fee3fc1 --- /dev/null +++ b/config/model/vgg16.yaml @@ -0,0 +1,5 @@ +name: vgg16 +object: + _target_: model_loading.vgg16.VGG16 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/vgg19.yaml b/config/model/vgg19.yaml new file mode 100644 index 0000000..c9831a1 --- /dev/null +++ b/config/model/vgg19.yaml @@ -0,0 +1,5 @@ +name: vgg19 +object: + _target_: model_loading.vgg19.VGG19 + num_classes: ${num_classes} + pretrained: ${pretrained} diff --git a/config/model/vit.yaml b/config/model/vit.yaml new file mode 100644 index 0000000..9bb812f --- /dev/null +++ b/config/model/vit.yaml @@ -0,0 +1,4 @@ +name: vit_b_16 +object: + _target_: model_loading.vit_b_16.ViT_B_16 + num_classes: ${num_classes} diff --git a/config/optimizer/adam.yaml b/config/optimizer/adam.yaml new file mode 100644 index 0000000..e258fab --- /dev/null +++ b/config/optimizer/adam.yaml @@ -0,0 +1,2 @@ +_target_: torch.optim.AdamW +lr: 0.0003 diff --git a/config/optimizer/sgd.yaml b/config/optimizer/sgd.yaml new file mode 100644 index 0000000..8ada58a --- /dev/null +++ b/config/optimizer/sgd.yaml @@ -0,0 +1,2 @@ +_target_: torch.optim.SGD +lr: 0.0001 diff --git a/config/scheduler/cosine.yaml b/config/scheduler/cosine.yaml new file mode 100644 index 0000000..f88fc98 --- /dev/null +++ b/config/scheduler/cosine.yaml @@ -0,0 +1,3 @@ +_target_: torch.optim.lr_scheduler.CosineAnnealingLR +T_max: ${trainer.max_epochs} +eta_min: 0.00001 diff --git a/dataloading/datagen.py b/dataloading/datagen.py new file mode 100644 index 0000000..09f6d32 --- /dev/null +++ b/dataloading/datagen.py @@ -0,0 +1,44 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torchvision.transforms.functional as TF +from PIL import Image +from torch.utils.data import Dataset +from torch.utils.data.dataset import Subset + + +class CustomDataGen(Dataset): + def __init__( + self, + df: Subset[pd.DataFrame], + base_path: Path, + transform=None, + ): + self.base_path = base_path + self.transform = transform + + if isinstance(df, Subset): + subset = np.take(df.dataset, df.indices, axis=0) # type: ignore + else: + subset = df + self.img_paths = subset["filename"].to_numpy() + self.labels = subset["label"].to_numpy() + + def __len__(self): + return len(self.img_paths) + + def __getitem__(self, idx): + img_path = self.img_paths[idx] + + # the image might be in any subfolder of the base path + # so we need to search for it first + img_path = next(self.base_path.glob(f"**/{img_path}")) + img = Image.open(img_path) + img = TF.to_tensor(img) + if self.transform: + img = self.transform(img) + + label = torch.tensor(self.labels[idx], dtype=torch.long) + return img, label diff --git a/dataloading/osteosarcomaDataModule.py b/dataloading/osteosarcomaDataModule.py new file mode 100644 index 0000000..54b37c3 --- /dev/null +++ b/dataloading/osteosarcomaDataModule.py @@ -0,0 +1,129 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torchvision.transforms +import torchvision.transforms.functional as TF +from PIL import Image +from sklearn.model_selection import KFold +from sklearn.utils.class_weight import compute_class_weight +from torch.utils.data import DataLoader, random_split + +from dataloading.datagen import CustomDataGen + + +class OsteosarcomaDataModule(pl.LightningDataModule): + def __init__( + self, + batch_size, + base_path, + num_workers=0, + train_val_ratio=0.8, + img_size=None, + k=1, + n_splits=1, + random_state=42, + ): + super().__init__() + self.batch_size = batch_size + self.train_val_ratio = train_val_ratio + self.num_workers = num_workers + self.base_path = Path(base_path) + self.img_size = tuple(img_size) if img_size else None + self.k = k + self.n_splits = n_splits + self.random_state = random_state + + self.mean = torch.tensor([0.485, 0.456, 0.406]) + self.std = torch.tensor([0.229, 0.224, 0.225]) + + def prepare_data(self): + f = self.base_path / "ML_Features_1144.csv" + self.df = pd.read_csv(f) + self.filenames = self.df["image.name"] + self.labels = self.df["classification"] + + self.filenames = self.filenames.str.replace(" -", "-") + self.filenames = self.filenames.str.replace("- ", "-") + self.filenames = self.filenames.str.replace(" ", "-") + self.filenames = self.filenames + ".jpg" + + self.labels = self.labels.str.lower() + self.labels = self.labels.replace("non-tumor", 0) + self.labels = self.labels.replace("viable", 1) + self.labels = self.labels.replace("non-viable-tumor", 2) + self.labels = self.labels.replace("viable: non-viable", 2) + + assert len(self.filenames) == len(self.labels) + assert set(self.labels) == {0, 1, 2} + + self.num_classes = len(set(self.labels)) + self.class_weights = torch.tensor( + compute_class_weight("balanced", classes=[0, 1, 2], y=self.labels), dtype=torch.float + ) + + self.df = pd.DataFrame({"filename": self.filenames, "label": self.labels}) + + def get_preprocessing_transform(self): + transforms = nn.Sequential( + torchvision.transforms.Normalize(self.mean, self.std), + torchvision.transforms.Resize(self.img_size) + if self.img_size + else nn.Identity(), + ) + return transforms + + def get_augmentation_transform(self): + transforms = nn.Sequential( + torchvision.transforms.RandomVerticalFlip(), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.RandomRotation(20), + ) + return transforms + + def setup(self, stage=None): + if self.n_splits != 1: + kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state) + splits = [k for k in kf.split(self.df)] + + self.train_subjects, self.val_subjects = self.df.iloc[splits[self.k][0]], self.df.iloc[splits[self.k][1]] + else: + num_subjects = len(self.df) + num_train_subjects = int(round(num_subjects * self.train_val_ratio)) + num_val_subjects = num_subjects - num_train_subjects + splits = num_train_subjects, num_val_subjects + self.train_subjects, self.val_subjects = random_split( + self.df, splits, generator=torch.Generator().manual_seed(self.random_state) # type: ignore + ) + + def train_dataloader(self): + return DataLoader( + CustomDataGen( + self.train_subjects, + self.base_path, + transform=torchvision.transforms.Compose( + [ + self.get_preprocessing_transform(), + self.get_augmentation_transform(), + ] + ), + ), + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self): + return DataLoader( + CustomDataGen( + self.val_subjects, + self.base_path, + transform=self.get_preprocessing_transform(), + ), + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/model_loading/mobilenetv2.py b/model_loading/mobilenetv2.py new file mode 100644 index 0000000..72bf530 --- /dev/null +++ b/model_loading/mobilenetv2.py @@ -0,0 +1,12 @@ +import torch +import torch.nn as nn + + +class MobileNetV2(nn.Module): + def __init__(self, num_classes, pretrained=False): + super().__init__() + self.model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=pretrained) + self.model.classifier[1] = nn.Linear(1280, num_classes) + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/model_loading/resnet18.py b/model_loading/resnet18.py new file mode 100644 index 0000000..fb3c23f --- /dev/null +++ b/model_loading/resnet18.py @@ -0,0 +1,15 @@ +import torch.nn as nn +from torchvision.models import ResNet18_Weights, resnet18 + + +class ResNet18(nn.Module): + def __init__(self, num_classes, pretrained=False): + super().__init__() + if pretrained: + self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) + else: + self.model = resnet18() + self.model.fc = nn.Linear(512, num_classes) + + def forward(self, x): + return self.model(x) diff --git a/model_loading/resnet34.py b/model_loading/resnet34.py new file mode 100644 index 0000000..687a8a2 --- /dev/null +++ b/model_loading/resnet34.py @@ -0,0 +1,15 @@ +import torch.nn as nn +from torchvision.models import ResNet34_Weights, resnet34 + + +class ResNet34(nn.Module): + def __init__(self, num_classes, pretrained=False): + super().__init__() + if pretrained: + self.model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) + else: + self.model = resnet34() + self.model.fc = nn.Linear(512, num_classes) + + def forward(self, x): + return self.model(x) diff --git a/model_loading/resnet50.py b/model_loading/resnet50.py new file mode 100644 index 0000000..d96441c --- /dev/null +++ b/model_loading/resnet50.py @@ -0,0 +1,15 @@ +import torch.nn as nn +from torchvision.models import ResNet50_Weights, resnet50 + + +class ResNet50(nn.Module): + def __init__(self, num_classes, pretrained=False): + super().__init__() + if pretrained: + self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) + else: + self.model = resnet50() + self.model.fc = nn.Linear(2048, num_classes) + + def forward(self, x): + return self.model(x) diff --git a/model_loading/vgg16.py b/model_loading/vgg16.py new file mode 100644 index 0000000..9d281b8 --- /dev/null +++ b/model_loading/vgg16.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn +from torchvision.models import VGG16_Weights, vgg16 + + +class VGG16(nn.Module): + def __init__(self, num_classes, pretrained=False): + super().__init__() + if pretrained: + self.model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1) + else: + self.model = vgg16() + self.model.classifier[0] = nn.Linear(in_features=512 * 7 * 7, out_features=512, bias=True) + self.model.classifier[3] = nn.Linear(in_features=512, out_features=1024, bias=True) + self.model.classifier[6] = nn.Linear(in_features=1024, out_features=num_classes, bias=True) + + def forward(self, x): + return self.model(x) diff --git a/model_loading/vgg19.py b/model_loading/vgg19.py new file mode 100644 index 0000000..95f01c7 --- /dev/null +++ b/model_loading/vgg19.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn +from torchvision.models import VGG19_Weights, vgg19 + + +class VGG19(nn.Module): + def __init__(self, num_classes, pretrained=False): + super().__init__() + if pretrained: + self.model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1) + else: + self.model = vgg19() + self.model.classifier[0] = nn.Linear(in_features=512 * 7 * 7, out_features=512, bias=True) + self.model.classifier[3] = nn.Linear(in_features=512, out_features=1024, bias=True) + self.model.classifier[6] = nn.Linear(in_features=1024, out_features=num_classes, bias=True) + + def forward(self, x): + return self.model(x) diff --git a/model_loading/vit_b_16.py b/model_loading/vit_b_16.py new file mode 100644 index 0000000..f877525 --- /dev/null +++ b/model_loading/vit_b_16.py @@ -0,0 +1,20 @@ +import torch.nn as nn +from torchvision.models import ViT_B_16_Weights, vit_b_16 +from torchvision.models.feature_extraction import create_feature_extractor + + +class ViT_B_16(nn.Module): + def __init__(self, num_classes): + super().__init__() + self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) + self.model = create_feature_extractor(self.model, return_nodes=['getitem_5']) + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(768, num_classes), + ) + + def forward(self, x): + x = self.model(x) + x = x["getitem_5"] + x = self.classifier(x) + return x diff --git a/network_module.py b/network_module.py new file mode 100644 index 0000000..9bbb89f --- /dev/null +++ b/network_module.py @@ -0,0 +1,90 @@ +import os + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics as tm +from hydra.utils import instantiate +from torchmetrics.classification import MulticlassROC + + +class Net(pl.LightningModule): + def __init__(self, model, criterion, num_classes, optimizer, scheduler=None): + super().__init__() + self.model = model + self.criterion = criterion + self.num_classes = num_classes + self.optimizer = optimizer + self.scheduler = scheduler + + self.confusion_matrix = tm.ConfusionMatrix( + task="multiclass", num_classes=num_classes + ) + self.roc = MulticlassROC(num_classes=num_classes) + + self.train_accuracy = tm.Accuracy(task="multiclass", num_classes=num_classes) + self.val_accuracy = tm.Accuracy(task="multiclass", num_classes=num_classes) + + def configure_optimizers(self): + if self.scheduler: + return { + "optimizer": self.optimizer, + "lr_scheduler": instantiate(self.scheduler, optimizer=self.optimizer), + "monitor": "val_loss", + } + return self.optimizer + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = self.criterion(y_hat, y) + self.log("train_loss", loss) + + self.train_accuracy.update(y_hat, y) + + return loss + + def training_epoch_end(self, outs): + self.log("train_accuracy", self.train_accuracy.compute()) + self.train_accuracy.reset() + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = self.criterion(y_hat, y) + self.log("val_loss", loss) + + self.confusion_matrix.update(y_hat, y) + self.roc.update(y_hat, y) + self.val_accuracy.update(y_hat, y) + + return loss + + def validation_epoch_end(self, outs): + self.log("val_accuracy", self.val_accuracy.compute()) + self.val_accuracy.reset() + + cm = self.confusion_matrix.compute() + roc = self.roc.compute() + + self.confusion_matrix.reset() + self.roc.reset() + + if self.logger: + run_name = self.logger.name + "/" + str(self.logger.version) + else: + run_name = "default" + + os.makedirs(f"confusion_matrices/{run_name}", exist_ok=True) + os.makedirs(f"rocs/{run_name}", exist_ok=True) + + torch.save( + cm, os.path.join(f"confusion_matrices/{run_name}", "confusion_matrix.pt") + ) + torch.save( + roc, os.path.join(f"rocs/{run_name}", "roc.pt") + ) diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..c5e4951 --- /dev/null +++ b/trainer.py @@ -0,0 +1,43 @@ +import hydra +from hydra.utils import instantiate +from pytorch_lightning.loggers import TensorBoardLogger + +from dataloading.osteosarcomaDataModule import OsteosarcomaDataModule +from network_module import Net + + +@hydra.main(config_path="config", config_name="config", version_base=None) +def main(cfg): + if cfg.pretrained: + base_run_name = str(cfg.run_name) + "_pretrained" + else: + base_run_name = str(cfg.run_name) + + for k in range(cfg.n_folds): + run_name = f"{cfg.experiment_name}/{base_run_name}/{cfg.datamodule.img_size}" + if cfg.n_folds > 1: + run_name += f"/{k}_fold" + + tensorboard_logger = TensorBoardLogger( + save_dir="logs", + name=run_name, + ) + + dm = instantiate(cfg.datamodule, k=k) + dm.prepare_data() + + model = instantiate(cfg.model.object, num_classes=dm.num_classes) + net = Net( + model=model, + criterion=instantiate(cfg.criterion, weight=dm.class_weights), + num_classes=dm.num_classes, + optimizer=instantiate(cfg.optimizer, params=model.parameters()), + scheduler=cfg.scheduler, + ) + + trainer = instantiate(cfg.trainer, logger=tensorboard_logger) + trainer.fit(net, dm) + + +if __name__ == "__main__": + main()