Skip to content

Commit

Permalink
Fix the random seed when bootstrap in Survival analysis evalution.
Browse files Browse the repository at this point in the history
  • Loading branch information
salan668 committed Feb 5, 2024
1 parent 4321dd8 commit 12b94c6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
4 changes: 2 additions & 2 deletions SA/FeatureSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ def ClusterSelect(dc, number):
sub_features = []

if number == 1:
mylog.warning('The minimum number of KMeans is 2, Kmeans doesn\'t apply')
# mylog.warning('The minimum number of KMeans is 2, Kmeans doesn\'t apply')
pccs = [abs(pearsonr(dc.df[one_feature].values, dc.event)[0]) for one_feature in dc.feature_name]
selected_feature = dc.feature_name[pccs.index(max(pccs))]
sub_features.append(selected_feature)
else:
clusters = KMeans(n_clusters=number, random_state=0, init='k-means++').fit_predict(dc.array.transpose())
clusters = KMeans(n_clusters=number, random_state=0, init='k-means++', n_init="auto").fit_predict(dc.array.transpose())

for i in range(number):
clustering_features = [name for cluster_index, name in zip(clusters, dc.feature_name) if cluster_index == i]
Expand Down
8 changes: 4 additions & 4 deletions SA/Fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import pickle
import random
from abc import abstractmethod
from lifelines import CoxPHFitter, AalenAdditiveFitter

from lifelines.utils.printer import Printer
Expand All @@ -15,6 +14,8 @@
from SA.Utility import mylog
from SA.DataContainer import DataContainer

random.seed(0)


class BaseFitter(object):
def __init__(self, fitter=None, name=None):
Expand Down Expand Up @@ -48,9 +49,8 @@ def Summary(self):


class CoxPH(BaseFitter):
def __init__(self):
random.seed(0)
super(CoxPH, self).__init__(CoxPHFitter(), self.__class__.__name__)
def __init__(self, penalizer=0.1):
super(CoxPH, self).__init__(CoxPHFitter(penalizer=penalizer), self.__class__.__name__)

def Fit(self, dc: DataContainer):
self.fitter.fit(dc.df, duration_col=dc.duration_name, event_col=dc.event_name)
Expand Down
20 changes: 15 additions & 5 deletions SA/PipelineManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import os
import csv

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -183,10 +184,15 @@ def RunWithoutCV(self, train_dc, test_dc=None, store_folder=None):
fs_store_folder = MakeFolder(reduce_store_folder, '{}_{}'.format(
feature_selector.name, feature_number))

feature_selector.selected_number = feature_number
if feature_number >= len(dr_train_dc.GetFeatureName()):
feature_selector.selected_number = len(dr_train_dc.GetFeatureName())
else:
feature_selector.selected_number = feature_number

feature_selector.Fit(dr_train_dc)
fs_train_dc = feature_selector.Transform(dr_train_dc, store_folder=fs_store_folder,
store_key=TRAIN)

if test_dc is not None and not test_dc.IsEmpty():
fs_test_dc = feature_selector.Transform(dr_test_dc, store_folder=fs_store_folder,
store_key=TEST)
Expand Down Expand Up @@ -259,10 +265,14 @@ def RunCV(self, dc, store_folder=None):
fs_store_folder = MakeFolder(reduce_store_folder, '{}_{}'.format(
feature_selector.name, feature_number))

feature_selector.selected_number = feature_number
feature_selector.Fit(dr_cv_train_dc)
fs_cv_train_dc = feature_selector.Transform(dr_cv_train_dc)
fs_cv_val_dc = feature_selector.Transform(dr_cv_val_dc)
if feature_number >= len(dr_cv_train_dc.GetFeatureName()):
fs_cv_train_dc = dr_cv_train_dc
fs_cv_val_dc = dr_cv_val_dc
else:
feature_selector.selected_number = feature_number
feature_selector.Fit(dr_cv_train_dc)
fs_cv_train_dc = feature_selector.Transform(dr_cv_train_dc)
fs_cv_val_dc = feature_selector.Transform(dr_cv_val_dc)

for fitter_index, fitter in enumerate(self.fitters):
fitter_store_folder = MakeFolder(fs_store_folder, fitter.name)
Expand Down
6 changes: 3 additions & 3 deletions SA/Utility/Matric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--2021/1/18
"""
import numpy as np
from random import choices
import random
from pycox.evaluation import EvalSurv

from SA.Utility.Constant import *
Expand All @@ -17,11 +17,11 @@ def __init__(self, bootstrap_n=100):
self.bootstrap_n = bootstrap_n

def Bootstrap(self, surv, event: list, duration: list):
np.random.seed(42) # control reproducibility
random.seed(42) # control reproducibility

cindex, brier, nbll = [], [], []
for _ in range(self.bootstrap_n):
sampled_index = choices(range(surv.shape[1]), k=surv.shape[1])
sampled_index = random.choices(range(surv.shape[1]), k=surv.shape[1])

sampled_surv = surv.iloc[:, sampled_index]
sampled_event = [event[i] for i in sampled_index]
Expand Down

0 comments on commit 12b94c6

Please sign in to comment.