-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
265 lines (217 loc) · 11.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import os
import numpy as np
import pandas as pd
import random
import h5py
import pickle
import itertools
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from sklearn.utils.class_weight import compute_sample_weight
from sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_score, cumulative_dynamic_auc
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def collate(batch):
# Keep a numpy array on items that don't need to be a tensor for training.
img = torch.cat([item[0] for item in batch], dim=0)
img2 = torch.cat([item[1] for item in batch], dim=0)
label = torch.LongTensor([item[2] for item in batch])
event_time = np.array([item[3] for item in batch])
censorship = torch.FloatTensor([item[4] for item in batch])
stage = torch.LongTensor([item[5] for item in batch])
slide_id = [item[6] for item in batch]
return [img, img2, label, event_time, censorship, stage, slide_id]
class FeatureBagsDataset(Dataset):
def __init__(self, df, data_dir, input_feature_size, stage_class):
self.slide_df = df.copy().reset_index(drop=True)
self.data_dir = data_dir
self.input_feature_size = input_feature_size
self.stage_class = stage_class
def _get_feature_path(self, slide_id):
return os.path.join(self.data_dir, f"{slide_id}_Mergedfeatures.pt")
def __getitem__(self, idx):
slide_id = self.slide_df["slide_id"][idx]
stage = self.slide_df["stage"][idx]
label = self.slide_df["disc_label"][idx]
event_time = self.slide_df["recurrence_years"][idx]
censorship = self.slide_df["censorship"][idx]
full_path = self._get_feature_path(slide_id)
features = torch.load(full_path)
# Merged features.
features_merged = torch.from_numpy(np.array([x[0].mean(0) for x in features]))
# Alternative would be all features depending on what works best.
features_flattened = torch.from_numpy(np.concatenate([x[0] for x in features]))
return features_merged, features_flattened, label, event_time, censorship, stage, slide_id
def __len__(self):
return len(self.slide_df)
def define_data_sampling(train_split, val_split, method, workers):
# Reproducibility of DataLoader.
g = torch.Generator()
g.manual_seed(0)
# Set up training data sampler.
if method == "random":
print("random sampling setting")
train_loader = DataLoader(
dataset=train_split,
batch_size=1, # model expects one bag of features at the time.
shuffle=True,
collate_fn=collate,
num_workers=workers,
pin_memory=True,
worker_init_fn=seed_worker,
generator=g,
)
else:
raise Exception(f"Sampling method '{method}' not implemented.")
val_loader = DataLoader(
dataset=val_split,
batch_size=1, # model expects one bag of features at the time.
sampler=SequentialSampler(val_split),
collate_fn=collate,
num_workers=workers,
pin_memory=True,
worker_init_fn=seed_worker,
generator=g,
)
return train_loader, val_loader
class MonitorBestModelEarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience and save best model """
def __init__(self, patience=15, min_epochs=20, saving_checkpoint=True):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 20
min_epochs (int): Earliest epoch possible for stopping
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
"""
#self.warmup = warmup
self.patience = patience
self.min_epochs = min_epochs
self.counter = 0
self.early_stop = False
self.eval_loss_min = np.Inf
self.best_loss_score = None
self.best_epoch_loss = None
self.best_CI_score = 0.0
self.best_metrics_score = None
self.best_epoch_CI = None
self.saving_checkpoint = saving_checkpoint
def __call__(self, epoch, eval_loss, eval_cindex, eval_other_metrics, model, log_dir):
loss_score = -eval_loss
CI_score = eval_cindex
metrics_score = eval_other_metrics
# Save model at epoch 0 and starts monitoring.
if self.best_loss_score is None:
self._update_loss_scores(loss_score, eval_loss, epoch)
self._update_metrics_scores(CI_score, metrics_score, epoch)
#self.save_checkpoint(model, log_dir, epoch)
# Eval loss starts increasing. Recommend running early stopping on the loss.
elif loss_score < self.best_loss_score:
self.counter += 1
print(f'Evaluation loss does not decrease : Starting Early stopping counter {self.counter} out of {self.patience}')
if self.counter >= self.patience and epoch > self.min_epochs:
self.early_stop = True
# Eval loss keeps decreasing.
else:
print(f'Epoch {epoch} validation loss decreased ({self.eval_loss_min:.6f} --> {eval_loss:.6f})')
self._update_loss_scores(loss_score, eval_loss, epoch)
#self.save_checkpoint(model, log_dir, epoch)
self.counter = 0
# We may have a tiny lag between min loss and the best C-index. With the patience and early stop, it is fine but better to save based on highest C-index too.
if CI_score > self.best_CI_score:
self._update_metrics_scores(CI_score, metrics_score, epoch)
#self.save_checkpoint(model, log_dir, epoch)
def save_checkpoint(self, model, log_dir, epoch):
filepath = os.path.join(log_dir, f"{epoch}_checkpoint.pt")
if self.saving_checkpoint and not os.path.exists(filepath):
print(f"Saving model")
torch.save(model.state_dict(), filepath)
def _update_loss_scores(self, loss_score, eval_loss, epoch):
self.eval_loss_min = eval_loss
self.best_loss_score = loss_score
self.best_epoch_loss = epoch
print(f'Updating loss at epoch {self.best_epoch_loss} -> {self.eval_loss_min:.6f}')
def _update_metrics_scores(self, CI_score, metrics_score, epoch):
# Even though loss would decrease, C-index does not necessarily increase. Keep track also on best C-index.
self.best_CI_score = CI_score
self.best_epoch_CI = epoch
self.best_metrics_score = metrics_score
print(f'Updating C-index at epoch {self.best_epoch_CI} -> {self.best_CI_score:.6f}')
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
def print_model(model):
print(model)
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model has {n_trainable_params} parameters")
def get_survival_data_for_BS(df, time_col_name, censorship_col_name='censorship'):
# To compute one survival metric the Brier score (BS), you need a specific format of censorship and times
# This is to estimate the censoring distribution from.
# A structured array containing the binary event indicator as first field (1 event occured; 0 censored), and time of event or time of censoring as second field.
val = df[[censorship_col_name, time_col_name]].values
max_time = df[time_col_name].max()
y = np.empty(len(df), dtype=[('cens', '?'), ('time', '<f8')])
for i in range(len(df)):
y[i] = tuple((bool(1-val[i][0]),val[i][1])) # Note that we take the uncensorship status.
return max_time, y
def get_bins_time_value(df, n_bins, time_col_name, label_time_col_name='disc_label', censorship_col_name='censorship'):
# Retrieve the values of each bin from the dataset.
# Note that this was done on the uncensored cases.
uncensored_df = df[df[censorship_col_name]==0]
labels, q_bins = pd.qcut(uncensored_df[time_col_name], q=n_bins, retbins=True, labels=False)
q_bins[0] = 0
q_bins[-1] = float('inf')
# Current q_bins list length == n bins + 1. There is no need to return q_bins[0]==0.
return q_bins[1::]
def compute_surv_metrics_eval(
bins_values,
all_survival_probs,
all_risk_scores,
train_BS,
test_BS,
years_of_interest=[1.0, 2.0, 3.0, 5.0],
yearI_of_interest=1.0,
yearF_of_interest=5.0,
time_step = 0.5):
# Note that we discretized the continuous time scale into N bins for training, thus having only N survival probabilities at each n time step.
# The Brier score does not return the exact same value if we use the discretize time or the continuous.
# Because it uses the distriution of censoring, even within each interval with same probability of survival.
# BS, IBS and AUC are time-dependent and step-dependent scores. See the notebook for some examples.
#years_of_interest = [1.0, 2.0, 3.0, 5.0] #This can be completely a choice, depending on what makes sense.
corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest])
# Note that this requires that survival times survival_test lie within the range of survival times survival_train.
# This can be achieved by specifying times accordingly, e.g. by setting times[-1] slightly below the maximum expected follow-up time.
_ , BS = brier_score(
survival_train=train_BS,
survival_test=test_BS,
estimate=all_survival_probs[:,corresponding_bins],
times=years_of_interest)
# The Integrated Brier Score (IBS) provides an overall calculation of the model performance at all available times.
# Both time points (at least two) and time step is a pure choice. Note that the time steps has an impact on the value of IBS.
#yearI_of_interest, yearF_of_interest = 1.0, 5.0 # Between 1Y and 5Y included.
#time_step = 0.5 #6month
years_of_interest_steps = np.arange(yearI_of_interest, yearF_of_interest+time_step, step=time_step, dtype=float) # Otherwise final year is not included.
corresponding_bins = np.asarray([np.argwhere(i<bins_values)[0,0] for i in years_of_interest_steps])
IBS = integrated_brier_score(
survival_train=train_BS,
survival_test=test_BS,
estimate=all_survival_probs[:,corresponding_bins],
times=years_of_interest_steps)
# The AUC can be extended to survival data by defining sensitivity (true positive rate) and specificity (true negative rate) as time-dependent measures.
# Cumulative cases are all individuals that experienced an event prior to or at time t (ti≤t), whereas dynamic controls are those with ti>t.
# The associated cumulative/dynamic AUC quantifies how well a model can distinguish subjects who fail by a given time (ti≤t) from subjects who fail after this time (ti>t).
cumAUC, meanAUC = cumulative_dynamic_auc(
survival_train=train_BS,
survival_test=test_BS,
estimate=all_risk_scores, #the AUC uses the risk scores like the C-index and not the event-free probabilities.
times=years_of_interest_steps, tied_tol=1e-08)
c_index_ipwc = concordance_index_ipcw(
survival_train=train_BS,
survival_test=test_BS,
estimate=all_risk_scores,
tied_tol=1e-08)[0]
return (BS, years_of_interest), (IBS, yearI_of_interest, yearF_of_interest), (cumAUC, meanAUC), (c_index_ipwc)