From f478fa213600e815045ed5b60850dbb499fd8f03 Mon Sep 17 00:00:00 2001 From: "Ivan Kuchin (ikuchin)" Date: Fri, 23 Aug 2024 18:35:56 -0400 Subject: [PATCH] wip: tile dataset --- config.py | 8 ++++--- dataset/pomc_dataset.py | 2 ++ dataset/saver.py | 33 +++++++++++++++++------------ predict.py | 23 +++++++++++++++++++- tools/craft_network/att_unet.py | 3 ++- tools/craft_network/att_unet_dsv.py | 6 ++++-- train_segmentation.py | 7 +++--- 7 files changed, 59 insertions(+), 23 deletions(-) diff --git a/config.py b/config.py index 9cb0ef1..7ea9184 100644 --- a/config.py +++ b/config.py @@ -8,10 +8,10 @@ # BACKGROUND_WEIGHT = 1 # must be calculated dynamically # FOREGROUND_WEIGHT = 7 # must be calculated dynamically -INITIAL_LEARNING_RATE = 0.001 +INITIAL_LEARNING_RATE = 1e-4 INSTANCE_NORM = False # not supported yet BATCH_NORM = True -BATCH_SIZE = 1 +BATCH_SIZE = 4 BATCH_NORM_MOMENTUM = 0.8 GRADIENT_ACCUMULATION_STEPS = 4 # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam#args @@ -27,7 +27,7 @@ PANCREAS_MIN_HU = -512 # -512 PANCREAS_MAX_HU = 1024 # 1024 -IMAGE_DIMENSION_X = 160 +IMAGE_DIMENSION_X = 96 IMAGE_DIMENSION_Y = IMAGE_DIMENSION_X IMAGE_DIMENSION_Z = IMAGE_DIMENSION_X @@ -47,6 +47,8 @@ IMAGE_ORIGINAL_DIMENSION_Y = IMAGE_ORIGINAL_DIMENSION_X IMAGE_ORIGINAL_DIMENSION_Z = IMAGE_ORIGINAL_DIMENSION_X +IS_TILE = True + # Dataset used for training # consists of pickle files of 3d numpy arrays TFRECORD_FOLDER = "c:/Users/ikuchin/Downloads/pancreas_data/dataset/" diff --git a/dataset/pomc_dataset.py b/dataset/pomc_dataset.py index e198366..cb0a5ec 100644 --- a/dataset/pomc_dataset.py +++ b/dataset/pomc_dataset.py @@ -47,6 +47,8 @@ def __init__(self, patients_src_folder, labels_src_folder, TFRECORD_FOLDER): self.min_HU = float("inf") self.max_HU = float("-inf") + # self.saver = SaverFactory(config.IS_TILE) + def get_patient_id_from_folder(self, folder): result = None m = re.search("(\\w+)$", folder) diff --git a/dataset/saver.py b/dataset/saver.py index 21ace46..b85389b 100644 --- a/dataset/saver.py +++ b/dataset/saver.py @@ -12,7 +12,7 @@ import config as config class Slicer: - def __init__(self, data, label): + def __init__(self, data, label, augment_margin = [0, 0, 0]): x = math.ceil(data.shape[0] / config.IMAGE_DIMENSION_X) * config.IMAGE_DIMENSION_X y = math.ceil(data.shape[1] / config.IMAGE_DIMENSION_Y) * config.IMAGE_DIMENSION_Y z = math.ceil(data.shape[2] / config.IMAGE_DIMENSION_Z) * config.IMAGE_DIMENSION_Z @@ -23,26 +23,26 @@ def __init__(self, data, label): self.data[:data.shape[0], :data.shape[1], :data.shape[2]] = data self.label[:label.shape[0], :label.shape[1], :label.shape[2]] = label + self.augment_margin = augment_margin + def __iter__(self): - augment_margin = [ - int(config.IMAGE_DIMENSION_X * config.AUGMENTATIO_SHIFT_MARGIN), - int(config.IMAGE_DIMENSION_Y * config.AUGMENTATIO_SHIFT_MARGIN), - int(config.IMAGE_DIMENSION_Z * config.AUGMENTATIO_SHIFT_MARGIN) - ] for x in range(0, self.data.shape[0], config.IMAGE_DIMENSION_X): for y in range(0, self.data.shape[1], config.IMAGE_DIMENSION_Y): for z in range(0, self.data.shape[2], config.IMAGE_DIMENSION_Z): - x_start = np.max([x - augment_margin[0], 0]) - y_start = np.max([y - augment_margin[1], 0]) - z_start = np.max([z - augment_margin[2], 0]) + x_start = np.max([x - self.augment_margin[0], 0]) + y_start = np.max([y - self.augment_margin[1], 0]) + z_start = np.max([z - self.augment_margin[2], 0]) - x_finish = np.min([x + config.IMAGE_DIMENSION_X + augment_margin[0], self.data.shape[0]]) - y_finish = np.min([y + config.IMAGE_DIMENSION_Y + augment_margin[1], self.data.shape[1]]) - z_finish = np.min([z + config.IMAGE_DIMENSION_Z + augment_margin[2], self.data.shape[2]]) + x_finish = np.min([x + config.IMAGE_DIMENSION_X + self.augment_margin[0], self.data.shape[0]]) + y_finish = np.min([y + config.IMAGE_DIMENSION_Y + self.augment_margin[1], self.data.shape[1]]) + z_finish = np.min([z + config.IMAGE_DIMENSION_Z + self.augment_margin[2], self.data.shape[2]]) data = self.data [x_start:x_finish, y_start:y_finish, z_start:z_finish] label = self.label[x_start:x_finish, y_start:y_finish, z_start:z_finish] + if np.max(label) == 0: + continue + yield data, label, x, y, z @@ -61,7 +61,14 @@ def save(self, src_data, label_data): src_data = np.cast[np.float32](src_data) label_data = np.cast[np.int8](label_data) - for (data, label, x, y, z) in Slicer(src_data, label_data): + augment_margin = [ + int(config.IMAGE_DIMENSION_X * config.AUGMENTATIO_SHIFT_MARGIN), + int(config.IMAGE_DIMENSION_Y * config.AUGMENTATIO_SHIFT_MARGIN), + int(config.IMAGE_DIMENSION_Z * config.AUGMENTATIO_SHIFT_MARGIN) + ] + + + for (data, label, x, y, z) in Slicer(src_data, label_data, augment_margin=augment_margin): # print(f"Saving slice at {x}, {y}, {z}...") np.savez_compressed(os.path.join(self.folder, self.subfolder, self.patient_id + f"_cut-{self.percentage}_slice-{x}-{y}-{z}.npz", ), [data, label]) diff --git a/predict.py b/predict.py index d401ed6..10ee595 100644 --- a/predict.py +++ b/predict.py @@ -7,6 +7,8 @@ import nibabel as nib import tools.craft_network as craft_network import config as config +from tools.predict.predict_no_tile import PredictNoTile +from tools.predict.predict_tile import PredictTile class Predict: @@ -125,6 +127,15 @@ def __resize_segmentation_to_dcm_shape(self, mask, dcm_slices): return result def __save_img_to_nifti(self, data, affine, result_file_name): + # TODO: add meta information + # affine = meta['affine'][0].cpu().numpy() + # pixdim = meta['pixdim'][0].cpu().numpy() + # dim = meta['dim'][0].cpu().numpy() + + # img = nib.Nifti1Image(input_nii_array, affine=affine) + # img.header['dim'] = dim + # img.header['pixdim'] = pixdim + img_to_save = nib.Nifti1Image(data, affine) nib.save(img_to_save, result_file_name) @@ -136,6 +147,15 @@ def __print_stat(self, data, title=""): tf.reduce_mean(tf.cast(data, dtype = tf.float32)), tf.reduce_max(data), tf.reduce_sum(data))) + def __predict(self, src_data, model): + if config.IS_TILE == True: + predict_class = PredictTile(model) + else: + predict_class = PredictNoTile(model) + + prediction = predict_class(src_data) + return prediction + def main(self, dcm_folder, result_file_name): dcm_slices = self.__read_dcm_slices(dcm_folder) raw_pixel_data = self.__get_pixel_data(dcm_slices) @@ -146,7 +166,8 @@ def main(self, dcm_folder, result_file_name): # model.summary() - prediction = model.predict(src_data) + # prediction = model.predict(src_data) + prediction = self.__predict(src_data, model) mask = self.__create_segmentation(prediction) mask = self.__resize_segmentation_to_dcm_shape(mask, dcm_slices) mask = tf.squeeze(mask) diff --git a/tools/craft_network/att_unet.py b/tools/craft_network/att_unet.py index 9903564..6816e60 100644 --- a/tools/craft_network/att_unet.py +++ b/tools/craft_network/att_unet.py @@ -43,7 +43,8 @@ def res_block(filters, input_shape, kernel_size, apply_batchnorm, apply_instance if (apply_dropout): x = tf.keras.layers.Dropout(0.5)(x) - return tf.keras.models.Model(inputs = input_layer, outputs = x) + model = tf.keras.models.Model(inputs = input_layer, outputs = x, name = "res_block_{}_{}".format(input_shape[-1], filters)) + return model def double_conv(filters, input_shape, kernel_size, apply_batchnorm, apply_instancenorm, apply_dropout=False): model = tf.keras.models.Sequential() diff --git a/tools/craft_network/att_unet_dsv.py b/tools/craft_network/att_unet_dsv.py index 991c547..dadad52 100644 --- a/tools/craft_network/att_unet_dsv.py +++ b/tools/craft_network/att_unet_dsv.py @@ -29,7 +29,7 @@ def craft_network(checkpoint_file = None, apply_batchnorm = True, apply_instance if idx < len(filters) - 1: x = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2), padding = "same")(x) - gating_base = get_gating_base(filters[-2], apply_batchnorm)(x) + gating_base = get_gating_base(filters[-1], apply_batchnorm)(x) dsv_outputs = [] skip_conns = reversed(generator_steps_output[:-1]) @@ -40,7 +40,9 @@ def craft_network(checkpoint_file = None, apply_batchnorm = True, apply_instance # --- don't gate signal due to no useful features at top level gated_skip = skip_conn else: - gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, gating_base)) + if idx == 0: + gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, gating_base)) + gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, x)) x = tf.keras.layers.Concatenate(name = "concat_{}".format(_filter))([x, gated_skip]) x = res_block(_filter, x.shape, kernel_size = config.KERNEL_SIZE, apply_batchnorm = apply_batchnorm, apply_instancenorm = apply_instancenorm)(x) diff --git a/train_segmentation.py b/train_segmentation.py index 1c24f6f..78f9606 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -22,7 +22,7 @@ def get_csv_dir(): def __dice_coef(y_true, y_pred): - gamma = 100.0 + gamma = 0.01 y_true = tf.cast(y_true, dtype = tf.float32) y_pred = tf.cast(y_pred[..., 1:2], dtype = tf.float32) @@ -30,11 +30,12 @@ def __dice_coef(y_true, y_pred): # print("y_pred shape: ", y_pred.shape) intersection = tf.reduce_sum(y_true * y_pred) - dice = (2. * intersection + gamma) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + gamma) + union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + dice = (2. * intersection + gamma) / (union + gamma) return dice def __dice_loss(y_true, y_pred): - return -tf.math.log(__dice_coef(y_true, y_pred)) + return 1 - __dice_coef(y_true, y_pred) def __weighted_loss(y_true, y_pred):