-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f2cfa83
Showing
28 changed files
with
525 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
defaults: | ||
- model: efficientnetb0 | ||
- optimizer: adam | ||
- scheduler: cosine | ||
- datamodule: osteosarcoma | ||
- _self_ | ||
|
||
experiment_name: osteosarcoma | ||
run_name: ${model.name} | ||
n_folds: 1 | ||
pretrained: True | ||
|
||
trainer: | ||
_target_: pytorch_lightning.Trainer | ||
gpus: 1 | ||
max_epochs: 100 | ||
log_every_n_steps: 50 | ||
callbacks: | ||
- _target_: pytorch_lightning.callbacks.LearningRateMonitor | ||
logging_interval: epoch | ||
|
||
criterion: | ||
_target_: torch.nn.CrossEntropyLoss | ||
|
||
num_classes: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_target_: dataloading.osteosarcomaDataModule.OsteosarcomaDataModule | ||
batch_size: 8 | ||
train_val_ratio: 0.7 | ||
base_path: # Set this to the path containing the Osteosarcoma data from UT Southwestern/UT Dallas https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=52756935 | ||
num_workers: 12 | ||
img_size: | ||
- 512 | ||
- 512 | ||
n_splits: ${n_folds} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
name: efficientnetb0 | ||
object: | ||
_target_: monai.networks.nets.EfficientNetBN | ||
model_name: efficientnet-b0 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
name: efficientnetb1 | ||
object: | ||
_target_: monai.networks.nets.EfficientNetBN | ||
model_name: efficientnet-b1 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
name: efficientnetb3 | ||
object: | ||
_target_: monai.networks.nets.EfficientNetBN | ||
model_name: efficientnet-b3 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
name: efficientnetb5 | ||
object: | ||
_target_: monai.networks.nets.EfficientNetBN | ||
model_name: efficientnet-b5 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
name: efficientnetb7 | ||
object: | ||
_target_: monai.networks.nets.EfficientNetBN | ||
model_name: efficientnet-b7 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: mobilenetv2 | ||
object: | ||
_target_: model_loading.mobilenetv2.MobileNetV2 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: resnet18 | ||
object: | ||
_target_: model_loading.resnet18.ResNet18 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: resnet34 | ||
object: | ||
_target_: model_loading.resnet34.ResNet34 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: resnet50 | ||
object: | ||
_target_: model_loading.resnet50.ResNet50 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: vgg16 | ||
object: | ||
_target_: model_loading.vgg16.VGG16 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: vgg19 | ||
object: | ||
_target_: model_loading.vgg19.VGG19 | ||
num_classes: ${num_classes} | ||
pretrained: ${pretrained} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
name: vit_b_16 | ||
object: | ||
_target_: model_loading.vit_b_16.ViT_B_16 | ||
num_classes: ${num_classes} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
_target_: torch.optim.AdamW | ||
lr: 0.0003 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
_target_: torch.optim.SGD | ||
lr: 0.0001 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
_target_: torch.optim.lr_scheduler.CosineAnnealingLR | ||
T_max: ${trainer.max_epochs} | ||
eta_min: 0.00001 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
import torchvision.transforms.functional as TF | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
from torch.utils.data.dataset import Subset | ||
|
||
|
||
class CustomDataGen(Dataset): | ||
def __init__( | ||
self, | ||
df: Subset[pd.DataFrame], | ||
base_path: Path, | ||
transform=None, | ||
): | ||
self.base_path = base_path | ||
self.transform = transform | ||
|
||
if isinstance(df, Subset): | ||
subset = np.take(df.dataset, df.indices, axis=0) # type: ignore | ||
else: | ||
subset = df | ||
self.img_paths = subset["filename"].to_numpy() | ||
self.labels = subset["label"].to_numpy() | ||
|
||
def __len__(self): | ||
return len(self.img_paths) | ||
|
||
def __getitem__(self, idx): | ||
img_path = self.img_paths[idx] | ||
|
||
# the image might be in any subfolder of the base path | ||
# so we need to search for it first | ||
img_path = next(self.base_path.glob(f"**/{img_path}")) | ||
img = Image.open(img_path) | ||
img = TF.to_tensor(img) | ||
if self.transform: | ||
img = self.transform(img) | ||
|
||
label = torch.tensor(self.labels[idx], dtype=torch.long) | ||
return img, label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytorch_lightning as pl | ||
import torch | ||
import torch.nn as nn | ||
import torchvision.transforms | ||
import torchvision.transforms.functional as TF | ||
from PIL import Image | ||
from sklearn.model_selection import KFold | ||
from sklearn.utils.class_weight import compute_class_weight | ||
from torch.utils.data import DataLoader, random_split | ||
|
||
from dataloading.datagen import CustomDataGen | ||
|
||
|
||
class OsteosarcomaDataModule(pl.LightningDataModule): | ||
def __init__( | ||
self, | ||
batch_size, | ||
base_path, | ||
num_workers=0, | ||
train_val_ratio=0.8, | ||
img_size=None, | ||
k=1, | ||
n_splits=1, | ||
random_state=42, | ||
): | ||
super().__init__() | ||
self.batch_size = batch_size | ||
self.train_val_ratio = train_val_ratio | ||
self.num_workers = num_workers | ||
self.base_path = Path(base_path) | ||
self.img_size = tuple(img_size) if img_size else None | ||
self.k = k | ||
self.n_splits = n_splits | ||
self.random_state = random_state | ||
|
||
self.mean = torch.tensor([0.485, 0.456, 0.406]) | ||
self.std = torch.tensor([0.229, 0.224, 0.225]) | ||
|
||
def prepare_data(self): | ||
f = self.base_path / "ML_Features_1144.csv" | ||
self.df = pd.read_csv(f) | ||
self.filenames = self.df["image.name"] | ||
self.labels = self.df["classification"] | ||
|
||
self.filenames = self.filenames.str.replace(" -", "-") | ||
self.filenames = self.filenames.str.replace("- ", "-") | ||
self.filenames = self.filenames.str.replace(" ", "-") | ||
self.filenames = self.filenames + ".jpg" | ||
|
||
self.labels = self.labels.str.lower() | ||
self.labels = self.labels.replace("non-tumor", 0) | ||
self.labels = self.labels.replace("viable", 1) | ||
self.labels = self.labels.replace("non-viable-tumor", 2) | ||
self.labels = self.labels.replace("viable: non-viable", 2) | ||
|
||
assert len(self.filenames) == len(self.labels) | ||
assert set(self.labels) == {0, 1, 2} | ||
|
||
self.num_classes = len(set(self.labels)) | ||
self.class_weights = torch.tensor( | ||
compute_class_weight("balanced", classes=[0, 1, 2], y=self.labels), dtype=torch.float | ||
) | ||
|
||
self.df = pd.DataFrame({"filename": self.filenames, "label": self.labels}) | ||
|
||
def get_preprocessing_transform(self): | ||
transforms = nn.Sequential( | ||
torchvision.transforms.Normalize(self.mean, self.std), | ||
torchvision.transforms.Resize(self.img_size) | ||
if self.img_size | ||
else nn.Identity(), | ||
) | ||
return transforms | ||
|
||
def get_augmentation_transform(self): | ||
transforms = nn.Sequential( | ||
torchvision.transforms.RandomVerticalFlip(), | ||
torchvision.transforms.RandomHorizontalFlip(), | ||
torchvision.transforms.RandomRotation(20), | ||
) | ||
return transforms | ||
|
||
def setup(self, stage=None): | ||
if self.n_splits != 1: | ||
kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state) | ||
splits = [k for k in kf.split(self.df)] | ||
|
||
self.train_subjects, self.val_subjects = self.df.iloc[splits[self.k][0]], self.df.iloc[splits[self.k][1]] | ||
else: | ||
num_subjects = len(self.df) | ||
num_train_subjects = int(round(num_subjects * self.train_val_ratio)) | ||
num_val_subjects = num_subjects - num_train_subjects | ||
splits = num_train_subjects, num_val_subjects | ||
self.train_subjects, self.val_subjects = random_split( | ||
self.df, splits, generator=torch.Generator().manual_seed(self.random_state) # type: ignore | ||
) | ||
|
||
def train_dataloader(self): | ||
return DataLoader( | ||
CustomDataGen( | ||
self.train_subjects, | ||
self.base_path, | ||
transform=torchvision.transforms.Compose( | ||
[ | ||
self.get_preprocessing_transform(), | ||
self.get_augmentation_transform(), | ||
] | ||
), | ||
), | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=True, | ||
) | ||
|
||
def val_dataloader(self): | ||
return DataLoader( | ||
CustomDataGen( | ||
self.val_subjects, | ||
self.base_path, | ||
transform=self.get_preprocessing_transform(), | ||
), | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
shuffle=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class MobileNetV2(nn.Module): | ||
def __init__(self, num_classes, pretrained=False): | ||
super().__init__() | ||
self.model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=pretrained) | ||
self.model.classifier[1] = nn.Linear(1280, num_classes) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch.nn as nn | ||
from torchvision.models import ResNet18_Weights, resnet18 | ||
|
||
|
||
class ResNet18(nn.Module): | ||
def __init__(self, num_classes, pretrained=False): | ||
super().__init__() | ||
if pretrained: | ||
self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) | ||
else: | ||
self.model = resnet18() | ||
self.model.fc = nn.Linear(512, num_classes) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch.nn as nn | ||
from torchvision.models import ResNet34_Weights, resnet34 | ||
|
||
|
||
class ResNet34(nn.Module): | ||
def __init__(self, num_classes, pretrained=False): | ||
super().__init__() | ||
if pretrained: | ||
self.model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) | ||
else: | ||
self.model = resnet34() | ||
self.model.fc = nn.Linear(512, num_classes) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch.nn as nn | ||
from torchvision.models import ResNet50_Weights, resnet50 | ||
|
||
|
||
class ResNet50(nn.Module): | ||
def __init__(self, num_classes, pretrained=False): | ||
super().__init__() | ||
if pretrained: | ||
self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | ||
else: | ||
self.model = resnet50() | ||
self.model.fc = nn.Linear(2048, num_classes) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torchvision.models import VGG16_Weights, vgg16 | ||
|
||
|
||
class VGG16(nn.Module): | ||
def __init__(self, num_classes, pretrained=False): | ||
super().__init__() | ||
if pretrained: | ||
self.model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1) | ||
else: | ||
self.model = vgg16() | ||
self.model.classifier[0] = nn.Linear(in_features=512 * 7 * 7, out_features=512, bias=True) | ||
self.model.classifier[3] = nn.Linear(in_features=512, out_features=1024, bias=True) | ||
self.model.classifier[6] = nn.Linear(in_features=1024, out_features=num_classes, bias=True) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torchvision.models import VGG19_Weights, vgg19 | ||
|
||
|
||
class VGG19(nn.Module): | ||
def __init__(self, num_classes, pretrained=False): | ||
super().__init__() | ||
if pretrained: | ||
self.model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1) | ||
else: | ||
self.model = vgg19() | ||
self.model.classifier[0] = nn.Linear(in_features=512 * 7 * 7, out_features=512, bias=True) | ||
self.model.classifier[3] = nn.Linear(in_features=512, out_features=1024, bias=True) | ||
self.model.classifier[6] = nn.Linear(in_features=1024, out_features=num_classes, bias=True) | ||
|
||
def forward(self, x): | ||
return self.model(x) |
Oops, something went wrong.