Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: tile dataset #29

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
# BACKGROUND_WEIGHT = 1 # must be calculated dynamically
# FOREGROUND_WEIGHT = 7 # must be calculated dynamically

INITIAL_LEARNING_RATE = 0.001
INITIAL_LEARNING_RATE = 1e-4
INSTANCE_NORM = False # not supported yet
BATCH_NORM = True
BATCH_SIZE = 1
BATCH_SIZE = 4
BATCH_NORM_MOMENTUM = 0.8

GRADIENT_ACCUMULATION_STEPS = 4 # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam#args
Expand All @@ -27,7 +27,7 @@
PANCREAS_MIN_HU = -512 # -512
PANCREAS_MAX_HU = 1024 # 1024

IMAGE_DIMENSION_X = 160
IMAGE_DIMENSION_X = 96
IMAGE_DIMENSION_Y = IMAGE_DIMENSION_X
IMAGE_DIMENSION_Z = IMAGE_DIMENSION_X

Expand All @@ -47,6 +47,8 @@
IMAGE_ORIGINAL_DIMENSION_Y = IMAGE_ORIGINAL_DIMENSION_X
IMAGE_ORIGINAL_DIMENSION_Z = IMAGE_ORIGINAL_DIMENSION_X

IS_TILE = True

# Dataset used for training
# consists of pickle files of 3d numpy arrays
TFRECORD_FOLDER = "c:/Users/ikuchin/Downloads/pancreas_data/dataset/"
Expand Down
2 changes: 2 additions & 0 deletions dataset/pomc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, patients_src_folder, labels_src_folder, TFRECORD_FOLDER):
self.min_HU = float("inf")
self.max_HU = float("-inf")

# self.saver = SaverFactory(config.IS_TILE)

def get_patient_id_from_folder(self, folder):
result = None
m = re.search("(\\w+)$", folder)
Expand Down
33 changes: 20 additions & 13 deletions dataset/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import config as config

class Slicer:
def __init__(self, data, label):
def __init__(self, data, label, augment_margin = [0, 0, 0]):
x = math.ceil(data.shape[0] / config.IMAGE_DIMENSION_X) * config.IMAGE_DIMENSION_X
y = math.ceil(data.shape[1] / config.IMAGE_DIMENSION_Y) * config.IMAGE_DIMENSION_Y
z = math.ceil(data.shape[2] / config.IMAGE_DIMENSION_Z) * config.IMAGE_DIMENSION_Z
Expand All @@ -23,26 +23,26 @@ def __init__(self, data, label):
self.data[:data.shape[0], :data.shape[1], :data.shape[2]] = data
self.label[:label.shape[0], :label.shape[1], :label.shape[2]] = label

self.augment_margin = augment_margin

def __iter__(self):
augment_margin = [
int(config.IMAGE_DIMENSION_X * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Y * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Z * config.AUGMENTATIO_SHIFT_MARGIN)
]
for x in range(0, self.data.shape[0], config.IMAGE_DIMENSION_X):
for y in range(0, self.data.shape[1], config.IMAGE_DIMENSION_Y):
for z in range(0, self.data.shape[2], config.IMAGE_DIMENSION_Z):
x_start = np.max([x - augment_margin[0], 0])
y_start = np.max([y - augment_margin[1], 0])
z_start = np.max([z - augment_margin[2], 0])
x_start = np.max([x - self.augment_margin[0], 0])
y_start = np.max([y - self.augment_margin[1], 0])
z_start = np.max([z - self.augment_margin[2], 0])

x_finish = np.min([x + config.IMAGE_DIMENSION_X + augment_margin[0], self.data.shape[0]])
y_finish = np.min([y + config.IMAGE_DIMENSION_Y + augment_margin[1], self.data.shape[1]])
z_finish = np.min([z + config.IMAGE_DIMENSION_Z + augment_margin[2], self.data.shape[2]])
x_finish = np.min([x + config.IMAGE_DIMENSION_X + self.augment_margin[0], self.data.shape[0]])
y_finish = np.min([y + config.IMAGE_DIMENSION_Y + self.augment_margin[1], self.data.shape[1]])
z_finish = np.min([z + config.IMAGE_DIMENSION_Z + self.augment_margin[2], self.data.shape[2]])

data = self.data [x_start:x_finish, y_start:y_finish, z_start:z_finish]
label = self.label[x_start:x_finish, y_start:y_finish, z_start:z_finish]

if np.max(label) == 0:
continue

yield data, label, x, y, z


Expand All @@ -61,7 +61,14 @@ def save(self, src_data, label_data):
src_data = np.cast[np.float32](src_data)
label_data = np.cast[np.int8](label_data)

for (data, label, x, y, z) in Slicer(src_data, label_data):
augment_margin = [
int(config.IMAGE_DIMENSION_X * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Y * config.AUGMENTATIO_SHIFT_MARGIN),
int(config.IMAGE_DIMENSION_Z * config.AUGMENTATIO_SHIFT_MARGIN)
]


