Skip to content

Commit

Permalink
Merge pull request #25 from IvanKuchin/development
Browse files Browse the repository at this point in the history
Model output shape replaced from array to a single tensor
  • Loading branch information
IvanKuchin authored Aug 12, 2024
2 parents b9c27af + aa163c5 commit 5b05193
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
BATCH_SIZE = 1
BATCH_NORM_MOMENTUM = 0.8

GRADIENT_ACCUMULATION_STEPS = None # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam#args
GRADIENT_ACCUMULATION_STEPS = 4 # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam#args

# Option 1) HU range for pancreas in CT scans from 30 to 400
# https://radiopaedia.org/articles/windowing-ct?lang=us
Expand Down
2 changes: 1 addition & 1 deletion tools/craft_network/att_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def call(self, inputs):


__sum = self.add_g_x([theta_x, phi_g])
__activation_sum = tf.keras.layers.Activation("relu")(__sum)
__activation_sum = tf.keras.layers.LeakyReLU()(__sum)

psi = tf.keras.layers.Activation("sigmoid")(self.psi(__activation_sum))
psi = self.psi_upsample(psi)
Expand Down
2 changes: 1 addition & 1 deletion tools/craft_network/att_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def craft_network(checkpoint_file = None, apply_batchnorm=True, apply_instanceno

output_layer = tf.keras.layers.Conv3D(2, kernel_size = 1, padding = "same", kernel_initializer = "he_uniform", activation="softmax")(x)

model = tf.keras.models.Model(inputs = [inputs], outputs = [output_layer])
model = tf.keras.models.Model(inputs = [inputs], outputs = output_layer)

if checkpoint_file and os.path.exists(checkpoint_file):
print("Loading weights from checkpoint ", checkpoint_file)
Expand Down
2 changes: 1 addition & 1 deletion tools/craft_network/att_unet_dsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def craft_network(checkpoint_file = None, apply_batchnorm = True, apply_instance
output_layer = tf.keras.layers.Conv3D(2, kernel_size = 1, padding = "same", kernel_initializer = "he_uniform", activation="softmax")(concat_layer)
# output_layer = tf.keras.activations.softmax(output_layer)

model = tf.keras.models.Model(inputs = [inputs], outputs = [output_layer])
model = tf.keras.models.Model(inputs = [inputs], outputs = output_layer)

if checkpoint_file and os.path.exists(checkpoint_file):
print("Loading weights from checkpoint ", checkpoint_file)
Expand Down
4 changes: 3 additions & 1 deletion tools/craft_network/dsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def call(self, inputs, **kwargs):
if self.scale_factor > 1:
x = self.upsample(x)

x = tf.keras.activations.softmax(x)

return x

##########################
Expand All @@ -44,7 +46,7 @@ def get_config(self):
dsv_model = DSV(scale_factor = 4)
result = dsv_model(x_shape)

model = tf.keras.models.Model(inputs = [inp], outputs = [result])
model = tf.keras.models.Model(inputs = [inp], outputs = result)
outputs = model(rnd)

model.summary()
Expand Down
2 changes: 1 addition & 1 deletion tools/craft_network/unet_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def craft_network(checkpoint_file = None, apply_batchnorm=True):

output_layer = tf.keras.layers.Conv3D(2, kernel_size = 1, padding = "same", kernel_initializer = "he_uniform")(x)

model = tf.keras.models.Model(inputs = [inputs], outputs = [output_layer])
model = tf.keras.models.Model(inputs = [inputs], outputs = output_layer)

if checkpoint_file and os.path.exists(checkpoint_file):
print("Loading weights from checkpoint ", checkpoint_file)
Expand Down
2 changes: 1 addition & 1 deletion tools/craft_network/unet_shortcuts_every_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def craft_network(checkpoint_file = None, apply_batchnorm=True):

output_layer = tf.keras.layers.Conv3D(2, kernel_size = 1, padding = "same", kernel_initializer = "he_uniform")(x)

model = tf.keras.models.Model(inputs = [inputs], outputs = [output_layer])
model = tf.keras.models.Model(inputs = [inputs], outputs = output_layer)

if checkpoint_file and os.path.exists(checkpoint_file):
print("Loading weights from checkpoint ", checkpoint_file)
Expand Down
4 changes: 2 additions & 2 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def main():
mode = config.MONITOR_MODE)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor = config.MONITOR_METRIC, mode = config.MONITOR_MODE, patience = 200,
verbose = 1)
model.compile(optimizer = tf.keras.optimizers.Adam(
model.compile(optimizer = tf.keras.optimizers.AdamW(
learning_rate = config.INITIAL_LEARNING_RATE,
gradient_accumulation_steps = config.GRADIENT_ACCUMULATION_STEPS,
# gradient_accumulation_steps = config.GRADIENT_ACCUMULATION_STEPS,
),
loss = __dice_loss,
metrics = [
Expand Down

0 comments on commit 5b05193

Please sign in to comment.