for (data, label, x, y, z) in Slicer(src_data, label_data, augment_margin=augment_margin):
# print(f"Saving slice at {x}, {y}, {z}...")
np.savez_compressed(os.path.join(self.folder, self.subfolder, self.patient_id + f"_cut-{self.percentage}_slice-{x}-{y}-{z}.npz", ), [data, label])

Expand Down
23 changes: 22 additions & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import nibabel as nib
import tools.craft_network as craft_network
import config as config
from tools.predict.predict_no_tile import PredictNoTile
from tools.predict.predict_tile import PredictTile


class Predict:
Expand Down Expand Up @@ -125,6 +127,15 @@ def __resize_segmentation_to_dcm_shape(self, mask, dcm_slices):
return result

def __save_img_to_nifti(self, data, affine, result_file_name):
# TODO: add meta information
# affine = meta['affine'][0].cpu().numpy()
# pixdim = meta['pixdim'][0].cpu().numpy()
# dim = meta['dim'][0].cpu().numpy()

# img = nib.Nifti1Image(input_nii_array, affine=affine)
# img.header['dim'] = dim
# img.header['pixdim'] = pixdim

img_to_save = nib.Nifti1Image(data, affine)
nib.save(img_to_save, result_file_name)

Expand All @@ -136,6 +147,15 @@ def __print_stat(self, data, title=""):
tf.reduce_mean(tf.cast(data, dtype = tf.float32)),
tf.reduce_max(data), tf.reduce_sum(data)))

def __predict(self, src_data, model):
if config.IS_TILE == True:
predict_class = PredictTile(model)
else:
predict_class = PredictNoTile(model)

prediction = predict_class(src_data)
return prediction

def main(self, dcm_folder, result_file_name):
dcm_slices = self.__read_dcm_slices(dcm_folder)
raw_pixel_data = self.__get_pixel_data(dcm_slices)
Expand All @@ -146,7 +166,8 @@ def main(self, dcm_folder, result_file_name):

# model.summary()

prediction = model.predict(src_data)
# prediction = model.predict(src_data)
prediction = self.__predict(src_data, model)
mask = self.__create_segmentation(prediction)
mask = self.__resize_segmentation_to_dcm_shape(mask, dcm_slices)
mask = tf.squeeze(mask)
Expand Down
3 changes: 2 additions & 1 deletion tools/craft_network/att_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def res_block(filters, input_shape, kernel_size, apply_batchnorm, apply_instance
if (apply_dropout):
x = tf.keras.layers.Dropout(0.5)(x)

return tf.keras.models.Model(inputs = input_layer, outputs = x)
model = tf.keras.models.Model(inputs = input_layer, outputs = x, name = "res_block_{}_{}".format(input_shape[-1], filters))
return model

def double_conv(filters, input_shape, kernel_size, apply_batchnorm, apply_instancenorm, apply_dropout=False):
model = tf.keras.models.Sequential()
Expand Down
6 changes: 4 additions & 2 deletions tools/craft_network/att_unet_dsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def craft_network(checkpoint_file = None, apply_batchnorm = True, apply_instance
if idx < len(filters) - 1:
x = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2), padding = "same")(x)

gating_base = get_gating_base(filters[-2], apply_batchnorm)(x)
gating_base = get_gating_base(filters[-1], apply_batchnorm)(x)

dsv_outputs = []
skip_conns = reversed(generator_steps_output[:-1])
Expand All @@ -40,7 +40,9 @@ def craft_network(checkpoint_file = None, apply_batchnorm = True, apply_instance
# --- don't gate signal due to no useful features at top level
gated_skip = skip_conn
else:
gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, gating_base))
if idx == 0:
gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, gating_base))
gated_skip = AttGate(apply_batchnorm = apply_batchnorm)((skip_conn, x))

x = tf.keras.layers.Concatenate(name = "concat_{}".format(_filter))([x, gated_skip])
x = res_block(_filter, x.shape, kernel_size = config.KERNEL_SIZE, apply_batchnorm = apply_batchnorm, apply_instancenorm = apply_instancenorm)(x)
Expand Down
7 changes: 4 additions & 3 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@ def get_csv_dir():


def __dice_coef(y_true, y_pred):
gamma = 100.0
gamma = 0.01
y_true = tf.cast(y_true, dtype = tf.float32)
y_pred = tf.cast(y_pred[..., 1:2], dtype = tf.float32)

# print("y_true shape: ", y_true.shape)
# print("y_pred shape: ", y_pred.shape)

intersection = tf.reduce_sum(y_true * y_pred)
dice = (2. * intersection + gamma) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + gamma)
union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
dice = (2. * intersection + gamma) / (union + gamma)
return dice

def __dice_loss(y_true, y_pred):
return -tf.math.log(__dice_coef(y_true, y_pred))
return 1 - __dice_coef(y_true, y_pred)


def __weighted_loss(y_true, y_pred):
Expand Down
Loading