From 7c903ed92a5bd8f07963e4160bcdd96e4deeaaef Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 11 Oct 2024 12:10:35 -0400 Subject: [PATCH 001/119] test steps in diffusion plot --- ml4h/models/train.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 368a2aeee..8eb58cb41 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -210,6 +210,18 @@ def train_diffusion_control_model(args): layer_range=None, show_layer_activations=False, ) + tf.keras.utils.plot_model( + model.control_embed_model, + to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_control_embed.png", + show_shapes=True, + show_dtype=False, + show_layer_names=True, + rankdir="TB", + expand_nested=True, + dpi=args.dpi, + layer_range=None, + show_layer_activations=True, + ) if os.path.exists(checkpoint_path+'.index'): model.load_weights(checkpoint_path) @@ -230,7 +242,7 @@ def train_diffusion_control_model(args): plot_metric_history(history, args.training_steps, args.id, os.path.dirname(checkpoint_path)) if args.inspect_model: if model.input_map.axes() == 2: - model.plot_ecgs(num_rows=4, prefix=os.path.dirname(checkpoint_path)) + model.plot_ecgs(num_rows=args.test_steps, prefix=os.path.dirname(checkpoint_path)) else: - model.plot_images(num_rows=4, prefix=os.path.dirname(checkpoint_path)) + model.plot_images(num_rows=args.test_steps, prefix=os.path.dirname(checkpoint_path)) return model From 088020e75cc143b8607cebf49c4c63009c4bcc42 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 16 Oct 2024 09:37:32 -0400 Subject: [PATCH 002/119] test steps in diffusion plot --- ml4h/models/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 8eb58cb41..b6510ea4b 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -128,6 +128,7 @@ def train_diffusion_model(args): # calculate mean and variance of training dataset for normalization model.normalizer.adapt(feature_batch) if args.inspect_model: + model.network.summary(print_fn=logging.info, expand_nested=True) tf.keras.utils.plot_model( model.network, to_file=f"{args.output_folder}/{args.id}/architecture_diffusion_unet.png", @@ -198,6 +199,7 @@ def train_diffusion_control_model(args): # calculate mean and variance of training dataset for normalization model.normalizer.adapt(feature_batch) if args.inspect_model: + model.network.summary(print_fn=logging.info, expand_nested=True) tf.keras.utils.plot_model( model.network, to_file=f"{args.output_folder}/{args.id}/architecture_{args.id}_unet.png", From 938215910a40eb5bacc68b58c79b908c18768691 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 16 Oct 2024 09:48:47 -0400 Subject: [PATCH 003/119] test steps in diffusion plot --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index a667c00e1..4ce09b78e 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -23,8 +23,8 @@ max_signal_rate = 0.95 # architecture -embedding_dims = 2048 -embedding_max_frequency = 2000.0 +embedding_dims = 256 +embedding_max_frequency = 1000.0 # optimization ema = 0.999 From c4ca4cfb6a1cad57edaebbd26d40b559657f2006 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 13:04:59 -0400 Subject: [PATCH 004/119] test steps in diffusion plot --- ml4h/models/train.py | 68 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index b6510ea4b..be9e90b2e 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -4,6 +4,10 @@ import logging from typing import List, Tuple, Iterable, Union +import datetime +import numpy as np +import matplotlib.pyplot as plt + import tensorflow as tf from tensorflow import keras import tensorflow_addons as tfa @@ -12,11 +16,12 @@ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback from ml4h.TensorMap import TensorMap +from ml4h.metrics import coefficient_of_determination from ml4h.models.diffusion_blocks import DiffusionModel, DiffusionController from ml4h.plots import plot_metric_history from ml4h.defines import IMAGE_EXT, MODEL_EXT from ml4h.models.inspect import plot_and_time_model -from ml4h.models.model_factory import get_custom_objects +from ml4h.models.model_factory import get_custom_objects, make_multimodal_multitask_model from ml4h.tensor_generators import test_train_valid_tensor_generators @@ -167,6 +172,53 @@ def train_diffusion_model(args): return model +def get_eval_model(args, model_file, output_tmap): + args.tensor_maps_out = [output_tmap] + eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + return eval_model + + +def regress_on_batch(diffuser, regressor, controls, tm_out, batch_size): + control_batch = {} + control_batch[tm_out.output_name()] = controls + + control_embed = diffuser.control_embed_model(control_batch) + generated_images = diffuser.generate( + control_embed, + num_images=batch_size, + diffusion_steps=50, + ) + logging.info(f'generated_images control_batch was {generated_images.shape}') + control_predictions = regressor.predict(generated_images) + logging.info(f'Control zip preds was {list(zip(controls, control_predictions))} ') + return control_predictions[:, 0] + + +def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batch_size, std, prefix): + preds = [] + all_controls = [] + # controls = np.arange(-8, 8, 1) + + for _ in range(batches): + controls = np.random.normal(0, std, size=batch_size) + preds.append(regress_on_batch(diffuser, regressor, controls, tm_out, batch_size)) + all_controls.append(controls) + + preds = np.array(preds).flatten() + all_controls = np.array(all_controls).flatten() + print(f'Control Predictions was {np.array(preds).shape} Control true was {np.array(all_controls).shape}') + pearson = np.corrcoef(preds, all_controls)[1, 0] + print(f'Pearson correlation {pearson:0.3f} ') + plt.scatter(preds, all_controls) + plt.title(f'''Diffusion Phenotype: {tm_out.name} Control vs Predictions + Pearson correlation {pearson:0.3f}, $R^2$ {coefficient_of_determination(preds, all_controls):0.3f}, N = {len(preds)}''') + now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'metrics_{tm_out.name}_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path) + + def train_diffusion_control_model(args): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) model = DiffusionController( @@ -239,12 +291,20 @@ def train_diffusion_control_model(args): validation_steps=args.validation_steps, callbacks=[checkpoint_callback], ) + plot_metric_history(history, args.training_steps, args.id, os.path.dirname(checkpoint_path)) model.load_weights(checkpoint_path) - plot_metric_history(history, args.training_steps, args.id, os.path.dirname(checkpoint_path)) if args.inspect_model: if model.input_map.axes() == 2: - model.plot_ecgs(num_rows=args.test_steps, prefix=os.path.dirname(checkpoint_path)) + model.plot_ecgs(num_rows=2, prefix=os.path.dirname(checkpoint_path)) else: - model.plot_images(num_rows=args.test_steps, prefix=os.path.dirname(checkpoint_path)) + model.plot_images(num_rows=2, prefix=os.path.dirname(checkpoint_path)) + + for tm_out, model_file in zip(args.tensor_maps_out, args.model_files): + args.tensor_maps_out = [tm_out] + args.model_file = model_file + eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, 5, + f'{args.output_folder}/{args.id}/') + return model From 26ea28c1f62b786f873e6eea97b1432cba576c73 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 13:12:19 -0400 Subject: [PATCH 005/119] test steps in diffusion plot --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index be9e90b2e..7796c519e 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -4,9 +4,9 @@ import logging from typing import List, Tuple, Iterable, Union -import datetime import numpy as np import matplotlib.pyplot as plt +from datetime import datetime import tensorflow as tf from tensorflow import keras From a07b7e554e6c6bef150b46bff87a96b98b1bf480 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 13:17:26 -0400 Subject: [PATCH 006/119] test steps in diffusion plot --- ml4h/models/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 7796c519e..c6a5187b5 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -179,8 +179,7 @@ def get_eval_model(args, model_file, output_tmap): def regress_on_batch(diffuser, regressor, controls, tm_out, batch_size): - control_batch = {} - control_batch[tm_out.output_name()] = controls + control_batch = {tm_out.output_name(): controls} control_embed = diffuser.control_embed_model(control_batch) generated_images = diffuser.generate( @@ -188,9 +187,7 @@ def regress_on_batch(diffuser, regressor, controls, tm_out, batch_size): num_images=batch_size, diffusion_steps=50, ) - logging.info(f'generated_images control_batch was {generated_images.shape}') control_predictions = regressor.predict(generated_images) - logging.info(f'Control zip preds was {list(zip(controls, control_predictions))} ') return control_predictions[:, 0] @@ -199,10 +196,13 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc all_controls = [] # controls = np.arange(-8, 8, 1) - for _ in range(batches): + for i in range(batches): controls = np.random.normal(0, std, size=batch_size) preds.append(regress_on_batch(diffuser, regressor, controls, tm_out, batch_size)) all_controls.append(controls) + if i % 4 == 0: + logging.info(f'Inferred on {i+1} synthetic diffusion batches of {batches}') + preds = np.array(preds).flatten() all_controls = np.array(all_controls).flatten() From a15baa26bf516fdfac5af060b59b550e308b41c5 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 13:20:15 -0400 Subject: [PATCH 007/119] test steps in diffusion plot --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index c6a5187b5..345b451a9 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -187,7 +187,7 @@ def regress_on_batch(diffuser, regressor, controls, tm_out, batch_size): num_images=batch_size, diffusion_steps=50, ) - control_predictions = regressor.predict(generated_images) + control_predictions = regressor.predict(generated_images, verbose=0) return control_predictions[:, 0] From a627f1f286419b8d70aa3477dd050b70591a21c1 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 14:08:05 -0400 Subject: [PATCH 008/119] test steps in diffusion plot --- ml4h/models/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 345b451a9..841d0360e 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -197,7 +197,10 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc # controls = np.arange(-8, 8, 1) for i in range(batches): - controls = np.random.normal(0, std, size=batch_size) + if tm_out.is_continuous(): + controls = np.random.normal(0, std, size=batch_size) + elif tm_out.is_categorical(): + controls = np.eye(tm_out.shape[-1])[np.random.choice(tm_out.shape[-1], batch_size)] preds.append(regress_on_batch(diffuser, regressor, controls, tm_out, batch_size)) all_controls.append(controls) if i % 4 == 0: From 01b1756decf679424d0932bcc3bb76b3e80835f2 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 14:13:21 -0400 Subject: [PATCH 009/119] test steps in diffusion plot --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 841d0360e..6c509f47f 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -188,7 +188,7 @@ def regress_on_batch(diffuser, regressor, controls, tm_out, batch_size): diffusion_steps=50, ) control_predictions = regressor.predict(generated_images, verbose=0) - return control_predictions[:, 0] + return control_predictions[:, 0] if tm_out.is_continuous() else control_predictions def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batch_size, std, prefix): From 84386f151de254c16e81530e5035ecbda26a4185 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 14:20:39 -0400 Subject: [PATCH 010/119] test steps in diffusion plot --- ml4h/models/train.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 6c509f47f..45afda87d 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -18,7 +18,7 @@ from ml4h.TensorMap import TensorMap from ml4h.metrics import coefficient_of_determination from ml4h.models.diffusion_blocks import DiffusionModel, DiffusionController -from ml4h.plots import plot_metric_history +from ml4h.plots import plot_metric_history, plot_roc from ml4h.defines import IMAGE_EXT, MODEL_EXT from ml4h.models.inspect import plot_and_time_model from ml4h.models.model_factory import get_custom_objects, make_multimodal_multitask_model @@ -210,11 +210,16 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc preds = np.array(preds).flatten() all_controls = np.array(all_controls).flatten() print(f'Control Predictions was {np.array(preds).shape} Control true was {np.array(all_controls).shape}') - pearson = np.corrcoef(preds, all_controls)[1, 0] - print(f'Pearson correlation {pearson:0.3f} ') - plt.scatter(preds, all_controls) - plt.title(f'''Diffusion Phenotype: {tm_out.name} Control vs Predictions - Pearson correlation {pearson:0.3f}, $R^2$ {coefficient_of_determination(preds, all_controls):0.3f}, N = {len(preds)}''') + if tm_out.is_continuous(): + pearson = np.corrcoef(preds, all_controls)[1, 0] + print(f'Pearson correlation {pearson:0.3f} ') + plt.scatter(preds, all_controls) + plt.title(f'''Diffusion Phenotype: {tm_out.name} Control vs Predictions + Pearson correlation {pearson:0.3f}, $R^2$ {coefficient_of_determination(preds, all_controls):0.3f}, N = {len(preds)}''') + elif tm_out.is_categorical(): + plot_roc(preds, all_controls, tm_out.channel_map, + f'Diffusion Phenotype: {tm_out.name} Control vs Predictions', prefix) + now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') figure_path = os.path.join(prefix, f'metrics_{tm_out.name}_{now_string}{IMAGE_EXT}') if not os.path.exists(os.path.dirname(figure_path)): From 343510762f01c7b673962c99d54c4a4e096cadbb Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 14:24:27 -0400 Subject: [PATCH 011/119] test steps in diffusion plot --- ml4h/models/train.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 45afda87d..4f51dc79c 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -206,11 +206,12 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc if i % 4 == 0: logging.info(f'Inferred on {i+1} synthetic diffusion batches of {batches}') - - preds = np.array(preds).flatten() - all_controls = np.array(all_controls).flatten() - print(f'Control Predictions was {np.array(preds).shape} Control true was {np.array(all_controls).shape}') + preds = np.array(preds) + all_controls = np.array(all_controls) + print(f'Control Predictions was {preds.shape} Control true was {all_controls.shape}') if tm_out.is_continuous(): + preds = preds.flatten() + all_controls = all_controls.flatten() pearson = np.corrcoef(preds, all_controls)[1, 0] print(f'Pearson correlation {pearson:0.3f} ') plt.scatter(preds, all_controls) From ae083af3a75048f25c353d837552f695559ff741 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 14:28:29 -0400 Subject: [PATCH 012/119] test steps in diffusion plot --- ml4h/models/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 4f51dc79c..e71abbe54 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -201,8 +201,8 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc controls = np.random.normal(0, std, size=batch_size) elif tm_out.is_categorical(): controls = np.eye(tm_out.shape[-1])[np.random.choice(tm_out.shape[-1], batch_size)] - preds.append(regress_on_batch(diffuser, regressor, controls, tm_out, batch_size)) - all_controls.append(controls) + preds.extend(regress_on_batch(diffuser, regressor, controls, tm_out, batch_size)) + all_controls.extend(controls) if i % 4 == 0: logging.info(f'Inferred on {i+1} synthetic diffusion batches of {batches}') From 50f8875cf6fcff756183bc5c5494b7b67d235724 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 16:23:19 -0400 Subject: [PATCH 013/119] test steps in diffusion plot --- ml4h/models/diffusion_blocks.py | 1 - ml4h/models/train.py | 10 +++++----- ml4h/plots.py | 3 +-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 4ce09b78e..ee811b3a9 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -568,7 +568,6 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") - plt.show() def get_network(input_shape, widths, block_depth, kernel_size): diff --git a/ml4h/models/train.py b/ml4h/models/train.py index e71abbe54..b4868560f 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -217,15 +217,15 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc plt.scatter(preds, all_controls) plt.title(f'''Diffusion Phenotype: {tm_out.name} Control vs Predictions Pearson correlation {pearson:0.3f}, $R^2$ {coefficient_of_determination(preds, all_controls):0.3f}, N = {len(preds)}''') + now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'scatter_{tm_out.name}_{now_string}{IMAGE_EXT}') + os.makedirs(os.path.dirname(figure_path), exist_ok=True) + plt.savefig(figure_path) elif tm_out.is_categorical(): plot_roc(preds, all_controls, tm_out.channel_map, f'Diffusion Phenotype: {tm_out.name} Control vs Predictions', prefix) - now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'metrics_{tm_out.name}_{now_string}{IMAGE_EXT}') - if not os.path.exists(os.path.dirname(figure_path)): - os.makedirs(os.path.dirname(figure_path)) - plt.savefig(figure_path) + def train_diffusion_control_model(args): diff --git a/ml4h/plots.py b/ml4h/plots.py index cbf7c8ce1..b068437b8 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -2574,8 +2574,7 @@ def plot_roc(prediction, truth, labels, title, prefix="./figures/", dpi=300, wid plt.title(f"ROC {title} n={np.sum(true_sums):.0f}") figure_path = os.path.join(prefix, "per_class_roc_" + title + IMAGE_EXT) - if not os.path.exists(os.path.dirname(figure_path)): - os.makedirs(os.path.dirname(figure_path)) + os.makedirs(os.path.dirname(figure_path), exist_ok=True) plt.savefig(figure_path) logging.info(f"Saved ROC curve at: {figure_path}") return labels_to_areas From a5c2ca3e7184e82bc8ba692adb465f353d69af93 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 16:23:46 -0400 Subject: [PATCH 014/119] test steps in diffusion plot --- ml4h/models/diffusion_blocks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index ee811b3a9..df1a428c6 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -426,7 +426,6 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, prefix='./f if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") - plt.show() def plot_reconstructions( self, batch, diffusion_amount=0, From 93978fdc43f6eb2096441e20ffcf191ad70c042b Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 16:56:59 -0400 Subject: [PATCH 015/119] test steps in diffusion plot --- ml4h/models/diffusion_blocks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index df1a428c6..f121a8429 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -426,6 +426,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, prefix='./f if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") + plt.close() def plot_reconstructions( self, batch, diffusion_amount=0, @@ -567,6 +568,7 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") + plt.close() def get_network(input_shape, widths, block_depth, kernel_size): From a95f7278295cba294e64c2605a20afcf6ba353df Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 18 Oct 2024 17:05:17 -0400 Subject: [PATCH 016/119] test steps in diffusion plot --- ml4h/plots.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ml4h/plots.py b/ml4h/plots.py index b068437b8..97e7ee86d 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -478,6 +478,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu f'Minimum validation loss: {min(history.history["val_loss"]):0.4f}', ) logging.info(f"Saved learning curves at:{figure_path}") + plt.close() def plot_rocs( From 7c90a475358dfb7229ab6b666b337d520c954750 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 29 Oct 2024 12:31:12 -0400 Subject: [PATCH 017/119] test steps in diffusion plot --- ml4h/tensormap/ukb/mri_brain.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index d7af36844..e83f0280e 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -439,3 +439,23 @@ def _masked_brain_tensor(tm, hd5, dependents={}): tensor_from_file=_mni_label_masked({'Left_accumbens': 19, 'Right_accumbens': 70}), normalization=ZeroMeanStd1(), ) + +def random_mni_slice_tensor(tm, hd5, dependents={}): + slice_index = np.random.randint(182) + tensor = pad_or_crop_array_to_shape( + (tm.shape[0], tm.shape[1], 1), np.array( + tm.hd5_first_dataset_in_group(hd5, f'{tm.path_prefix}/axial_{slice_index}'), dtype=np.float32, + ), + ) + return tensor + + +t1_mni_random_slice = TensorMap( + 't1_mni_random_slice', + Interpretation.CONTINUOUS, + shape=(192, 192, 1), + path_prefix='ukb_brain_mri/T1_brain_to_MNI/', + tensor_from_file=random_mni_slice_tensor, + normalization=ZeroMeanStd1(), +) + From 53af85f5f09dddddb504ba9006e5e6d7a6c1ebe4 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 29 Oct 2024 12:35:11 -0400 Subject: [PATCH 018/119] test steps in diffusion plot --- ml4h/tensormap/ukb/mri_brain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index e83f0280e..c70e99e96 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -444,7 +444,7 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): slice_index = np.random.randint(182) tensor = pad_or_crop_array_to_shape( (tm.shape[0], tm.shape[1], 1), np.array( - tm.hd5_first_dataset_in_group(hd5, f'{tm.path_prefix}/axial_{slice_index}'), dtype=np.float32, + tm.hd5_first_dataset_in_group(hd5, f'{tm.path_prefix}axial_{slice_index}/'), dtype=np.float32, ), ) return tensor From 9bfeb36f08938f2310eb78d703f68de956ab4117 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 29 Oct 2024 12:39:50 -0400 Subject: [PATCH 019/119] test steps in diffusion plot --- ml4h/tensormap/ukb/mri_brain.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index c70e99e96..7560b14dc 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -447,8 +447,10 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): tm.hd5_first_dataset_in_group(hd5, f'{tm.path_prefix}axial_{slice_index}/'), dtype=np.float32, ), ) + dependents[tm.dependent_map] = np.array(slice_index) return tensor +axial_index_map = TensorMap('axial_index', Interpretation.CONTINUOUS, shape=(1,), channel_map={'axial_index':0}) t1_mni_random_slice = TensorMap( 't1_mni_random_slice', @@ -457,5 +459,13 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): path_prefix='ukb_brain_mri/T1_brain_to_MNI/', tensor_from_file=random_mni_slice_tensor, normalization=ZeroMeanStd1(), + dependent_map=axial_index_map, ) +sax_random_slice_segmented = TensorMap( + 'sax_random_slice_segmented', Interpretation.CATEGORICAL, shape=(224, 224, len(MRI_SEGMENTED_CHANNEL_MAP)), channel_map=MRI_SEGMENTED_CHANNEL_MAP, +) +sax_random_slice = TensorMap( + 'sax_random_slice', shape=(224, 224, 1), tensor_from_file=sax_random_slice_tensor_maker('cine_segmented_sax_inlinevf/2', 'cine_segmented_sax_inlinevf_segmented/2'), + path_prefix='ukb_cardiac_mri', normalization=ZeroMeanStd1(), dependent_map=sax_random_slice_segmented, +) From c548720f501c1810340b790821f260dbd8f1bea6 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 29 Oct 2024 12:40:28 -0400 Subject: [PATCH 020/119] test steps in diffusion plot --- ml4h/tensormap/ukb/mri_brain.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index 7560b14dc..f4664ea72 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -461,11 +461,3 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): normalization=ZeroMeanStd1(), dependent_map=axial_index_map, ) - -sax_random_slice_segmented = TensorMap( - 'sax_random_slice_segmented', Interpretation.CATEGORICAL, shape=(224, 224, len(MRI_SEGMENTED_CHANNEL_MAP)), channel_map=MRI_SEGMENTED_CHANNEL_MAP, -) -sax_random_slice = TensorMap( - 'sax_random_slice', shape=(224, 224, 1), tensor_from_file=sax_random_slice_tensor_maker('cine_segmented_sax_inlinevf/2', 'cine_segmented_sax_inlinevf_segmented/2'), - path_prefix='ukb_cardiac_mri', normalization=ZeroMeanStd1(), dependent_map=sax_random_slice_segmented, -) From 5617694e84df4d615ab5033e8bbc54f69d80693d Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 29 Oct 2024 12:55:48 -0400 Subject: [PATCH 021/119] test steps in diffusion plot --- ml4h/tensormap/ukb/mri_brain.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index f4664ea72..b2a48385e 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -443,11 +443,15 @@ def _masked_brain_tensor(tm, hd5, dependents={}): def random_mni_slice_tensor(tm, hd5, dependents={}): slice_index = np.random.randint(182) tensor = pad_or_crop_array_to_shape( - (tm.shape[0], tm.shape[1], 1), np.array( + tm.shape, np.array( tm.hd5_first_dataset_in_group(hd5, f'{tm.path_prefix}axial_{slice_index}/'), dtype=np.float32, ), ) - dependents[tm.dependent_map] = np.array(slice_index) + dependents[tm.dependent_map] = np.zeros( + tm.dependent_map.shape, + dtype=np.float32, + ) + dependents[tm.dependent_map][0] = slice_index return tensor axial_index_map = TensorMap('axial_index', Interpretation.CONTINUOUS, shape=(1,), channel_map={'axial_index':0}) From 4cb6b041f9768cfca2b6d13467b7881754808fd1 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 29 Oct 2024 14:11:57 -0400 Subject: [PATCH 022/119] test steps in diffusion plot --- ml4h/tensormap/ukb/mri_brain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index b2a48385e..fbaedd9ca 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -451,7 +451,7 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): tm.dependent_map.shape, dtype=np.float32, ) - dependents[tm.dependent_map][0] = slice_index + dependents[tm.dependent_map][0] = (slice_index / 182.0) - 0.5 return tensor axial_index_map = TensorMap('axial_index', Interpretation.CONTINUOUS, shape=(1,), channel_map={'axial_index':0}) From 16a3ba2e6dcd8a4f5175c8e05ddd39cb76895287 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 30 Oct 2024 11:04:42 -0400 Subject: [PATCH 023/119] brain mri z index norm --- ml4h/models/train.py | 2 ++ ml4h/tensormap/ukb/mri_brain.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index b4868560f..ff6656389 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -215,6 +215,8 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc pearson = np.corrcoef(preds, all_controls)[1, 0] print(f'Pearson correlation {pearson:0.3f} ') plt.scatter(preds, all_controls) + plt.xlabel(f'Predicted {tm_out.name}') + plt.ylabel(f'Control {tm_out.name}') plt.title(f'''Diffusion Phenotype: {tm_out.name} Control vs Predictions Pearson correlation {pearson:0.3f}, $R^2$ {coefficient_of_determination(preds, all_controls):0.3f}, N = {len(preds)}''') now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index fbaedd9ca..45cf6d152 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -451,7 +451,7 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): tm.dependent_map.shape, dtype=np.float32, ) - dependents[tm.dependent_map][0] = (slice_index / 182.0) - 0.5 + dependents[tm.dependent_map][0] = float(slice_index) / 182.0 return tensor axial_index_map = TensorMap('axial_index', Interpretation.CONTINUOUS, shape=(1,), channel_map={'axial_index':0}) From 97af51e2cbe85e1dbef1858b7f99ea379303c22e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 30 Oct 2024 12:51:57 -0400 Subject: [PATCH 024/119] brain mri z index norm --- ml4h/models/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index ff6656389..09a1374f6 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -191,14 +191,14 @@ def regress_on_batch(diffuser, regressor, controls, tm_out, batch_size): return control_predictions[:, 0] if tm_out.is_continuous() else control_predictions -def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batch_size, std, prefix): +def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batch_size, mean, std, prefix): preds = [] all_controls = [] # controls = np.arange(-8, 8, 1) for i in range(batches): if tm_out.is_continuous(): - controls = np.random.normal(0, std, size=batch_size) + controls = np.random.normal(mean, std, size=batch_size) elif tm_out.is_categorical(): controls = np.eye(tm_out.shape[-1])[np.random.choice(tm_out.shape[-1], batch_size)] preds.extend(regress_on_batch(diffuser, regressor, controls, tm_out, batch_size)) @@ -315,7 +315,7 @@ def train_diffusion_control_model(args): args.tensor_maps_out = [tm_out] args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) - regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, 5, - f'{args.output_folder}/{args.id}/') + regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, + 0.5,0.2,f'{args.output_folder}/{args.id}/') return model From d476edf86f875bb7dbef18aa2dfc2ce02cfe5dbe Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 30 Oct 2024 13:34:36 -0400 Subject: [PATCH 025/119] brain mri z index norm --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 09a1374f6..6757ed519 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -316,6 +316,6 @@ def train_diffusion_control_model(args): args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, - 0.5,0.2,f'{args.output_folder}/{args.id}/') + 0.5,2,f'{args.output_folder}/{args.id}/') return model From dde6044051575d931b84c85c63a15899dd84836e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 31 Oct 2024 13:36:24 -0400 Subject: [PATCH 026/119] brain mri z index norm --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 6757ed519..2eb527dc3 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -316,6 +316,6 @@ def train_diffusion_control_model(args): args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, - 0.5,2,f'{args.output_folder}/{args.id}/') + 0.4,2,f'{args.output_folder}/{args.id}/') return model From 0825553f8e0f7c0718cebbbce2b24ecd8667f53e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 31 Oct 2024 13:36:36 -0400 Subject: [PATCH 027/119] brain mri z index norm --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 2eb527dc3..f8c3ba56b 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -316,6 +316,6 @@ def train_diffusion_control_model(args): args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, - 0.4,2,f'{args.output_folder}/{args.id}/') + 0.5,1,f'{args.output_folder}/{args.id}/') return model From c716324eb0f06efe08477eb064c4a41bf8a9b707 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 1 Nov 2024 12:00:11 -0400 Subject: [PATCH 028/119] brain mri z index norm --- ml4h/tensormap/ukb/ecg.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ml4h/tensormap/ukb/ecg.py b/ml4h/tensormap/ukb/ecg.py index c71735f65..0be40c985 100755 --- a/ml4h/tensormap/ukb/ecg.py +++ b/ml4h/tensormap/ukb/ecg.py @@ -1266,3 +1266,14 @@ def ppg_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.nda 'ppg_2', shape=(100, 1), tensor_from_file=ppg_from_hd5, channel_map={'ppg_2': 0}, normalization=Standardize(mean=4824.6, std=3235.8), ) + +def uw_ecg_from_hd5(tm, hd5, dependents={}): + tensor = 6+(30*np.array(hd5[tm.path_prefix], dtype=np.float32)) + return tensor + +ecg_median_uw = TensorMap('ecg_rest_median_raw_10', + Interpretation.CONTINUOUS, + shape=(600, 12), + path_prefix='ecg.ecg_rest_median_raw_10', + channel_map=ECG_REST_MEDIAN_LEADS, + tensor_from_file=uw_ecg_from_hd5) From 1d1a19a37dc4de06b927e0045bb9c3584fb7d355 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 12:08:32 -0500 Subject: [PATCH 029/119] log auc and mse --- ml4h/models/train.py | 6 +++--- ml4h/plots.py | 15 +++++++++------ ml4h/tensormap/ukb/ecg.py | 4 +++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index f8c3ba56b..0a57744fb 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -208,19 +208,19 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc preds = np.array(preds) all_controls = np.array(all_controls) - print(f'Control Predictions was {preds.shape} Control true was {all_controls.shape}') + logging.info(f'Control Predictions was {preds.shape} Control true was {all_controls.shape}') if tm_out.is_continuous(): preds = preds.flatten() all_controls = all_controls.flatten() pearson = np.corrcoef(preds, all_controls)[1, 0] - print(f'Pearson correlation {pearson:0.3f} ') + logging.info(f'Pearson correlation {pearson:0.3f} ') plt.scatter(preds, all_controls) plt.xlabel(f'Predicted {tm_out.name}') plt.ylabel(f'Control {tm_out.name}') plt.title(f'''Diffusion Phenotype: {tm_out.name} Control vs Predictions Pearson correlation {pearson:0.3f}, $R^2$ {coefficient_of_determination(preds, all_controls):0.3f}, N = {len(preds)}''') now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'scatter_{tm_out.name}_{now_string}{IMAGE_EXT}') + figure_path = os.path.join(prefix, f'scatter_{tm_out.name}_r_{pearson:0.3f}_{now_string}{IMAGE_EXT}') os.makedirs(os.path.dirname(figure_path), exist_ok=True) plt.savefig(figure_path) elif tm_out.is_categorical(): diff --git a/ml4h/plots.py b/ml4h/plots.py index 97e7ee86d..816562784 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -470,12 +470,15 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path) - if 'loss' in history.history: - logging.info(f'Starting training loss: {history.history["loss"][0]:0.3f}, Final training loss: {history.history["loss"][-1]:0.4f}') - if 'val_loss' in history.history: + for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss']: + if log_label not in history.history: + continue logging.info( - f'Starting validation loss: {history.history["val_loss"][0]:0.3f}, Final validation loss: {history.history["val_loss"][-1]:0.4f}, ' - f'Minimum validation loss: {min(history.history["val_loss"]):0.4f}', + f''' + Starting {log_label.replace('val_', 'validation ')}: {history.history[log_label][0]:0.3f}, + Final {log_label.replace('val_', 'validation ')}: {history.history[log_label][-1]:0.4f}, + Minimum {log_label.replace('val_', 'validation ')}: {min(history.history[log_label]):0.4f} + ''' ) logging.info(f"Saved learning curves at:{figure_path}") plt.close() @@ -2574,7 +2577,7 @@ def plot_roc(prediction, truth, labels, title, prefix="./figures/", dpi=300, wid plt.plot([0, 1], [0, 1], "k:", lw=0.5) plt.title(f"ROC {title} n={np.sum(true_sums):.0f}") - figure_path = os.path.join(prefix, "per_class_roc_" + title + IMAGE_EXT) + figure_path = os.path.join(prefix, f"per_class_roc_auc_{roc_auc[labels[key]]:0.3f}_{title}{IMAGE_EXT}") os.makedirs(os.path.dirname(figure_path), exist_ok=True) plt.savefig(figure_path) logging.info(f"Saved ROC curve at: {figure_path}") diff --git a/ml4h/tensormap/ukb/ecg.py b/ml4h/tensormap/ukb/ecg.py index 0be40c985..15d71f292 100755 --- a/ml4h/tensormap/ukb/ecg.py +++ b/ml4h/tensormap/ukb/ecg.py @@ -1268,7 +1268,9 @@ def ppg_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.nda ) def uw_ecg_from_hd5(tm, hd5, dependents={}): - tensor = 6+(30*np.array(hd5[tm.path_prefix], dtype=np.float32)) + new_mean = 6 + new_std = 30 + tensor = new_mean+(new_std*np.array(hd5[tm.path_prefix], dtype=np.float32)) return tensor ecg_median_uw = TensorMap('ecg_rest_median_raw_10', From 3c538b3688d6b89228fc92796a3f742ef6adef7e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 12:49:20 -0500 Subject: [PATCH 030/119] log auc and mse --- ml4h/models/diffusion_blocks.py | 2 +- ml4h/models/train.py | 34 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index f121a8429..f22aefa53 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -405,7 +405,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, prefix='./f if 'Sex' in cm.name: control_batch[cm.output_name()][:, 0] = 1 # all female - print(f'\nControl batch keys: {list(control_batch.keys())}') + logging.info(f'\nControl batch keys: {list(control_batch.keys())}') control_embed = self.control_embed_model(control_batch) # plot random generated images for visual evaluation of generation quality generated_images = self.generate( diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 0a57744fb..1758610cd 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -228,6 +228,38 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc f'Diffusion Phenotype: {tm_out.name} Control vs Predictions', prefix) +def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, batch_size, prefix): + control_batch = {} + samples = np.arange(-5, 6, 3) + num_rows = len(samples) + num_cols = 4 + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) + for row, pheno_scale in enumerate(samples): + for cm in tensor_maps_out: + if cm == control_tm: + control_batch[cm.output_name()] = np.ones((batch_size,) + cm.shape) * pheno_scale + else: + control_batch[cm.output_name()] = np.zeros((batch_size,) + cm.shape) + + control_embed = diffuser.control_embed_model(control_batch) + generated_images = diffuser.generate( + control_embed, + num_images=batch_size, + diffusion_steps=50, + ) + + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + plt.imshow(generated_images[index], cmap='gray') + plt.gca().set_title(f'{control_tm.name}: {pheno_scale}') + plt.axis("off") + + plt.tight_layout() + now_string = datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'interpolate_synthetic_{control_tm.name}_{now_string}{IMAGE_EXT}') + os.makedirs(os.path.dirname(figure_path), exist_ok=True) + plt.savefig(figure_path) def train_diffusion_control_model(args): @@ -306,6 +338,8 @@ def train_diffusion_control_model(args): model.load_weights(checkpoint_path) if args.inspect_model: + interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, + f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: model.plot_ecgs(num_rows=2, prefix=os.path.dirname(checkpoint_path)) else: From a543f43c62c07b025feb418d924627294c7877c9 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:08:36 -0500 Subject: [PATCH 031/119] log auc and mse --- ml4h/models/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 1758610cd..4926fe29d 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -246,6 +246,7 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba control_embed, num_images=batch_size, diffusion_steps=50, + reseed=12345, # hold everything constant except for control signal ) for col in range(num_cols): From dfcdb331c3628b20fafc44608e169ff0ca8e3179 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:12:00 -0500 Subject: [PATCH 032/119] log auc and mse --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 4926fe29d..37765a383 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -253,7 +253,7 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) plt.imshow(generated_images[index], cmap='gray') - plt.gca().set_title(f'{control_tm.name}: {pheno_scale}') + plt.gca().set_title(f'{control_tm.name[:10]}: {pheno_scale:0.1f}') plt.axis("off") plt.tight_layout() From 5d8132fb75e1d3309e70296dc18ed11c00ca03f3 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:12:28 -0500 Subject: [PATCH 033/119] log auc and mse --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 37765a383..6ab2e472b 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -230,7 +230,7 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, batch_size, prefix): control_batch = {} - samples = np.arange(-5, 6, 3) + samples = np.arange(-5, 6, 1) num_rows = len(samples) num_cols = 4 plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) From 8f357abc7aa5b2c8e1db045c7dc6125093f5a6f4 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:16:56 -0500 Subject: [PATCH 034/119] log auc and mse --- ml4h/models/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 6ab2e472b..e8cdce04a 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -249,10 +249,10 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba reseed=12345, # hold everything constant except for control signal ) - for col in range(num_cols): + for i_col, col in enumerate(range(num_cols)): index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) - plt.imshow(generated_images[index], cmap='gray') + plt.imshow(generated_images[i_col], cmap='gray') plt.gca().set_title(f'{control_tm.name[:10]}: {pheno_scale:0.1f}') plt.axis("off") From 50ae738e4931087d00c0f33e7270a8f04384f037 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:27:42 -0500 Subject: [PATCH 035/119] log auc and mse --- ml4h/models/train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index e8cdce04a..57ccff03b 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -252,7 +252,11 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba for i_col, col in enumerate(range(num_cols)): index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) - plt.imshow(generated_images[i_col], cmap='gray') + if len(generated_images.shape) == 3: + for lead in range(generated_images.shape[-1]): + plt.plot(generated_images[index, :, lead], label=lead) + elif len(generated_images.shape) == 4: + plt.imshow(generated_images[i_col], cmap='gray') plt.gca().set_title(f'{control_tm.name[:10]}: {pheno_scale:0.1f}') plt.axis("off") From 7b00f90ed320d71cef394994931a1f015103f43c Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:38:40 -0500 Subject: [PATCH 036/119] log auc and mse --- ml4h/models/diffusion_blocks.py | 2 -- ml4h/models/train.py | 2 +- ml4h/plots.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index f22aefa53..1b1794b8a 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -266,8 +266,6 @@ def generate(self, control_embed, num_images, diffusion_steps, reseed=None, reno # noise -> images -> denormalized images if reseed is not None: - if renoise is not None: - noiser = tf.random.normal(shape=(num_images,) + self.input_map.shape) * renoise tf.random.set_seed(reseed) initial_noise = tf.random.normal(shape=(num_images,) + self.input_map.shape) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 57ccff03b..e33fd4b50 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -230,7 +230,7 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, batch_size, prefix): control_batch = {} - samples = np.arange(-5, 6, 1) + samples = np.arange(-2, 2, 0.3) num_rows = len(samples) num_cols = 4 plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) diff --git a/ml4h/plots.py b/ml4h/plots.py index 816562784..525338d72 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -475,7 +475,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu continue logging.info( f''' - Starting {log_label.replace('val_', 'validation ')}: {history.history[log_label][0]:0.3f}, + Starting {log_label.replace('val_', 'validation ')}: {history.history[log_label][0]:0.4f}, Final {log_label.replace('val_', 'validation ')}: {history.history[log_label][-1]:0.4f}, Minimum {log_label.replace('val_', 'validation ')}: {min(history.history[log_label]):0.4f} ''' From 28a7ea309ad66380963d4f9ee6bb50367b410a30 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:44:10 -0500 Subject: [PATCH 037/119] log auc and mse --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index e33fd4b50..8c8552802 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -254,7 +254,7 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba plt.subplot(num_rows, num_cols, index + 1) if len(generated_images.shape) == 3: for lead in range(generated_images.shape[-1]): - plt.plot(generated_images[index, :, lead], label=lead) + plt.plot(generated_images[i_col, :, lead], label=lead) elif len(generated_images.shape) == 4: plt.imshow(generated_images[i_col], cmap='gray') plt.gca().set_title(f'{control_tm.name[:10]}: {pheno_scale:0.1f}') From f4734e7d640818f92e18c1ae4bdedfbb9380fc35 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 13:49:11 -0500 Subject: [PATCH 038/119] log auc and mse --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 8c8552802..7e7357808 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -230,7 +230,7 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, batch_size, prefix): control_batch = {} - samples = np.arange(-2, 2, 0.3) + samples = np.arange(-4, 5, 1) num_rows = len(samples) num_cols = 4 plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) From 35b3a6dd5857fe79f1ca22531a4ccd5b1ffedab5 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 8 Nov 2024 15:34:28 -0500 Subject: [PATCH 039/119] log auc and mse --- ml4h/models/train.py | 2 +- ml4h/plots.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 7e7357808..81c8b8d96 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -355,6 +355,6 @@ def train_diffusion_control_model(args): args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, - 0.5,1,f'{args.output_folder}/{args.id}/') + 0.0,5,f'{args.output_folder}/{args.id}/') return model diff --git a/ml4h/plots.py b/ml4h/plots.py index 525338d72..b0e93018a 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -476,8 +476,8 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu logging.info( f''' Starting {log_label.replace('val_', 'validation ')}: {history.history[log_label][0]:0.4f}, - Final {log_label.replace('val_', 'validation ')}: {history.history[log_label][-1]:0.4f}, - Minimum {log_label.replace('val_', 'validation ')}: {min(history.history[log_label]):0.4f} + Final {log_label.replace('val_', 'validation ')}: {history.history[log_label][-1]:0.4f}, + Minimum {log_label.replace('val_', 'validation ')}: {min(history.history[log_label]):0.4f} ''' ) logging.info(f"Saved learning curves at:{figure_path}") From e510f821797b8511614108e2daf8050868acc73a Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 11 Nov 2024 12:01:41 -0500 Subject: [PATCH 040/119] log auc and mse --- ml4h/plots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index b0e93018a..b8586a675 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -781,7 +781,7 @@ def plot_scatter( sns.distplot(truth, label="Truth", color="b", ax=ax2) ax2.legend(loc="upper left") - figure_path = os.path.join(prefix, "scatter_" + title + IMAGE_EXT) + figure_path = os.path.join(prefix, f"scatter_{title}_r_{pearson:0.4f}{IMAGE_EXT}") if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) logging.info(f"Try to save scatter plot at: {figure_path}") From fa239b81c4e5773316a1e082c762b7474b00a03f Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 11 Nov 2024 12:04:05 -0500 Subject: [PATCH 041/119] log auc and mse --- ml4h/plots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index b8586a675..b8285f3b8 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -781,7 +781,7 @@ def plot_scatter( sns.distplot(truth, label="Truth", color="b", ax=ax2) ax2.legend(loc="upper left") - figure_path = os.path.join(prefix, f"scatter_{title}_r_{pearson:0.4f}{IMAGE_EXT}") + figure_path = os.path.join(prefix, f"scatter_r_{pearson:0.4f}_{title}{IMAGE_EXT}") if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) logging.info(f"Try to save scatter plot at: {figure_path}") From c5f718cb040e91bee3ec1a5e9e792a5c55c0ab15 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 15 Nov 2024 10:40:28 -0500 Subject: [PATCH 042/119] log auc and mse --- ml4h/models/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 81c8b8d96..3ac324c90 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -223,6 +223,7 @@ def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batc figure_path = os.path.join(prefix, f'scatter_{tm_out.name}_r_{pearson:0.3f}_{now_string}{IMAGE_EXT}') os.makedirs(os.path.dirname(figure_path), exist_ok=True) plt.savefig(figure_path) + plt.close() elif tm_out.is_categorical(): plot_roc(preds, all_controls, tm_out.channel_map, f'Diffusion Phenotype: {tm_out.name} Control vs Predictions', prefix) @@ -265,6 +266,7 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba figure_path = os.path.join(prefix, f'interpolate_synthetic_{control_tm.name}_{now_string}{IMAGE_EXT}') os.makedirs(os.path.dirname(figure_path), exist_ok=True) plt.savefig(figure_path) + plt.close() def train_diffusion_control_model(args): From dca9f3ed8be0d63962df108ab77bfd8ecb368f64 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 15 Nov 2024 14:00:52 -0500 Subject: [PATCH 043/119] log auc and mse --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 3ac324c90..544c7c728 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -357,6 +357,6 @@ def train_diffusion_control_model(args): args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, - 0.0,5,f'{args.output_folder}/{args.id}/') + 0.0,10,f'{args.output_folder}/{args.id}/') return model From 953aa69d2b3d8f57c324be6c1c5e7727c69d2736 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 15 Nov 2024 14:10:52 -0500 Subject: [PATCH 044/119] log auc and mse --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 544c7c728..363b4c149 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -357,6 +357,6 @@ def train_diffusion_control_model(args): args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, - 0.0,10,f'{args.output_folder}/{args.id}/') + 0.0,2.5,f'{args.output_folder}/{args.id}/') return model From 4e11ca25b5d2ee9968ad556a8d9433e684b943da Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 26 Nov 2024 10:38:22 -0500 Subject: [PATCH 045/119] log auc and mse --- scripts/jupyter.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/jupyter.sh b/scripts/jupyter.sh index e6e85d51b..be1e19d3f 100755 --- a/scripts/jupyter.sh +++ b/scripts/jupyter.sh @@ -105,6 +105,7 @@ ${DOCKER_COMMAND} run -it \ ${GPU_DEVICE} \ --rm \ --ipc=host \ +--cpus="32" \ -v /home/${USER}/:/home/${USER}/ \ -v /mnt/:/mnt/ \ -p 0.0.0.0:${PORT}:${PORT} \ From 1df78f7150e0853900525104ff352dc68a167e0d Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 26 Nov 2024 15:18:49 -0500 Subject: [PATCH 046/119] log auc and mse --- scripts/jupyter.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/jupyter.sh b/scripts/jupyter.sh index be1e19d3f..e6e85d51b 100755 --- a/scripts/jupyter.sh +++ b/scripts/jupyter.sh @@ -105,7 +105,6 @@ ${DOCKER_COMMAND} run -it \ ${GPU_DEVICE} \ --rm \ --ipc=host \ ---cpus="32" \ -v /home/${USER}/:/home/${USER}/ \ -v /mnt/:/mnt/ \ -p 0.0.0.0:${PORT}:${PORT} \ From f892448b4b6aee56a7aa83c8f6becc9306d96deb Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:04:29 -0500 Subject: [PATCH 047/119] cumc docker --- docker/ml4h_deploy/Dockerfile | 16 ++++++++++++++++ docker/ml4h_deploy/process_files.py | 23 +++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 docker/ml4h_deploy/Dockerfile create mode 100644 docker/ml4h_deploy/process_files.py diff --git a/docker/ml4h_deploy/Dockerfile b/docker/ml4h_deploy/Dockerfile new file mode 100644 index 000000000..413d5a081 --- /dev/null +++ b/docker/ml4h_deploy/Dockerfile @@ -0,0 +1,16 @@ +FROM ghcr.io/broadinstitute/ml4h:tf2.9-latest-cpu + +# Set the working directory +WORKDIR /app + +# Install TensorFlow (or any other necessary libraries) +RUN pip install tensorflow + +# Copy the Keras model file into the Docker image +COPY ecg2af_quintuplet_v2024_01_13.h5 /app/ecg2af_quintuplet_v2024_01_13.h5 + +# Copy the Python script +COPY process_files.py /app/process_files.py + +# Define the command to run the script +CMD ["python", "process_files.py", "/data"] \ No newline at end of file diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py new file mode 100644 index 000000000..40d5a23b1 --- /dev/null +++ b/docker/ml4h_deploy/process_files.py @@ -0,0 +1,23 @@ +import sys +import os +from tensorflow.keras.models import load_model + +# Load the model +model = load_model('ecg2af_quintuplet_v2024_01_13.h5') + +def process_file(filepath): + # Placeholder for file processing logic + print(f"Processing file: {filepath}") + # Example: Use the model to make a prediction (add real processing logic here) + +def main(directory): + # Iterate over all files in the specified directory + for filename in os.listdir(directory): + filepath = os.path.join(directory, filename) + if os.path.isfile(filepath): + process_file(filepath) + +if __name__ == "__main__": + # Take directory path from command-line arguments + directory = sys.argv[1] if len(sys.argv) > 1 else "/data" + main(directory) \ No newline at end of file From 43c857e9d8981fe5ecd98ef9e09342bc56bc7d78 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:10:21 -0500 Subject: [PATCH 048/119] cumc docker --- docker/ml4h_deploy/process_files.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 40d5a23b1..d0bd94d4f 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -1,6 +1,14 @@ -import sys import os +import sys +import numpy as np from tensorflow.keras.models import load_model +from ml4h.models.model_factory import get_custom_objects +from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2 +from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy, sex_dummy3 + +output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3]} +custom_dict = get_custom_objects(list(output_tensormaps.values())) +model = load_model('./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5', custom_objects=custom_dict) # Load the model model = load_model('ecg2af_quintuplet_v2024_01_13.h5') From e495b4801cde3e96d8ac717c03b24c15ef070aab Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:11:51 -0500 Subject: [PATCH 049/119] cumc docker --- docker/ml4h_deploy/Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/ml4h_deploy/Dockerfile b/docker/ml4h_deploy/Dockerfile index 413d5a081..651ae6efc 100644 --- a/docker/ml4h_deploy/Dockerfile +++ b/docker/ml4h_deploy/Dockerfile @@ -12,5 +12,7 @@ COPY ecg2af_quintuplet_v2024_01_13.h5 /app/ecg2af_quintuplet_v2024_01_13.h5 # Copy the Python script COPY process_files.py /app/process_files.py +RUN pip3 install ml4h + # Define the command to run the script CMD ["python", "process_files.py", "/data"] \ No newline at end of file From 13ee7457636096b977b312bd8560b9154cbcc8f2 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:13:20 -0500 Subject: [PATCH 050/119] cumc docker --- docker/ml4h_deploy/process_files.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index d0bd94d4f..03b9667e4 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -8,10 +8,7 @@ output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3]} custom_dict = get_custom_objects(list(output_tensormaps.values())) -model = load_model('./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5', custom_objects=custom_dict) - -# Load the model -model = load_model('ecg2af_quintuplet_v2024_01_13.h5') +model = load_model('ecg2af_quintuplet_v2024_01_13.h5', custom_objects=custom_dict) def process_file(filepath): # Placeholder for file processing logic From aa5c162444000c9b938399ffd7f8c44ce9ff97f5 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:16:57 -0500 Subject: [PATCH 051/119] cumc docker --- docker/ml4h_deploy/process_files.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 03b9667e4..0fe15a6cc 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -2,11 +2,14 @@ import sys import numpy as np from tensorflow.keras.models import load_model +from ml4h.TensorMap import TensorMap, Interpretation from ml4h.models.model_factory import get_custom_objects from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2 -from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy, sex_dummy3 +from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy -output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3]} +sex_tmap = TensorMap(name='sex', interpretation=Interpretation.CATEGORICAL, channel_map={'Female': 0, 'Male':1}) + +output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_tmap]} custom_dict = get_custom_objects(list(output_tensormaps.values())) model = load_model('ecg2af_quintuplet_v2024_01_13.h5', custom_objects=custom_dict) From 37db2830049680a4130695df533d6c1db2c6b982 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:19:21 -0500 Subject: [PATCH 052/119] cumc docker --- docker/ml4h_deploy/process_files.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 0fe15a6cc..90f7ed833 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -4,12 +4,26 @@ from tensorflow.keras.models import load_model from ml4h.TensorMap import TensorMap, Interpretation from ml4h.models.model_factory import get_custom_objects -from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2 -from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy + +n_intervals = 25 + +af_tmap = TensorMap( + 'survival_curve_af', + Interpretation.SURVIVAL_CURVE, + shape=(n_intervals*2,), +) + +death_tmap = TensorMap( + 'death_event', + Interpretation.SURVIVAL_CURVE, + shape=(n_intervals*2,), +) sex_tmap = TensorMap(name='sex', interpretation=Interpretation.CATEGORICAL, channel_map={'Female': 0, 'Male':1}) +age_tmap = TensorMap(name='age_in_days', interpretation=Interpretation.CONTINUOUS, channel_map={'age_in_days': 0}) +af_in_read_tmap = TensorMap(name='af_in_read', interpretation=Interpretation.CATEGORICAL, channel_map={'no_af_in_read': 0, 'af_in_read':1}) -output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_tmap]} +output_tensormaps = {tm.output_name(): tm for tm in [af_tmap, death_tmap, sex_tmap, age_tmap, af_in_read_tmap]} custom_dict = get_custom_objects(list(output_tensormaps.values())) model = load_model('ecg2af_quintuplet_v2024_01_13.h5', custom_objects=custom_dict) From 5befd1ee978ae3b331d99720fafef899c5ce560f Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:36:42 -0500 Subject: [PATCH 053/119] cumc docker --- docker/ml4h_deploy/process_files.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 90f7ed833..1ebf2d3c3 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -1,12 +1,22 @@ import os import sys + +import h5py import numpy as np from tensorflow.keras.models import load_model from ml4h.TensorMap import TensorMap, Interpretation +from ml4h.defines import ECG_REST_AMP_LEADS from ml4h.models.model_factory import get_custom_objects n_intervals = 25 +ecg_tmap = TensorMap( + 'ecg_5000_std', + Interpretation.CONTINUOUS, + shape=(5000, 12), + channel_map=ECG_REST_AMP_LEADS +) + af_tmap = TensorMap( 'survival_curve_af', Interpretation.SURVIVAL_CURVE, @@ -30,6 +40,11 @@ def process_file(filepath): # Placeholder for file processing logic print(f"Processing file: {filepath}") + with h5py.File(filepath, 'r') as hd5: + tensor = np.zeros(ecg_tmap.shape, dtype=np.float32) + for lead in ecg_tmap.channel_map: + tensor[:, ecg_tmap.channel_map[lead]] = hd5[f'/ukb_ecg_rest/strip_{lead}/instance_0'] + print(f"Got tensor: {tensor.mean():0.3f}") # Example: Use the model to make a prediction (add real processing logic here) def main(directory): From 87b4cbc23f662078393b5da9b706331b2c73da74 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:42:35 -0500 Subject: [PATCH 054/119] cumc docker --- docker/ml4h_deploy/process_files.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 1ebf2d3c3..f48d6111a 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -44,7 +44,15 @@ def process_file(filepath): tensor = np.zeros(ecg_tmap.shape, dtype=np.float32) for lead in ecg_tmap.channel_map: tensor[:, ecg_tmap.channel_map[lead]] = hd5[f'/ukb_ecg_rest/strip_{lead}/instance_0'] + + tensor -= tensor.mean() + tensor /= (tensor.std() + 1e-6) print(f"Got tensor: {tensor.mean():0.3f}") + prediction = model.predict(tensor, verbose=0) + if len(model.output_names) == 1: + prediction = [prediction] + predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} + print(f"Got predictions: {predictions_dict}") # Example: Use the model to make a prediction (add real processing logic here) def main(directory): From 02dc15a5158d09379579a19829dc2762c09873ae Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 15:43:53 -0500 Subject: [PATCH 055/119] cumc docker --- docker/ml4h_deploy/process_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index f48d6111a..7946c0ea1 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -48,7 +48,7 @@ def process_file(filepath): tensor -= tensor.mean() tensor /= (tensor.std() + 1e-6) print(f"Got tensor: {tensor.mean():0.3f}") - prediction = model.predict(tensor, verbose=0) + prediction = model.predict(np.expand_dims(tensor, axis=0), verbose=0) if len(model.output_names) == 1: prediction = [prediction] predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} From 1a933049eeae4985d3c9c28543bda0d8ebd6142e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 16:32:22 -0500 Subject: [PATCH 056/119] cumc docker --- docker/ml4h_deploy/process_files.py | 36 +++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 7946c0ea1..6950771e7 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -1,8 +1,10 @@ import os import sys +from collections import defaultdict import h5py import numpy as np +import pandas as pd from tensorflow.keras.models import load_model from ml4h.TensorMap import TensorMap, Interpretation from ml4h.defines import ECG_REST_AMP_LEADS @@ -36,8 +38,9 @@ output_tensormaps = {tm.output_name(): tm for tm in [af_tmap, death_tmap, sex_tmap, age_tmap, af_in_read_tmap]} custom_dict = get_custom_objects(list(output_tensormaps.values())) model = load_model('ecg2af_quintuplet_v2024_01_13.h5', custom_objects=custom_dict) +space_dict = defaultdict(list) -def process_file(filepath): +def process_file(filepath, space_dict): # Placeholder for file processing logic print(f"Processing file: {filepath}") with h5py.File(filepath, 'r') as hd5: @@ -47,20 +50,43 @@ def process_file(filepath): tensor -= tensor.mean() tensor /= (tensor.std() + 1e-6) - print(f"Got tensor: {tensor.mean():0.3f}") + #print(f"Got tensor: {tensor.mean():0.3f}") prediction = model.predict(np.expand_dims(tensor, axis=0), verbose=0) if len(model.output_names) == 1: prediction = [prediction] predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} - print(f"Got predictions: {predictions_dict}") + #print(f"Got predictions: {predictions_dict}") + + for otm in output_tensormaps.values(): + y = predictions_dict[otm.output_name()] + if otm.is_categorical(): + space_dict[f'{otm.name}_prediction'].append(y[0, 1]) + elif otm.is_continuous(): + space_dict[f'{otm.name}_prediction'].append(y[0, 0]) + elif otm.is_survival_curve(): + intervals = otm.shape[-1] // 2 + days_per_bin = 1 + (2 * otm.days_window) // intervals + predicted_survivals = np.cumprod(y[:, :intervals], axis=1) + space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1])) + # print(f' got target: {target[otm.output_name()].numpy().shape}') + # sick = np.sum(target[otm.output_name()].numpy()[:, intervals:], axis=-1) + # follow_up = np.cumsum(target[otm.output_name()].numpy()[:, :intervals], axis=-1)[:, -1] * days_per_bin + # space_dict[f'{otm.name}_event'].append(str(sick[b])) + # space_dict[f'{otm.name}_follow_up'].append(str(follow_up[b])) # Example: Use the model to make a prediction (add real processing logic here) def main(directory): # Iterate over all files in the specified directory - for filename in os.listdir(directory): + space_dict = defaultdict(list) + for i,filename in enumerate(os.listdir(directory)): filepath = os.path.join(directory, filename) if os.path.isfile(filepath): - process_file(filepath) + process_file(filepath, space_dict) + if i > 100: + break + + df = pd.DataFrame.from_dict(space_dict) + df.to_csv('/output/ecg2af_quintuplet.csv', index=False) if __name__ == "__main__": # Take directory path from command-line arguments From 3393380ebfad17a98a61cb2a820a7e36f178011f Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 16:36:03 -0500 Subject: [PATCH 057/119] cumc docker --- docker/ml4h_deploy/process_files.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 6950771e7..c99937209 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -56,7 +56,8 @@ def process_file(filepath, space_dict): prediction = [prediction] predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} #print(f"Got predictions: {predictions_dict}") - + space_dict['sample_id'].append(os.path.basename(filepath).replace('.hd5', '')) + space_dict['ecg_path'].append(filepath) for otm in output_tensormaps.values(): y = predictions_dict[otm.output_name()] if otm.is_categorical(): From baa187544e66d915f10ef8c9d962e79891ea1642 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 16:40:08 -0500 Subject: [PATCH 058/119] cumc docker --- docker/ml4h_deploy/process_files.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index c99937209..6e48b37a3 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -58,6 +58,11 @@ def process_file(filepath, space_dict): #print(f"Got predictions: {predictions_dict}") space_dict['sample_id'].append(os.path.basename(filepath).replace('.hd5', '')) space_dict['ecg_path'].append(filepath) + if '/dates/atrial_fibrillation_or_flutter_date' in hd5: + space_dict['has_af'].append(1) + else: + space_dict['has_af'].append(0) + for otm in output_tensormaps.values(): y = predictions_dict[otm.output_name()] if otm.is_categorical(): From afa912dcf3d2c113b4411c5042fb816c80f51d36 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 16:40:26 -0500 Subject: [PATCH 059/119] cumc docker --- docker/ml4h_deploy/process_files.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 6e48b37a3..7338016c0 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -62,7 +62,7 @@ def process_file(filepath, space_dict): space_dict['has_af'].append(1) else: space_dict['has_af'].append(0) - + for otm in output_tensormaps.values(): y = predictions_dict[otm.output_name()] if otm.is_categorical(): @@ -88,7 +88,7 @@ def main(directory): filepath = os.path.join(directory, filename) if os.path.isfile(filepath): process_file(filepath, space_dict) - if i > 100: + if i > 1000: break df = pd.DataFrame.from_dict(space_dict) From 6fddd3d7076781bf23c25bba8a49abeef26b64e4 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 27 Nov 2024 16:51:36 -0500 Subject: [PATCH 060/119] cumc docker --- docker/ml4h_deploy/process_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 7338016c0..2c2b21b72 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -88,7 +88,7 @@ def main(directory): filepath = os.path.join(directory, filename) if os.path.isfile(filepath): process_file(filepath, space_dict) - if i > 1000: + if i > 10000: break df = pd.DataFrame.from_dict(space_dict) From 3cef0242ef8cee17cc114857c1bd6b917e6ccd6e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 3 Dec 2024 11:39:17 -0500 Subject: [PATCH 061/119] fix --- docker/ml4h_deploy/process_files.py | 167 ++++++++++++++++++++++++++-- 1 file changed, 160 insertions(+), 7 deletions(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index 2c2b21b72..df47d98c0 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -1,8 +1,11 @@ import os import sys +import base64 +import struct from collections import defaultdict import h5py +import xmltodict import numpy as np import pandas as pd from tensorflow.keras.models import load_model @@ -40,18 +43,18 @@ model = load_model('ecg2af_quintuplet_v2024_01_13.h5', custom_objects=custom_dict) space_dict = defaultdict(list) -def process_file(filepath, space_dict): +def process_ukb_hd5(filepath, space_dict): # Placeholder for file processing logic print(f"Processing file: {filepath}") with h5py.File(filepath, 'r') as hd5: - tensor = np.zeros(ecg_tmap.shape, dtype=np.float32) + ecg_array = np.zeros(ecg_tmap.shape, dtype=np.float32) for lead in ecg_tmap.channel_map: - tensor[:, ecg_tmap.channel_map[lead]] = hd5[f'/ukb_ecg_rest/strip_{lead}/instance_0'] + ecg_array[:, ecg_tmap.channel_map[lead]] = hd5[f'/ukb_ecg_rest/strip_{lead}/instance_0'] - tensor -= tensor.mean() - tensor /= (tensor.std() + 1e-6) + ecg_array -= ecg_array.mean() + ecg_array /= (ecg_array.std() + 1e-6) #print(f"Got tensor: {tensor.mean():0.3f}") - prediction = model.predict(np.expand_dims(tensor, axis=0), verbose=0) + prediction = model.predict(np.expand_dims(ecg_array, axis=0), verbose=0) if len(model.output_names) == 1: prediction = [prediction] predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} @@ -81,13 +84,163 @@ def process_file(filepath, space_dict): # space_dict[f'{otm.name}_follow_up'].append(str(follow_up[b])) # Example: Use the model to make a prediction (add real processing logic here) +def decode_ekg_muse(raw_wave): + """ + Ingest the base64 encoded waveforms and transform to numeric + """ + # covert the waveform from base64 to byte array + arr = base64.b64decode(bytes(raw_wave, 'utf-8')) + + # unpack every 2 bytes, little endian (16 bit encoding) + unpack_symbols = ''.join([char * int(len(arr) / 2) for char in 'h']) + byte_array = struct.unpack(unpack_symbols, arr) + return byte_array + +def decode_ekg_muse_to_array(raw_wave, downsample=1): + """ + Ingest the base64 encoded waveforms and transform to numeric + + downsample: 0.5 takes every other value in the array. Muse samples at 500/s and the sample model requires 250/s. So take every other. + """ + try: + dwnsmpl = int(1 // downsample) + except ZeroDivisionError: + print("You must downsample by more than 0") + # covert the waveform from base64 to byte array + arr = base64.b64decode(bytes(raw_wave, 'utf-8')) + + # unpack every 2 bytes, little endian (16 bit encoding) + unpack_symbols = ''.join([char * int(len(arr) / 2) for char in 'h']) + byte_array = struct.unpack(unpack_symbols, arr) + return np.array(byte_array)[::dwnsmpl] + +def process_ge_muse_xml(filepath, space_dict): + with open(filepath, 'rb') as fd: + dic = xmltodict.parse(fd.read().decode('utf8')) + + """ + + Upload the ECG as numpy array with shape=[2500,12,1] ([time, leads, 1]). + + The voltage unit should be in 1 mv/unit and the sampling rate should be 250/second (total 10 second). + + The leads should be ordered as follow I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6. + + """ + try: + pt_id = dic['RestingECG']['PatientDemographics']['PatientID'] + except: + print("no PatientID") + pt_id = "none" + try: + PharmaUniqueECGID = dic['RestingECG']['PharmaData']['PharmaUniqueECGID'] + except: + print("no PharmaUniqueECGID") + PharmaUniqueECGID = "none" + try: + AcquisitionDateTime = dic['RestingECG']['TestDemographics']['AcquisitionDate'] + "_" + \ + dic['RestingECG']['TestDemographics']['AcquisitionTime'].replace(":", "-") + except: + print("no AcquisitionDateTime") + AcquisitionDateTime = "none" + + # try: + # requisition_number = dic['RestingECG']['Order']['RequisitionNumber'] + # except: + # print("no requisition_number") + # requisition_number = "none" + + # need to instantiate leads in the proper order for the model + lead_order = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] + + """ + Each EKG will have this data structure: + lead_data = { + 'I': np.array + } + """ + + lead_data = dict.fromkeys(lead_order) + # lead_data = {leadid: None for k in lead_order} + + # for all_lead_data in dic['RestingECG']['Waveform']: + # for single_lead_data in lead['LeadData']: + # leadname = single_lead_data['LeadID'] + # if leadname in (lead_order): + + for lead in dic['RestingECG']['Waveform']: + for leadid in range(len(lead['LeadData'])): + sample_length = len(decode_ekg_muse_to_array(lead['LeadData'][leadid]['WaveFormData'])) + # sample_length is equivalent to dic['RestingECG']['Waveform']['LeadData']['LeadSampleCountTotal'] + if sample_length == 5000: + lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array( + lead['LeadData'][leadid]['WaveFormData'], downsample=1) + elif sample_length == 2500: + lead_data[lead['LeadData'][leadid]['LeadID']] = decode_ekg_muse_to_array( + lead['LeadData'][leadid]['WaveFormData'], downsample=2) + else: + continue + # ensures all leads have 2500 samples and also passes over the 3 second waveform + + lead_data['III'] = (np.array(lead_data["II"]) - np.array(lead_data["I"])) + lead_data['aVR'] = -(np.array(lead_data["I"]) + np.array(lead_data["II"])) / 2 + lead_data['aVF'] = (np.array(lead_data["II"]) + np.array(lead_data["III"])) / 2 + lead_data['aVL'] = (np.array(lead_data["I"]) - np.array(lead_data["III"])) / 2 + + lead_data = {k: lead_data[k] for k in lead_order} + # drops V3R, V4R, and V7 if it was a 15-lead ECG + + # now construct and reshape the array + # converting the dictionary to an np.array + temp = [] + for key, value in lead_data.items(): + temp.append(value) + + # transpose to be [time, leads, ] + ecg_array = np.array(temp).T + + # expand dims to [time, leads, 1] + ecg_array = np.expand_dims(ecg_array, axis=-1) + filename = '{}_{}_{}.npy'.format(pt_id, AcquisitionDateTime,PharmaUniqueECGID) + print(f'would write npy to {filename} len lead III {len(lead_data["III"])}') + ecg_array -= ecg_array.mean() + ecg_array /= (ecg_array.std() + 1e-6) + #print(f"Got tensor: {tensor.mean():0.3f}") + prediction = model.predict(np.expand_dims(ecg_array, axis=0), verbose=0) + if len(model.output_names) == 1: + prediction = [prediction] + predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} + #print(f"Got predictions: {predictions_dict}") + space_dict['filepath'].append(os.path.basename(filepath)) + space_dict['patient_id'].append(pt_id) + space_dict['acquisition_datetime'].append(AcquisitionDateTime) + space_dict['pharma_unique_ecg_id'].append(PharmaUniqueECGID) + + for otm in output_tensormaps.values(): + y = predictions_dict[otm.output_name()] + if otm.is_categorical(): + space_dict[f'{otm.name}_prediction'].append(y[0, 1]) + elif otm.is_continuous(): + space_dict[f'{otm.name}_prediction'].append(y[0, 0]) + elif otm.is_survival_curve(): + intervals = otm.shape[-1] // 2 + days_per_bin = 1 + (2 * otm.days_window) // intervals + predicted_survivals = np.cumprod(y[:, :intervals], axis=1) + space_dict[f'{otm.name}_prediction'].append(str(1 - predicted_survivals[0, -1])) + # print(f' got target: {target[otm.output_name()].numpy().shape}') + # sick = np.sum(target[otm.output_name()].numpy()[:, intervals:], axis=-1) + # follow_up = np.cumsum(target[otm.output_name()].numpy()[:, :intervals], axis=-1)[:, -1] * days_per_bin + # space_dict[f'{otm.name}_event'].append(str(sick[b])) + # space_dict[f'{otm.name}_follow_up'].append(str(follow_up[b])) +# Example: Use the model to make a prediction (add real processing logic here) + def main(directory): # Iterate over all files in the specified directory space_dict = defaultdict(list) for i,filename in enumerate(os.listdir(directory)): filepath = os.path.join(directory, filename) if os.path.isfile(filepath): - process_file(filepath, space_dict) + process_ge_muse_xml(filepath, space_dict) if i > 10000: break From 82c8c9ad9f5e74490033bbabc68ae5a898ce91d5 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 3 Dec 2024 12:55:27 -0500 Subject: [PATCH 062/119] fix --- docker/ml4h_deploy/process_files.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/docker/ml4h_deploy/process_files.py b/docker/ml4h_deploy/process_files.py index df47d98c0..166cb982b 100644 --- a/docker/ml4h_deploy/process_files.py +++ b/docker/ml4h_deploy/process_files.py @@ -128,21 +128,21 @@ def process_ge_muse_xml(filepath, space_dict): """ try: - pt_id = dic['RestingECG']['PatientDemographics']['PatientID'] + patient_id = dic['RestingECG']['PatientDemographics']['PatientID'] except: print("no PatientID") - pt_id = "none" + patient_id = "none" try: - PharmaUniqueECGID = dic['RestingECG']['PharmaData']['PharmaUniqueECGID'] + pharma_unique_ecg_id = dic['RestingECG']['PharmaData']['PharmaUniqueECGID'] except: print("no PharmaUniqueECGID") - PharmaUniqueECGID = "none" + pharma_unique_ecg_id = "none" try: - AcquisitionDateTime = dic['RestingECG']['TestDemographics']['AcquisitionDate'] + "_" + \ + acquisition_date_time = dic['RestingECG']['TestDemographics']['AcquisitionDate'] + "_" + \ dic['RestingECG']['TestDemographics']['AcquisitionTime'].replace(":", "-") except: print("no AcquisitionDateTime") - AcquisitionDateTime = "none" + acquisition_date_time = "none" # try: # requisition_number = dic['RestingECG']['Order']['RequisitionNumber'] @@ -199,10 +199,7 @@ def process_ge_muse_xml(filepath, space_dict): # transpose to be [time, leads, ] ecg_array = np.array(temp).T - # expand dims to [time, leads, 1] - ecg_array = np.expand_dims(ecg_array, axis=-1) - filename = '{}_{}_{}.npy'.format(pt_id, AcquisitionDateTime,PharmaUniqueECGID) - print(f'would write npy to {filename} len lead III {len(lead_data["III"])}') + print(f'Writing row of ECG2AF predictions for ECG {patient_id}, at {acquisition_date_time}') ecg_array -= ecg_array.mean() ecg_array /= (ecg_array.std() + 1e-6) #print(f"Got tensor: {tensor.mean():0.3f}") @@ -212,9 +209,9 @@ def process_ge_muse_xml(filepath, space_dict): predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)} #print(f"Got predictions: {predictions_dict}") space_dict['filepath'].append(os.path.basename(filepath)) - space_dict['patient_id'].append(pt_id) - space_dict['acquisition_datetime'].append(AcquisitionDateTime) - space_dict['pharma_unique_ecg_id'].append(PharmaUniqueECGID) + space_dict['patient_id'].append(patient_id) + space_dict['acquisition_datetime'].append(acquisition_date_time) + space_dict['pharma_unique_ecg_id'].append(pharma_unique_ecg_id) for otm in output_tensormaps.values(): y = predictions_dict[otm.output_name()] From aa2e4ea25d7e899aeaff6efb833d1b16f88cd37e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 10 Dec 2024 13:43:53 -0500 Subject: [PATCH 063/119] fix --- ml4h/arguments.py | 13 +++++ ml4h/models/diffusion_blocks.py | 93 ++++++++++++++++++++++++++------- ml4h/models/train.py | 20 +++++-- 3 files changed, 102 insertions(+), 24 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index f41d6c113..01c20e804 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -227,6 +227,19 @@ def parse_args(): help='For diffusion models, when U-Net representation size is smaller than attention_window ' 'Cross-Attention is applied', ) + parser.add_argument( + '--attention_modulo', default=3, type=int, + help='For diffusion models, this controls how frequently Cross-Attention is applied. ' + '2 means every other residual block, 3 would mean every third.', + ) + parser.add_argument( + '--diffusion_loss', default='sigmoid', + help='Loss function to use for diffusion models. Can be sigmoid, mean_absolute_error, or mean_squared_error', + ) + parser.add_argument( + '--sigmoid_beta', default=-3, type=float, + help='Beta to use with sigmoid loss for diffusion models.', + ) parser.add_argument( '--transformer_size', default=32, type=int, help='Number of output neurons in Transformer encoders and decoders, ' diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 1b1794b8a..3ef1819dd 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -102,7 +102,53 @@ def apply(x): return apply -def get_control_network(input_shape, widths, block_depth, kernel_size, control_size, attention_start, attention_heads): +def residual_block_control(width, conv, kernel_size, attention_heads): + def apply(x): + x, control = x + input_width = x.shape[-1] + if input_width == width: + residual = x + else: + residual = conv(width, kernel_size=1)(x) + x = layers.BatchNormalization(center=False, scale=False)(x) + x = conv( + width, kernel_size=kernel_size, padding="same", activation=keras.activations.swish + )(x) + x = keras.layers.MultiHeadAttention(num_heads = attention_heads, key_dim = width)(x, control) + x = conv(width, kernel_size=kernel_size, padding="same")(x) + x = layers.Add()([x, residual]) + return x + + return apply + + +def down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads): + def apply(x): + x, skips, control = x + for _ in range(block_depth): + x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control]) + skips.append(x) + x = pool(pool_size=2)(x) + return x + + return apply + + +def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads): + def apply(x): + x, skips, control = x + # x = upsample(size=2, interpolation="bilinear")(x) + x = upsample(size=2)(x) + for _ in range(block_depth): + x = layers.Concatenate()([x, skips.pop()]) + x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control]) + return x + + return apply + + +def get_control_network(input_shape, widths, block_depth, kernel_size, control_size, + attention_start, attention_heads, attention_modulo): noisy_images = keras.Input(shape=input_shape) noise_variances = keras.Input(shape=[1] * len(input_shape)) @@ -124,34 +170,35 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s x = layers.Concatenate()([x, e]) skips = [] - for width in widths[:-1]: - if x.shape[1] < attention_start: + for i, width in enumerate(widths[:-1]): + if (i + 1) % attention_modulo == 0: if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) - x = keras.layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, c2) - x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) + x = down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads)([x, skips, c2]) + else: + x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) - for _ in range(block_depth): - x = keras.layers.MultiHeadAttention(num_heads=attention_heads, key_dim=widths[-1])(x, c2) - x = residual_block(widths[-1], conv, kernel_size)(x) + for i in range(block_depth): + x = residual_block_control(widths[-1], conv, kernel_size, attention_heads)([x, c2]) - for width in reversed(widths[:-1]): - if x.shape[1] < attention_start: + for i, width in enumerate(reversed(widths[:-1])): + if i % attention_modulo == 0: if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) - x = keras.layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, c2) - x = up_block(width, block_depth, conv, upsample, kernel_size)([x, skips]) + x = up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads)([x, skips, c2]) + else: + x = up_block(width, block_depth, conv, upsample, kernel_size)([x, skips]) - x = conv(input_shape[-1], kernel_size=1, kernel_initializer="zeros")(x) + x = conv(input_shape[-1], kernel_size=1, activation="linear", kernel_initializer="zeros")(x) return keras.Model([noisy_images, noise_variances, control], x, name="control_unet") @@ -169,7 +216,7 @@ def get_control_embed_model(output_maps, control_size): class DiffusionController(keras.Model): def __init__( self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, + attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta ): super().__init__() @@ -178,11 +225,11 @@ def __init__( self.output_maps = output_maps self.control_embed_model = get_control_embed_model(self.output_maps, control_size) self.normalizer = layers.Normalization() - self.network = get_control_network( - self.input_map.shape, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, - ) + self.network = get_control_network(self.input_map.shape, widths, block_depth, conv_x, control_size, + attention_start, attention_heads, attention_modulo) self.ema_network = keras.models.clone_model(self.network) + self.use_sigmoid_loss = diffusion_loss == 'sigmoid' + self.beta = sigmoid_beta def compile(self, **kwargs): @@ -306,6 +353,14 @@ def train_step(self, batch): noise_loss = self.loss(noises, pred_noises) # used for training image_loss = self.loss(images, pred_images) # only used as metric + if self.use_sigmoid_loss: + signal_rates_squared = tf.square(signal_rates) + noise_rates_squared = tf.square(noise_rates) + + # Compute log-SNR (lambda_t) + lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared) + weight = tf.math.sigmoid(self.beta - lambda_t) + noise_loss = weight * noise_loss gradients = tape.gradient(noise_loss, self.network.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) @@ -562,7 +617,7 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, plt.axis("off") plt.tight_layout() now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}') + figure_path = os.path.join(prefix, f'diffusion_ecg_generations_{now_string}{IMAGE_EXT}') if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 363b4c149..ae979fe80 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -2,6 +2,7 @@ import os import logging +from functools import partial from typing import List, Tuple, Iterable, Union import numpy as np @@ -272,16 +273,17 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba def train_diffusion_control_model(args): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) model = DiffusionController( - args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, - args.dense_blocks, args.block_size, args.conv_x, args.dense_layers[0], - args.attention_window, args.attention_heads, + args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, + args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, + args.sigmoid_beta, ) + loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error model.compile( optimizer=tfa.optimizers.AdamW( learning_rate=args.learning_rate, weight_decay=1e-4, ), - loss=keras.losses.mean_absolute_error, + loss=loss, ) batch = next(generate_train) for k in batch[0]: @@ -297,6 +299,7 @@ def train_diffusion_control_model(args): mode="min", save_best_only=True, ) + callbacks = [checkpoint_callback] # calculate mean and variance of training dataset for normalization model.normalizer.adapt(feature_batch) @@ -326,6 +329,13 @@ def train_diffusion_control_model(args): layer_range=None, show_layer_activations=True, ) + prefix_value = f'{args.output_folder}{args.id}/learning_generations/' + # Create a partial function with reseed and prefix pre-filled + if model.input_map.axes() == 2: + plot_partial = partial(model.plot_ecgs, reseed=args.random_seed, prefix=prefix_value) + else: + plot_partial = partial(model.plot_images, reseed=args.random_seed, prefix=prefix_value) + callbacks.append(keras.callbacks.LambdaCallback(on_epoch_end=plot_partial)) if os.path.exists(checkpoint_path+'.index'): model.load_weights(checkpoint_path) @@ -339,7 +349,7 @@ def train_diffusion_control_model(args): epochs=args.epochs, validation_data=generate_valid, validation_steps=args.validation_steps, - callbacks=[checkpoint_callback], + callbacks=callbacks, ) plot_metric_history(history, args.training_steps, args.id, os.path.dirname(checkpoint_path)) model.load_weights(checkpoint_path) From 8c6b2033c166867ffaa5615cb29a778ed57fdffb Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 10 Dec 2024 18:32:51 -0500 Subject: [PATCH 064/119] fix --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 3ef1819dd..cb0a625d4 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -171,7 +171,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s skips = [] for i, width in enumerate(widths[:-1]): - if (i + 1) % attention_modulo == 0: + if attention_modulo > 1 and (i + 1) % attention_modulo == 0: if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: @@ -189,7 +189,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s x = residual_block_control(widths[-1], conv, kernel_size, attention_heads)([x, c2]) for i, width in enumerate(reversed(widths[:-1])): - if i % attention_modulo == 0: + if attention_modulo > 1 and i % attention_modulo == 0: if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: From ab95a1ec659152b858a838e3f46b471941b8dfb4 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 10 Dec 2024 18:53:05 -0500 Subject: [PATCH 065/119] fix --- ml4h/models/diffusion_blocks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index cb0a625d4..c974def66 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -451,20 +451,21 @@ def test_step(self, batch): return {m.name: m.result() for m in self.metrics} - def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, prefix='./figures/'): + def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): control_batch = {} for cm in self.output_maps: control_batch[cm.output_name()] = np.zeros((max(self.batch_size, num_rows * num_cols),) + cm.shape) if 'Sex' in cm.name: control_batch[cm.output_name()][:, 0] = 1 # all female - logging.info(f'\nControl batch keys: {list(control_batch.keys())}') + print(f'\nControl batch keys: {list(control_batch.keys())}') control_embed = self.control_embed_model(control_batch) # plot random generated images for visual evaluation of generation quality generated_images = self.generate( control_embed, num_images=max(self.batch_size, num_rows * num_cols), diffusion_steps=plot_diffusion_steps, + reseed=reseed, ) plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) for row in range(num_rows): From 3beef1b0b11efcd57281df1398604417d398b9e7 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 10 Dec 2024 20:08:02 -0500 Subject: [PATCH 066/119] fix --- ml4h/models/diffusion_blocks.py | 2 +- ml4h/models/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index c974def66..10c0c4a22 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -209,7 +209,7 @@ def get_control_embed_model(output_maps, control_size): for cm in output_maps: control_ins.append(keras.Input(shape=cm.shape, name=cm.output_name())) c = layers.Concatenate()(control_ins) - c = layers.Dense(control_size, activation='linear')(c) + #c = layers.Dense(control_size, activation='linear')(c) return keras.Model(control_ins, c, name='control_embed') diff --git a/ml4h/models/train.py b/ml4h/models/train.py index ae979fe80..b6c9babc7 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -367,6 +367,6 @@ def train_diffusion_control_model(args): args.model_file = model_file eval_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, - 0.0,2.5,f'{args.output_folder}/{args.id}/') + 0.0,2.0,f'{args.output_folder}/{args.id}/') return model From 2ab67e14b409dd06c8351f4fa16853c62729d914 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 11 Dec 2024 10:32:39 -0500 Subject: [PATCH 067/119] fix --- ml4h/models/diffusion_blocks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 10c0c4a22..e0160f922 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -110,11 +110,12 @@ def apply(x): residual = x else: residual = conv(width, kernel_size=1)(x) + x = keras.layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control) x = layers.BatchNormalization(center=False, scale=False)(x) x = conv( width, kernel_size=kernel_size, padding="same", activation=keras.activations.swish )(x) - x = keras.layers.MultiHeadAttention(num_heads = attention_heads, key_dim = width)(x, control) + x = conv(width, kernel_size=kernel_size, padding="same")(x) x = layers.Add()([x, residual]) return x From 43381d5ab142f7f75ebfdabc9320243b98f57466 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 11 Dec 2024 10:35:22 -0500 Subject: [PATCH 068/119] fix --- ml4h/models/diffusion_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index e0160f922..437b397b8 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -210,7 +210,7 @@ def get_control_embed_model(output_maps, control_size): for cm in output_maps: control_ins.append(keras.Input(shape=cm.shape, name=cm.output_name())) c = layers.Concatenate()(control_ins) - #c = layers.Dense(control_size, activation='linear')(c) + c = layers.Dense(control_size, activation='linear')(c) return keras.Model(control_ins, c, name='control_embed') From 414aafa44a3069ca4ad2877a6b96773e4b7e6131 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 11 Dec 2024 16:59:18 -0500 Subject: [PATCH 069/119] fix --- ml4h/tensormap/ukb/mri.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ml4h/tensormap/ukb/mri.py b/ml4h/tensormap/ukb/mri.py index b56ffb4d6..f9f1d4c63 100755 --- a/ml4h/tensormap/ukb/mri.py +++ b/ml4h/tensormap/ukb/mri.py @@ -420,6 +420,16 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}): normalization=ZeroMeanStd1(), tensor_from_file=_random_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/2/instance_0'), ) +lax_3ch_random_slice_3d = TensorMap( + 'lax_3ch_random_slice_3d', Interpretation.CONTINUOUS, shape=(224, 160, 1), + normalization=ZeroMeanStd1(), + tensor_from_file=_random_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_3ch/2/instance_0'), +) +lax_2ch_random_slice_3d = TensorMap( + 'lax_2ch_random_slice_3d', Interpretation.CONTINUOUS, shape=(224, 224, 1), + normalization=ZeroMeanStd1(), + tensor_from_file=_random_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0'), +) lax_4ch_diastole_slice0_224_3d_augmented = TensorMap( 'lax_4ch_diastole_slice0_224_3d_augmented', Interpretation.CONTINUOUS, shape=(160, 224, 1), From 6921c4a62442d7e7e4d803be03ed360e6568ea5f Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 12 Dec 2024 13:02:38 -0500 Subject: [PATCH 070/119] fix --- ml4h/models/diffusion_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 437b397b8..d90048ff7 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -32,7 +32,7 @@ weight_decay = 1e-4 # plotting -plot_diffusion_steps = 20 +plot_diffusion_steps = 50 def sinusoidal_embedding(x, dims=1): From 1eee793ee4728efad78b67bd1f7bdcc91c33e7b1 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 13 Dec 2024 13:10:17 -0500 Subject: [PATCH 071/119] ecg 512 --- ml4h/tensormap/ukb/ecg.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ml4h/tensormap/ukb/ecg.py b/ml4h/tensormap/ukb/ecg.py index 15d71f292..a11bcb1c0 100755 --- a/ml4h/tensormap/ukb/ecg.py +++ b/ml4h/tensormap/ukb/ecg.py @@ -581,6 +581,12 @@ def ecg_rest_section_to_segment(tm, hd5, dependents={}): normalization=Standardize(mean=0, std=10), ) +ecg_rest_median_512 = TensorMap( + 'ecg_rest_median_512', Interpretation.CONTINUOUS, path_prefix='ukb_ecg_rest', shape=(512, 12), loss='logcosh', + activation='linear', tensor_from_file=_make_ecg_rest(), channel_map=ECG_REST_MEDIAN_LEADS, + normalization=ZeroMeanStd1(), +) + ecg_rest_median_raw_10_no_poor = TensorMap( 'ecg_rest_median_raw_10', Interpretation.CONTINUOUS, path_prefix='ukb_ecg_rest', shape=(600, 12), loss='logcosh', activation='linear', tensor_from_file=_make_ecg_rest(skip_poor=True), metrics=['mse', 'mae'], channel_map=ECG_REST_MEDIAN_LEADS, normalization=Standardize(mean=0, std=10), From 5928cd76bb40f4874b14640e0de5ada99b54f16a Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 13 Dec 2024 13:44:38 -0500 Subject: [PATCH 072/119] condition strategy --- ml4h/arguments.py | 4 +++ ml4h/models/diffusion_blocks.py | 46 ++++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index 01c20e804..f1111f3c0 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -232,6 +232,10 @@ def parse_args(): help='For diffusion models, this controls how frequently Cross-Attention is applied. ' '2 means every other residual block, 3 would mean every third.', ) + parser.add_argument( + '--diffusion_condition_strategy', default='concat', choices=['cross_attention', 'concat', 'film'], + help='For diffusion models, this controls conditional embeddings are integrated into the U-NET', + ) parser.add_argument( '--diffusion_loss', default='sigmoid', help='Loss function to use for diffusion models. Can be sigmoid, mean_absolute_error, or mean_squared_error', diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index d90048ff7..f1afea6d8 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -102,7 +102,20 @@ def apply(x): return apply -def residual_block_control(width, conv, kernel_size, attention_heads): +def condition_layer_film(input_tensor, control_vector, filters): + # Transform control into gamma and beta + gamma = layers.Dense(filters, activation="linear")(control_vector) + beta = layers.Dense(filters, activation="linear")(control_vector) + + # Reshape gamma and beta to match the spatial dimensions + gamma = tf.reshape(gamma, (-1, 1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, 1, filters)) + + # Apply FiLM (Feature-wise Linear Modulation) + return input_tensor * gamma + beta + + +def residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy): def apply(x): x, control = x input_width = x.shape[-1] @@ -110,7 +123,14 @@ def apply(x): residual = x else: residual = conv(width, kernel_size=1)(x) - x = keras.layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control) + + if 'cross_attention' == condition_strategy: + x = layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control) + elif 'concat' == condition_strategy: + x = layers.Concatenate()([x, control]) + elif 'film' == condition_strategy: + x = condition_layer_film(x, control, width) + x = layers.BatchNormalization(center=False, scale=False)(x) x = conv( width, kernel_size=kernel_size, padding="same", activation=keras.activations.swish @@ -123,11 +143,11 @@ def apply(x): return apply -def down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads): +def down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads, condition_strategy): def apply(x): x, skips, control = x for _ in range(block_depth): - x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control]) + x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control]) skips.append(x) x = pool(pool_size=2)(x) return x @@ -135,21 +155,21 @@ def apply(x): return apply -def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads): +def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads, condition_strategy): def apply(x): x, skips, control = x # x = upsample(size=2, interpolation="bilinear")(x) x = upsample(size=2)(x) for _ in range(block_depth): x = layers.Concatenate()([x, skips.pop()]) - x = residual_block_control(width, conv, kernel_size, attention_heads)([x, control]) + x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control]) return x return apply def get_control_network(input_shape, widths, block_depth, kernel_size, control_size, - attention_start, attention_heads, attention_modulo): + attention_window, attention_heads, attention_modulo, condition_strategy): noisy_images = keras.Input(shape=input_shape) noise_variances = keras.Input(shape=[1] * len(input_shape)) @@ -177,7 +197,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) - x = down_block_control(width, block_depth, conv, pool, kernel_size, attention_heads)([x, skips, c2]) + x = down_block_control(width, block_depth, conv, pool, + kernel_size, attention_heads, condition_strategy)([x, skips, c2]) else: x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) @@ -187,7 +208,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s c2 = upsample(size=x.shape[-2])(control[control_idxs]) for i in range(block_depth): - x = residual_block_control(widths[-1], conv, kernel_size, attention_heads)([x, c2]) + x = residual_block_control(widths[-1], conv, kernel_size, attention_heads, condition_strategy)([x, c2]) for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and i % attention_modulo == 0: @@ -195,7 +216,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) - x = up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads)([x, skips, c2]) + x = up_block_control(width, block_depth, conv, upsample, + kernel_size, attention_heads, condition_strategy)([x, skips, c2]) else: x = up_block(width, block_depth, conv, upsample, kernel_size)([x, skips]) @@ -217,7 +239,7 @@ def get_control_embed_model(output_maps, control_size): class DiffusionController(keras.Model): def __init__( self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta + attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, ): super().__init__() @@ -227,7 +249,7 @@ def __init__( self.control_embed_model = get_control_embed_model(self.output_maps, control_size) self.normalizer = layers.Normalization() self.network = get_control_network(self.input_map.shape, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, attention_modulo) + attention_start, attention_heads, attention_modulo, condition_strategy) self.ema_network = keras.models.clone_model(self.network) self.use_sigmoid_loss = diffusion_loss == 'sigmoid' self.beta = sigmoid_beta From c53e449690f85dd32a336477ec646eeb472e7b0a Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 13 Dec 2024 13:49:43 -0500 Subject: [PATCH 073/119] condition strategy --- ml4h/models/diffusion_blocks.py | 2 +- ml4h/models/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index f1afea6d8..087b28199 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -123,7 +123,7 @@ def apply(x): residual = x else: residual = conv(width, kernel_size=1)(x) - + if 'cross_attention' == condition_strategy: x = layers.MultiHeadAttention(num_heads=attention_heads, key_dim=width)(x, control) elif 'concat' == condition_strategy: diff --git a/ml4h/models/train.py b/ml4h/models/train.py index b6c9babc7..00065dcb8 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -275,7 +275,7 @@ def train_diffusion_control_model(args): model = DiffusionController( args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, - args.sigmoid_beta, + args.sigmoid_beta, args.diffusion_condition_strategy, ) loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error From 82a9d67c7ffbf5c937533ca7e0efc67db1c24087 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Sun, 15 Dec 2024 08:21:42 -0500 Subject: [PATCH 074/119] condition strategy --- ml4h/models/diffusion_blocks.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 087b28199..974fc984a 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -193,7 +193,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s skips = [] for i, width in enumerate(widths[:-1]): if attention_modulo > 1 and (i + 1) % attention_modulo == 0: - if len(input_shape) > 2: + if condition_strategy == 'film': + c2 = control + elif len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -202,7 +204,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s else: x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) - if len(input_shape) > 2: + if condition_strategy == 'film': + c2 = control + elif len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -212,7 +216,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and i % attention_modulo == 0: - if len(input_shape) > 2: + if condition_strategy == 'film': + c2 = control + elif len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) From d85a97df02b7c48e33bc68797a375c51459da883 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:04:36 -0500 Subject: [PATCH 075/119] condition strategy --- ml4h/models/diffusion_blocks.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 974fc984a..6698adf9b 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -108,9 +108,12 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - gamma = tf.reshape(gamma, (-1, 1, 1, filters)) - beta = tf.reshape(beta, (-1, 1, 1, filters)) - + if 4 == len(input_tensor.shape): + gamma = tf.reshape(gamma, (-1, 1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, 1, filters)) + elif 3 == len(input_tensor.shape): + gamma = tf.reshape(gamma, (-1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, filters)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta From 0931fd7a7f1e6ef1a35036b431d2c81e586eb80c Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:26:51 -0500 Subject: [PATCH 076/119] condition strategy --- ml4h/models/diffusion_blocks.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 6698adf9b..1c9e18e89 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -28,8 +28,6 @@ # optimization ema = 0.999 -learning_rate = 5e-4 -weight_decay = 1e-4 # plotting plot_diffusion_steps = 50 @@ -196,9 +194,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s skips = [] for i, width in enumerate(widths[:-1]): if attention_modulo > 1 and (i + 1) % attention_modulo == 0: - if condition_strategy == 'film': - c2 = control - elif len(input_shape) > 2: + if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -219,9 +215,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and i % attention_modulo == 0: - if condition_strategy == 'film': - c2 = control - elif len(input_shape) > 2: + if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -659,7 +653,7 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, def get_network(input_shape, widths, block_depth, kernel_size): noisy_images = keras.Input(shape=input_shape) - conv, upsample, pool = layers_from_shape(input_shape) + conv, upsample, pool, _ = layers_from_shape_control(input_shape) noise_variances = keras.Input(shape=[1] * len(input_shape)) e = layers.Lambda(sinusoidal_embedding)(noise_variances) From 7b49ea013b82ba2a46e13dd07e72c57e257198d0 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:29:37 -0500 Subject: [PATCH 077/119] condition strategy --- ml4h/models/diffusion_blocks.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 1c9e18e89..5f6855876 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -106,12 +106,12 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - if 4 == len(input_tensor.shape): - gamma = tf.reshape(gamma, (-1, 1, 1, filters)) - beta = tf.reshape(beta, (-1, 1, 1, filters)) - elif 3 == len(input_tensor.shape): - gamma = tf.reshape(gamma, (-1, 1, filters)) - beta = tf.reshape(beta, (-1, 1, filters)) + # if 4 == len(input_tensor.shape): + # gamma = tf.reshape(gamma, (-1, 1, 1, filters)) + # beta = tf.reshape(beta, (-1, 1, 1, filters)) + # elif 3 == len(input_tensor.shape): + # gamma = tf.reshape(gamma, (-1, 1, filters)) + # beta = tf.reshape(beta, (-1, 1, filters)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta @@ -203,9 +203,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s else: x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) - if condition_strategy == 'film': - c2 = control - elif len(input_shape) > 2: + if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) From ab3c844522f53ea01e9bd506b94f97dc1ab93494 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:35:18 -0500 Subject: [PATCH 078/119] condition strategy --- ml4h/models/diffusion_blocks.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 5f6855876..bfd466352 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -106,12 +106,8 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - # if 4 == len(input_tensor.shape): - # gamma = tf.reshape(gamma, (-1, 1, 1, filters)) - # beta = tf.reshape(beta, (-1, 1, 1, filters)) - # elif 3 == len(input_tensor.shape): - # gamma = tf.reshape(gamma, (-1, 1, filters)) - # beta = tf.reshape(beta, (-1, 1, filters)) + gamma = tf.reshape(gamma, input_tensor.input_shape[:-1] + (filters,)) + beta = tf.reshape(beta, input_tensor.input_shape[:-1] + (filters,)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta From 7f082beab9ddee8eceefd563971ebc1fdd7186e2 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:36:17 -0500 Subject: [PATCH 079/119] condition strategy --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index bfd466352..6d0ebc2aa 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -106,8 +106,8 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - gamma = tf.reshape(gamma, input_tensor.input_shape[:-1] + (filters,)) - beta = tf.reshape(beta, input_tensor.input_shape[:-1] + (filters,)) + gamma = tf.reshape(gamma, input_tensor.shape[:-1] + (filters,)) + beta = tf.reshape(beta, input_tensor.shape[:-1] + (filters,)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta From ce66b0ea3e5e1a11684088e17a795ef616711b7e Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:37:51 -0500 Subject: [PATCH 080/119] condition strategy --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 6d0ebc2aa..6cd5835be 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -106,8 +106,8 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - gamma = tf.reshape(gamma, input_tensor.shape[:-1] + (filters,)) - beta = tf.reshape(beta, input_tensor.shape[:-1] + (filters,)) + gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,)) + beta = tf.reshape(beta, (-1,) + input_tensor.shape[1:-1] + (filters,)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta From 3e942598c71f411e739b0fa8ddf95d03c6386502 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:39:04 -0500 Subject: [PATCH 081/119] condition strategy --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 6cd5835be..2ac518d62 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -106,8 +106,8 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,)) - beta = tf.reshape(beta, (-1,) + input_tensor.shape[1:-1] + (filters,)) + #gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,)) + #beta = tf.reshape(beta, (-1,) + input_tensor.shape[1:-1] + (filters,)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta From 4e7cebb192d1ac3724998d6ca84822406b91741f Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:44:38 -0500 Subject: [PATCH 082/119] condition strategy --- ml4h/models/diffusion_blocks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 2ac518d62..4007bd61c 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -10,6 +10,7 @@ from tensorflow import keras from keras import layers +from torch.nn.quantized.functional import upsample from ml4h.defines import IMAGE_EXT from ml4h.models.Block import Block @@ -157,6 +158,7 @@ def apply(x): x, skips, control = x # x = upsample(size=2, interpolation="bilinear")(x) x = upsample(size=2)(x) + control = upsample(size=2)(control) for _ in range(block_depth): x = layers.Concatenate()([x, skips.pop()]) x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control]) From a3165b6b0a7bccb57e5273d1a592ae2fa7db8924 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:46:26 -0500 Subject: [PATCH 083/119] condition strategy --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 4007bd61c..6a8f6c198 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -103,8 +103,8 @@ def apply(x): def condition_layer_film(input_tensor, control_vector, filters): # Transform control into gamma and beta - gamma = layers.Dense(filters, activation="linear")(control_vector) - beta = layers.Dense(filters, activation="linear")(control_vector) + gamma = layers.Dense(filters*2, activation="linear")(control_vector) + beta = layers.Dense(filters*2, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions #gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,)) From 306888aad5a7ca54ccbadb638ca0e8f6654770c9 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:47:33 -0500 Subject: [PATCH 084/119] condition strategy --- ml4h/models/diffusion_blocks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 6a8f6c198..e6e856d07 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -157,11 +157,12 @@ def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_ def apply(x): x, skips, control = x # x = upsample(size=2, interpolation="bilinear")(x) - x = upsample(size=2)(x) - control = upsample(size=2)(control) + + #control = upsample(size=2)(control) for _ in range(block_depth): x = layers.Concatenate()([x, skips.pop()]) x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control]) + x = upsample(size=2)(x) return x return apply From b146aa6d4f99b9b344969355f9ae506211bfcebb Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:48:29 -0500 Subject: [PATCH 085/119] condition strategy --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index e6e856d07..45ca9b279 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -103,8 +103,8 @@ def apply(x): def condition_layer_film(input_tensor, control_vector, filters): # Transform control into gamma and beta - gamma = layers.Dense(filters*2, activation="linear")(control_vector) - beta = layers.Dense(filters*2, activation="linear")(control_vector) + gamma = layers.Dense(filters, activation="linear")(control_vector) + beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions #gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,)) From b7d5afb736a40fd27c9b3cdf7e186eb18d871240 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 11:51:57 -0500 Subject: [PATCH 086/119] condition strategy --- ml4h/models/diffusion_blocks.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 45ca9b279..0589bb0b9 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -10,7 +10,6 @@ from tensorflow import keras from keras import layers -from torch.nn.quantized.functional import upsample from ml4h.defines import IMAGE_EXT from ml4h.models.Block import Block @@ -107,8 +106,12 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - #gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,)) - #beta = tf.reshape(beta, (-1,) + input_tensor.shape[1:-1] + (filters,)) + if 4 == len(input_tensor.shape): + gamma = tf.reshape(gamma, (-1, 1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, 1, filters)) + elif 3 == len(input_tensor.shape): + gamma = tf.reshape(gamma, (-1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, filters)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta @@ -157,12 +160,10 @@ def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_ def apply(x): x, skips, control = x # x = upsample(size=2, interpolation="bilinear")(x) - - #control = upsample(size=2)(control) + x = upsample(size=2)(x) for _ in range(block_depth): x = layers.Concatenate()([x, skips.pop()]) x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control]) - x = upsample(size=2)(x) return x return apply @@ -193,7 +194,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s skips = [] for i, width in enumerate(widths[:-1]): if attention_modulo > 1 and (i + 1) % attention_modulo == 0: - if len(input_shape) > 2: + if condition_strategy == 'film': + c2 = control + elif len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -202,7 +205,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s else: x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) - if len(input_shape) > 2: + if condition_strategy == 'film': + c2 = control + elif len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -212,7 +217,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and i % attention_modulo == 0: - if len(input_shape) > 2: + if condition_strategy == 'film': + c2 = control + elif len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) From e8e5228ee845c5a3daa9eb77a394cb35f052c5d1 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 12:16:58 -0500 Subject: [PATCH 087/119] condition strategy --- ml4h/models/diffusion_blocks.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 0589bb0b9..5f6855876 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -106,12 +106,12 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - if 4 == len(input_tensor.shape): - gamma = tf.reshape(gamma, (-1, 1, 1, filters)) - beta = tf.reshape(beta, (-1, 1, 1, filters)) - elif 3 == len(input_tensor.shape): - gamma = tf.reshape(gamma, (-1, 1, filters)) - beta = tf.reshape(beta, (-1, 1, filters)) + # if 4 == len(input_tensor.shape): + # gamma = tf.reshape(gamma, (-1, 1, 1, filters)) + # beta = tf.reshape(beta, (-1, 1, 1, filters)) + # elif 3 == len(input_tensor.shape): + # gamma = tf.reshape(gamma, (-1, 1, filters)) + # beta = tf.reshape(beta, (-1, 1, filters)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta @@ -194,9 +194,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s skips = [] for i, width in enumerate(widths[:-1]): if attention_modulo > 1 and (i + 1) % attention_modulo == 0: - if condition_strategy == 'film': - c2 = control - elif len(input_shape) > 2: + if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -205,9 +203,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s else: x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) - if condition_strategy == 'film': - c2 = control - elif len(input_shape) > 2: + if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) @@ -217,9 +213,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and i % attention_modulo == 0: - if condition_strategy == 'film': - c2 = control - elif len(input_shape) > 2: + if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: c2 = upsample(size=x.shape[-2])(control[control_idxs]) From 46c5298f0b74de34214b79b08bd51af7d2d075a3 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 12:19:04 -0500 Subject: [PATCH 088/119] condition strategy --- ml4h/models/diffusion_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 5f6855876..2fedf73cc 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -212,7 +212,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s x = residual_block_control(widths[-1], conv, kernel_size, attention_heads, condition_strategy)([x, c2]) for i, width in enumerate(reversed(widths[:-1])): - if attention_modulo > 1 and i % attention_modulo == 0: + if False and attention_modulo > 1 and i % attention_modulo == 0: if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) else: From 7a4f445448b3542b0b0b7a4fe5a107ca892d6de3 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 16 Dec 2024 12:22:31 -0500 Subject: [PATCH 089/119] condition strategy --- ml4h/models/diffusion_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 2fedf73cc..613878dc7 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -214,9 +214,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s for i, width in enumerate(reversed(widths[:-1])): if False and attention_modulo > 1 and i % attention_modulo == 0: if len(input_shape) > 2: - c2 = upsample(size=x.shape[1:-1])(control[control_idxs]) + c2 = upsample(size=x.shape[1:-1]*2)(control[control_idxs]) else: - c2 = upsample(size=x.shape[-2])(control[control_idxs]) + c2 = upsample(size=x.shape[-2]*2)(control[control_idxs]) x = up_block_control(width, block_depth, conv, upsample, kernel_size, attention_heads, condition_strategy)([x, skips, c2]) else: From db3530a42f034825f399010dd687ea71f732511f Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 19 Dec 2024 14:00:57 -0500 Subject: [PATCH 090/119] condition strategy --- ml4h/models/diffusion_blocks.py | 729 +++++++++++++++----------------- ml4h/models/train.py | 5 +- 2 files changed, 355 insertions(+), 379 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 613878dc7..ac15a9c05 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -100,6 +100,51 @@ def apply(x): return apply +def get_network(input_shape, widths, block_depth, kernel_size): + noisy_images = keras.Input(shape=input_shape) + conv, upsample, pool, _ = layers_from_shape_control(input_shape) + noise_variances = keras.Input(shape=[1] * len(input_shape)) + + e = layers.Lambda(sinusoidal_embedding)(noise_variances) + if len(input_shape) == 2: + e = upsample(size=input_shape[-2])(e) + else: + e = upsample(size=input_shape[:-1], interpolation="nearest")(e) + + print(f'e shape: {e.shape} len {len(input_shape)}') + x = conv(widths[0], kernel_size=1)(noisy_images) + x = layers.Concatenate()([x, e]) + + skips = [] + for width in widths[:-1]: + x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) + + for _ in range(block_depth): + x = residual_block(widths[-1], conv, kernel_size)(x) + + for width in reversed(widths[:-1]): + x = up_block(width, block_depth, conv, upsample, kernel_size)([x, skips]) + + x = conv(input_shape[-1], kernel_size=1, kernel_initializer="zeros")(x) + + return keras.Model([noisy_images, noise_variances], x, name="residual_unet") + + +def layers_from_shape_control(input_shape): + if len(input_shape) == 2: + return layers.Conv1D, layers.UpSampling1D, layers.AveragePooling1D, tuple( + [slice(None), np.newaxis, slice(None)], + ) + elif len(input_shape) == 3: + return layers.Conv2D, layers.UpSampling2D, layers.AveragePooling2D, tuple( + [slice(None), np.newaxis, np.newaxis, slice(None)], + ) + elif len(input_shape) == 4: + return layers.Conv3D, layers.UpSampling3D, layers.AveragePooling3D, tuple( + [slice(None), np.newaxis, np.newaxis, np.newaxis, slice(None)], + ) + + def condition_layer_film(input_tensor, control_vector, filters): # Transform control into gamma and beta gamma = layers.Dense(filters, activation="linear")(control_vector) @@ -212,7 +257,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s x = residual_block_control(widths[-1], conv, kernel_size, attention_heads, condition_strategy)([x, c2]) for i, width in enumerate(reversed(widths[:-1])): - if False and attention_modulo > 1 and i % attention_modulo == 0: + if attention_modulo > 1 and ((len(widths) - 1) - i) % attention_modulo == 0: if len(input_shape) > 2: c2 = upsample(size=x.shape[1:-1]*2)(control[control_idxs]) else: @@ -237,24 +282,20 @@ def get_control_embed_model(output_maps, control_size): return keras.Model(control_ins, c, name='control_embed') -class DiffusionController(keras.Model): - def __init__( - self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, - ): +class DiffusionModel(keras.Model): + def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta): super().__init__() - self.input_map = tensor_map + self.tensor_map = tensor_map self.batch_size = batch_size - self.output_maps = output_maps - self.control_embed_model = get_control_embed_model(self.output_maps, control_size) self.normalizer = layers.Normalization() - self.network = get_control_network(self.input_map.shape, widths, block_depth, conv_x, control_size, - attention_start, attention_heads, attention_modulo, condition_strategy) + self.network = get_network(self.tensor_map.shape, widths, block_depth, kernel_size) self.ema_network = keras.models.clone_model(self.network) self.use_sigmoid_loss = diffusion_loss == 'sigmoid' self.beta = sigmoid_beta + def can_apply(self): + return self.tensor_map.axes() > 1 def compile(self, **kwargs): super().compile(**kwargs) @@ -288,7 +329,7 @@ def diffusion_schedule(self, diffusion_times): return noise_rates, signal_rates - def denoise(self, control_embed, noisy_images, noise_rates, signal_rates, training): + def denoise(self, noisy_images, noise_rates, signal_rates, training): # the exponential moving average weights are used at evaluation if training: network = self.network @@ -296,12 +337,12 @@ def denoise(self, control_embed, noisy_images, noise_rates, signal_rates, traini network = self.ema_network # predict noise component and calculate the image component using it - pred_noises = network([noisy_images, noise_rates ** 2, control_embed], training=training) + pred_noises = network([noisy_images, noise_rates ** 2], training=training) pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates return pred_noises, pred_images - def reverse_diffusion(self, control_embed, initial_noise, diffusion_steps): + def reverse_diffusion(self, initial_noise, diffusion_steps): # reverse diffusion = sampling num_images = initial_noise.shape[0] step_size = 1.0 / diffusion_steps @@ -310,15 +351,14 @@ def reverse_diffusion(self, control_embed, initial_noise, diffusion_steps): # at the first sampling step, the "noisy image" is pure noise # but its signal rate is assumed to be nonzero (min_signal_rate) next_noisy_images = initial_noise - pred_images = None for step in range(diffusion_steps): noisy_images = next_noisy_images # separate the current noisy image to its components - diffusion_times = tf.ones([num_images] + [1] * self.input_map.axes()) - step * step_size + diffusion_times = tf.ones([num_images] + [1] * self.tensor_map.axes()) - step * step_size noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) pred_noises, pred_images = self.denoise( - control_embed, noisy_images, noise_rates, signal_rates, training=False, + noisy_images, noise_rates, signal_rates, training=False, ) # network used in eval mode @@ -331,38 +371,28 @@ def reverse_diffusion(self, control_embed, initial_noise, diffusion_steps): next_signal_rates * pred_images + next_noise_rates * pred_noises ) # this new noisy image will be used in the next step + return pred_images - def generate(self, control_embed, num_images, diffusion_steps, reseed=None, renoise=None): + def generate(self, num_images, diffusion_steps): # noise -> images -> denormalized images - - if reseed is not None: - tf.random.set_seed(reseed) - - initial_noise = tf.random.normal(shape=(num_images,) + self.input_map.shape) - - generated_images = self.reverse_diffusion(control_embed, initial_noise, diffusion_steps) - generated_images = self.denormalize(generated_images) - return generated_images - - def generate_from_noise(self, control_embed, num_images, diffusion_steps, initial_noise): - generated_images = self.reverse_diffusion(control_embed, initial_noise, diffusion_steps) + initial_noise = tf.random.normal(shape=(num_images,) + self.tensor_map.shape) + generated_images = self.reverse_diffusion(initial_noise, diffusion_steps) generated_images = self.denormalize(generated_images) return generated_images - def train_step(self, batch): + def train_step(self, images_original): # normalize images to have standard deviation of 1, like the noises - images = batch[0][self.input_map.input_name()] + images = images_original[0][self.tensor_map.input_name()] self.normalizer.update_state(images) + # images = images['input_lax_4ch_diastole_slice0_224_3d_continuous'] images = self.normalizer(images, training=True) - - control_embed = self.control_embed_model(batch[1]) - - noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) + # images = images.numpy() - images.numpy().mean() / images.numpy().std() + noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) # sample uniform random diffusion times diffusion_times = tf.random.uniform( - shape=[self.batch_size] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0, + shape=[self.batch_size] + [1] * self.tensor_map.axes(), minval=0.0, maxval=1.0, ) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) # mix the images with noises accordingly @@ -372,7 +402,7 @@ def train_step(self, batch): with tf.GradientTape() as tape: # train the network to separate noisy images to their components pred_noises, pred_images = self.denoise( - control_embed, noisy_images, noise_rates, signal_rates, training=True, + noisy_images, noise_rates, signal_rates, training=True, ) noise_loss = self.loss(noises, pred_noises) # used for training @@ -399,56 +429,17 @@ def train_step(self, batch): # KID is not measured during the training phase for computational efficiency return {m.name: m.result() for m in self.metrics[:-1]} - # def call(self, inputs): - # # normalize images to have standard deviation of 1, like the noises - # images = inputs[self.input_map.input_name()] - # self.normalizer.update_state(images) - # images = self.normalizer(images, training=False) - - # control_embed = self.control_embed_model(inputs) - - # noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) - - # # sample uniform random diffusion times - # diffusion_times = tf.random.uniform( - # shape=[self.batch_size, ] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0 - # ) - # noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) - # # mix the images with noises accordingly - # noisy_images = signal_rates * images + noise_rates * noises - - # # use the network to separate noisy images to their components - # pred_noises, pred_images = self.denoise( - # control_embed, noisy_images, noise_rates, signal_rates, training=False - # ) - - # noise_loss = self.loss(noises, pred_noises) - # image_loss = self.loss(images, pred_images) - - # self.image_loss_tracker.update_state(image_loss) - # self.noise_loss_tracker.update_state(noise_loss) - - # # measure KID between real and generated images - # # this is computationally demanding, kid_diffusion_steps has to be small - # images = self.denormalize(images) - # generated_images = self.generate( - # control_embed, num_images=self.batch_size, diffusion_steps=20 - # ) - # return generated_images - - def test_step(self, batch): + def test_step(self, images_original): # normalize images to have standard deviation of 1, like the noises - images = batch[0][self.input_map.input_name()] + images = images_original[0][self.tensor_map.input_name()] self.normalizer.update_state(images) images = self.normalizer(images, training=False) - - control_embed = self.control_embed_model(batch[1]) - - noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) + # images = images - tf.math.reduce_mean(images) / tf.math.reduce_std(images) + noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) # sample uniform random diffusion times diffusion_times = tf.random.uniform( - shape=[self.batch_size] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0, + shape=[self.batch_size] + [1] * self.tensor_map.axes(), minval=0.0, maxval=1.0, ) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) # mix the images with noises accordingly @@ -456,47 +447,49 @@ def test_step(self, batch): # use the network to separate noisy images to their components pred_noises, pred_images = self.denoise( - control_embed, noisy_images, noise_rates, signal_rates, training=False, + noisy_images, noise_rates, signal_rates, training=False, ) noise_loss = self.loss(noises, pred_noises) image_loss = self.loss(images, pred_images) + if self.use_sigmoid_loss: + signal_rates_squared = tf.square(signal_rates) + noise_rates_squared = tf.square(noise_rates) + + # Compute log-SNR (lambda_t) + lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared) + weight = tf.math.sigmoid(self.beta - lambda_t) + noise_loss = weight * noise_loss self.image_loss_tracker.update_state(image_loss) self.noise_loss_tracker.update_state(noise_loss) # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small - images = self.denormalize(images) - generated_images = self.generate( - control_embed, num_images=self.batch_size, diffusion_steps=20, - ) + # images = self.denormalize(images) + # generated_images = self.generate( + # num_images=self.batch_size, diffusion_steps=20, + # ) # self.kid.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics} - def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): - control_batch = {} - for cm in self.output_maps: - control_batch[cm.output_name()] = np.zeros((max(self.batch_size, num_rows * num_cols),) + cm.shape) - if 'Sex' in cm.name: - control_batch[cm.output_name()][:, 0] = 1 # all female - - print(f'\nControl batch keys: {list(control_batch.keys())}') - control_embed = self.control_embed_model(control_batch) + def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, prefix='./figures/'): # plot random generated images for visual evaluation of generation quality generated_images = self.generate( - control_embed, - num_images=max(self.batch_size, num_rows * num_cols), + num_images=num_rows * num_cols, diffusion_steps=plot_diffusion_steps, - reseed=reseed, ) + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) for row in range(num_rows): for col in range(num_cols): index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) - plt.imshow(generated_images[index], cmap='gray') + if len(generated_images[index].shape) == 3 and generated_images[index].shape[-1] > 1: + plt.imshow(generated_images[index, ..., 0], cmap='gray') # just plot first frame + else: + plt.imshow(generated_images[index], cmap='gray') plt.axis("off") plt.tight_layout() now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') @@ -504,50 +497,12 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") - plt.close() - - def plot_reconstructions( - self, batch, diffusion_amount=0, - epoch=None, logs=None, num_rows=4, num_cols=4, - ): - images = batch[0][self.input_map.input_name()] - self.normalizer.update_state(images) - images = self.normalizer(images, training=False) - noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) - diffusion_times = diffusion_amount * tf.ones(shape=[self.batch_size] + [1] * self.input_map.axes()) - noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) - # mix the images with noises accordingly - noisy_images = signal_rates * images + noise_rates * noises - - control_embed = self.control_embed_model(batch[1]) - - # use the network to separate noisy images to their components - pred_noises, generated_images = self.denoise( - control_embed, noisy_images, noise_rates, signal_rates, training=False, - ) - plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) - for row in range(num_rows): - for col in range(num_cols): - index = row * num_cols + col - plt.subplot(num_rows, num_cols, index + 1) - plt.imshow(generated_images[index], cmap='gray') - plt.axis("off") - plt.tight_layout() - plt.show() - plt.close() - def control_plot_images( - self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, - renoise=None, - ): - control_embed = self.control_embed_model(control_batch) + def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, prefix='./figures/'): # plot random generated images for visual evaluation of generation quality generated_images = self.generate( - control_embed, num_images=max(self.batch_size, num_rows * num_cols), diffusion_steps=plot_diffusion_steps, - reseed=reseed, - renoise=renoise, ) plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) @@ -555,24 +510,33 @@ def control_plot_images( for col in range(num_cols): index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) - plt.imshow(generated_images[index], cmap='gray') + plt.plot(generated_images[index, ..., 0]) plt.axis("off") plt.tight_layout() - plt.show() - plt.close() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'diffusion_ecg_generations_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") - return generated_images + def plot_reconstructions( + self, images_original, diffusion_amount=0, + epoch=None, logs=None, num_rows=3, num_cols=6, + ): + images = images_original[0][self.tensor_map.input_name()] + self.normalizer.update_state(images) + images = self.normalizer(images, training=False) + noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) - def control_plot_images_noise(self, control_batch, initial_noise, epoch=None, logs=None, num_rows=2, num_cols=8): - control_embed = self.control_embed_model(control_batch) - # plot random generated images for visual evaluation of generation quality - generated_images = self.generate_from_noise( - control_embed, - num_images=max(self.batch_size, num_rows * num_cols), - diffusion_steps=plot_diffusion_steps, - initial_noise=initial_noise, - ) + diffusion_times = diffusion_amount * tf.ones(shape=[self.batch_size] + [1] * self.tensor_map.axes()) + noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) + # mix the images with noises accordingly + noisy_images = signal_rates * images + noise_rates * noises + # use the network to separate noisy images to their components + pred_noises, generated_images = self.denoise( + noisy_images, noise_rates, signal_rates, training=False, + ) plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) for row in range(num_rows): for col in range(num_cols): @@ -584,131 +548,70 @@ def control_plot_images_noise(self, control_batch, initial_noise, epoch=None, lo plt.show() plt.close() - return generated_images - - def control_plot_ecgs( - self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, - renoise=None, - ): - control_embed = self.control_embed_model(control_batch) - # plot random generated images for visual evaluation of generation quality - generated_images = self.generate( - control_embed, - num_images=max(self.batch_size, num_rows * num_cols), - diffusion_steps=plot_diffusion_steps, - reseed=reseed, - renoise=renoise, - ) + def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_cols=6): + images = images_original[0][self.tensor_map.input_name()] + noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) + # reverse diffusion = sampling + num_images = images.shape[0] + step_size = 1.0 / max(0.0001, diffusion_steps) - plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) - for row in range(num_rows): - for col in range(num_cols): - index = row * num_cols + col - plt.subplot(num_rows, num_cols, index + 1) - for lead in range(generated_images.shape[-1]): - plt.plot(generated_images[index, :, lead], label=lead) - plt.axis("off") - plt.tight_layout() - plt.show() - plt.close() + # important line: + # at the first sampling step, the "noisy image" is pure noise + # but its signal rate is assumed to be nonzero (min_signal_rate) + next_noisy_images = images * masks + noises * (1 - masks) + for step in range(diffusion_steps): + noisy_images = next_noisy_images - return generated_images + # separate the current noisy image to its components + diffusion_times = tf.ones([num_images] + [1] * self.tensor_map.axes()) - step * step_size + noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) + pred_noises, generated_images = self.denoise( + noisy_images, noise_rates, signal_rates, training=False, + ) - def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): - control_batch = {} - for cm in self.output_maps: - control_batch[cm.output_name()] = np.zeros((max(self.batch_size, num_rows * num_cols),) + cm.shape) - if 'Sex' in cm.name: - control_batch[cm.output_name()][:, 0] = 1 # all female + # apply the mask + generated_images = generated_images * (1 - masks) + images * masks - print(f'\nControl batch keys: {list(control_batch.keys())}') - control_embed = self.control_embed_model(control_batch) + # remix the predicted components using the next signal and noise rates + next_diffusion_times = diffusion_times - step_size + next_noise_rates, next_signal_rates = self.diffusion_schedule( + next_diffusion_times, + ) + next_noisy_images = ( + next_signal_rates * generated_images + next_noise_rates * pred_noises + ) + # this new noisy image will be used in the next step - # plot random generated images for visual evaluation of generation quality - generated_images = self.generate( - control_embed, - num_images=max(self.batch_size, num_rows * num_cols), - diffusion_steps=plot_diffusion_steps, - reseed=reseed, - ) - logging.info(f'Generated ECGs with shape:{generated_images.shape}') plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) for row in range(num_rows): for col in range(num_cols): index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) - for lead in range(generated_images.shape[-1]): - plt.plot(generated_images[index, :, lead], label=lead) + plt.imshow(generated_images[index], cmap='gray') plt.axis("off") plt.tight_layout() - now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'diffusion_ecg_generations_{now_string}{IMAGE_EXT}') - if not os.path.exists(os.path.dirname(figure_path)): - os.makedirs(os.path.dirname(figure_path)) - plt.savefig(figure_path, bbox_inches="tight") + plt.show() plt.close() -def get_network(input_shape, widths, block_depth, kernel_size): - noisy_images = keras.Input(shape=input_shape) - conv, upsample, pool, _ = layers_from_shape_control(input_shape) - noise_variances = keras.Input(shape=[1] * len(input_shape)) - - e = layers.Lambda(sinusoidal_embedding)(noise_variances) - if len(input_shape) == 2: - e = upsample(size=input_shape[-2])(e) - else: - e = upsample(size=input_shape[:-1], interpolation="nearest")(e) - - print(f'e shape: {e.shape} len {len(input_shape)}') - x = conv(widths[0], kernel_size=1)(noisy_images) - x = layers.Concatenate()([x, e]) - - skips = [] - for width in widths[:-1]: - x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips]) - - for _ in range(block_depth): - x = residual_block(widths[-1], conv, kernel_size)(x) - - for width in reversed(widths[:-1]): - x = up_block(width, block_depth, conv, upsample, kernel_size)([x, skips]) - - x = conv(input_shape[-1], kernel_size=1, kernel_initializer="zeros")(x) - - return keras.Model([noisy_images, noise_variances], x, name="residual_unet") - - -def layers_from_shape_control(input_shape): - if len(input_shape) == 2: - return layers.Conv1D, layers.UpSampling1D, layers.AveragePooling1D, tuple( - [slice(None), np.newaxis, slice(None)], - ) - elif len(input_shape) == 3: - return layers.Conv2D, layers.UpSampling2D, layers.AveragePooling2D, tuple( - [slice(None), np.newaxis, np.newaxis, slice(None)], - ) - elif len(input_shape) == 4: - return layers.Conv3D, layers.UpSampling3D, layers.AveragePooling3D, tuple( - [slice(None), np.newaxis, np.newaxis, np.newaxis, slice(None)], - ) - - - - - -class DiffusionModel(keras.Model): - def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size): +class DiffusionController(keras.Model): + def __init__( + self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, + attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, + ): super().__init__() - self.tensor_map = tensor_map + self.input_map = tensor_map self.batch_size = batch_size + self.output_maps = output_maps + self.control_embed_model = get_control_embed_model(self.output_maps, control_size) self.normalizer = layers.Normalization() - self.network = get_network(self.tensor_map.shape, widths, block_depth, kernel_size) + self.network = get_control_network(self.input_map.shape, widths, block_depth, conv_x, control_size, + attention_start, attention_heads, attention_modulo, condition_strategy) self.ema_network = keras.models.clone_model(self.network) + self.use_sigmoid_loss = diffusion_loss == 'sigmoid' + self.beta = sigmoid_beta - def can_apply(self): - return self.tensor_map.axes() > 1 def compile(self, **kwargs): super().compile(**kwargs) @@ -742,7 +645,7 @@ def diffusion_schedule(self, diffusion_times): return noise_rates, signal_rates - def denoise(self, noisy_images, noise_rates, signal_rates, training): + def denoise(self, control_embed, noisy_images, noise_rates, signal_rates, training): # the exponential moving average weights are used at evaluation if training: network = self.network @@ -750,12 +653,12 @@ def denoise(self, noisy_images, noise_rates, signal_rates, training): network = self.ema_network # predict noise component and calculate the image component using it - pred_noises = network([noisy_images, noise_rates ** 2], training=training) + pred_noises = network([noisy_images, noise_rates ** 2, control_embed], training=training) pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates return pred_noises, pred_images - def reverse_diffusion(self, initial_noise, diffusion_steps): + def reverse_diffusion(self, control_embed, initial_noise, diffusion_steps): # reverse diffusion = sampling num_images = initial_noise.shape[0] step_size = 1.0 / diffusion_steps @@ -764,14 +667,15 @@ def reverse_diffusion(self, initial_noise, diffusion_steps): # at the first sampling step, the "noisy image" is pure noise # but its signal rate is assumed to be nonzero (min_signal_rate) next_noisy_images = initial_noise + pred_images = None for step in range(diffusion_steps): noisy_images = next_noisy_images # separate the current noisy image to its components - diffusion_times = tf.ones([num_images] + [1] * self.tensor_map.axes()) - step * step_size + diffusion_times = tf.ones([num_images] + [1] * self.input_map.axes()) - step * step_size noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) pred_noises, pred_images = self.denoise( - noisy_images, noise_rates, signal_rates, training=False, + control_embed, noisy_images, noise_rates, signal_rates, training=False, ) # network used in eval mode @@ -784,28 +688,38 @@ def reverse_diffusion(self, initial_noise, diffusion_steps): next_signal_rates * pred_images + next_noise_rates * pred_noises ) # this new noisy image will be used in the next step - return pred_images - def generate(self, num_images, diffusion_steps): + def generate(self, control_embed, num_images, diffusion_steps, reseed=None, renoise=None): # noise -> images -> denormalized images - initial_noise = tf.random.normal(shape=(num_images,) + self.tensor_map.shape) - generated_images = self.reverse_diffusion(initial_noise, diffusion_steps) + + if reseed is not None: + tf.random.set_seed(reseed) + + initial_noise = tf.random.normal(shape=(num_images,) + self.input_map.shape) + + generated_images = self.reverse_diffusion(control_embed, initial_noise, diffusion_steps) generated_images = self.denormalize(generated_images) return generated_images - def train_step(self, images_original): + def generate_from_noise(self, control_embed, num_images, diffusion_steps, initial_noise): + generated_images = self.reverse_diffusion(control_embed, initial_noise, diffusion_steps) + generated_images = self.denormalize(generated_images) + return generated_images + + def train_step(self, batch): # normalize images to have standard deviation of 1, like the noises - images = images_original[0][self.tensor_map.input_name()] + images = batch[0][self.input_map.input_name()] self.normalizer.update_state(images) - # images = images['input_lax_4ch_diastole_slice0_224_3d_continuous'] images = self.normalizer(images, training=True) - # images = images.numpy() - images.numpy().mean() / images.numpy().std() - noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) + + control_embed = self.control_embed_model(batch[1]) + + noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) # sample uniform random diffusion times diffusion_times = tf.random.uniform( - shape=[self.batch_size] + [1] * self.tensor_map.axes(), minval=0.0, maxval=1.0, + shape=[self.batch_size] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0, ) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) # mix the images with noises accordingly @@ -815,11 +729,19 @@ def train_step(self, images_original): with tf.GradientTape() as tape: # train the network to separate noisy images to their components pred_noises, pred_images = self.denoise( - noisy_images, noise_rates, signal_rates, training=True, + control_embed, noisy_images, noise_rates, signal_rates, training=True, ) noise_loss = self.loss(noises, pred_noises) # used for training image_loss = self.loss(images, pred_images) # only used as metric + if self.use_sigmoid_loss: + signal_rates_squared = tf.square(signal_rates) + noise_rates_squared = tf.square(noise_rates) + + # Compute log-SNR (lambda_t) + lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared) + weight = tf.math.sigmoid(self.beta - lambda_t) + noise_loss = weight * noise_loss gradients = tape.gradient(noise_loss, self.network.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) @@ -834,17 +756,56 @@ def train_step(self, images_original): # KID is not measured during the training phase for computational efficiency return {m.name: m.result() for m in self.metrics[:-1]} - def test_step(self, images_original): + # def call(self, inputs): + # # normalize images to have standard deviation of 1, like the noises + # images = inputs[self.input_map.input_name()] + # self.normalizer.update_state(images) + # images = self.normalizer(images, training=False) + + # control_embed = self.control_embed_model(inputs) + + # noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) + + # # sample uniform random diffusion times + # diffusion_times = tf.random.uniform( + # shape=[self.batch_size, ] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0 + # ) + # noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) + # # mix the images with noises accordingly + # noisy_images = signal_rates * images + noise_rates * noises + + # # use the network to separate noisy images to their components + # pred_noises, pred_images = self.denoise( + # control_embed, noisy_images, noise_rates, signal_rates, training=False + # ) + + # noise_loss = self.loss(noises, pred_noises) + # image_loss = self.loss(images, pred_images) + + # self.image_loss_tracker.update_state(image_loss) + # self.noise_loss_tracker.update_state(noise_loss) + + # # measure KID between real and generated images + # # this is computationally demanding, kid_diffusion_steps has to be small + # images = self.denormalize(images) + # generated_images = self.generate( + # control_embed, num_images=self.batch_size, diffusion_steps=20 + # ) + # return generated_images + + def test_step(self, batch): # normalize images to have standard deviation of 1, like the noises - images = images_original[0][self.tensor_map.input_name()] + images = batch[0][self.input_map.input_name()] self.normalizer.update_state(images) images = self.normalizer(images, training=False) - # images = images - tf.math.reduce_mean(images) / tf.math.reduce_std(images) - noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) + + control_embed = self.control_embed_model(batch[1]) + + noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) # sample uniform random diffusion times diffusion_times = tf.random.uniform( - shape=[self.batch_size] + [1] * self.tensor_map.axes(), minval=0.0, maxval=1.0, + shape=[self.batch_size] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0, ) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) # mix the images with noises accordingly @@ -852,7 +813,7 @@ def test_step(self, images_original): # use the network to separate noisy images to their components pred_noises, pred_images = self.denoise( - noisy_images, noise_rates, signal_rates, training=False, + control_embed, noisy_images, noise_rates, signal_rates, training=False, ) noise_loss = self.loss(noises, pred_noises) @@ -865,74 +826,61 @@ def test_step(self, images_original): # this is computationally demanding, kid_diffusion_steps has to be small images = self.denormalize(images) generated_images = self.generate( - num_images=self.batch_size, diffusion_steps=20, + control_embed, num_images=self.batch_size, diffusion_steps=20, ) # self.kid.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics} - def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, prefix='./figures/'): - # plot random generated images for visual evaluation of generation quality - generated_images = self.generate( - num_images=num_rows * num_cols, - diffusion_steps=plot_diffusion_steps, - ) - - plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) - for row in range(num_rows): - for col in range(num_cols): - index = row * num_cols + col - plt.subplot(num_rows, num_cols, index + 1) - if len(generated_images[index].shape) == 3 and generated_images[index].shape[-1] > 1: - plt.imshow(generated_images[index, ..., 0], cmap='gray') # just plot first frame - else: - plt.imshow(generated_images[index], cmap='gray') - plt.axis("off") - plt.tight_layout() - now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}') - if not os.path.exists(os.path.dirname(figure_path)): - os.makedirs(os.path.dirname(figure_path)) - plt.savefig(figure_path, bbox_inches="tight") + def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): + control_batch = {} + for cm in self.output_maps: + control_batch[cm.output_name()] = np.zeros((max(self.batch_size, num_rows * num_cols),) + cm.shape) + if 'Sex' in cm.name: + control_batch[cm.output_name()][:, 0] = 1 # all female - def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, prefix='./figures/'): + print(f'\nControl batch keys: {list(control_batch.keys())}') + control_embed = self.control_embed_model(control_batch) # plot random generated images for visual evaluation of generation quality generated_images = self.generate( + control_embed, num_images=max(self.batch_size, num_rows * num_cols), diffusion_steps=plot_diffusion_steps, + reseed=reseed, ) - plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) for row in range(num_rows): for col in range(num_cols): index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) - plt.plot(generated_images[index, ..., 0]) + plt.imshow(generated_images[index], cmap='gray') plt.axis("off") plt.tight_layout() now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'diffusion_ecg_generations_{now_string}{IMAGE_EXT}') + figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}') if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") + plt.close() def plot_reconstructions( - self, images_original, diffusion_amount=0, - epoch=None, logs=None, num_rows=3, num_cols=6, + self, batch, diffusion_amount=0, + epoch=None, logs=None, num_rows=4, num_cols=4, ): - images = images_original[0][self.tensor_map.input_name()] + images = batch[0][self.input_map.input_name()] self.normalizer.update_state(images) images = self.normalizer(images, training=False) - noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) - - diffusion_times = diffusion_amount * tf.ones(shape=[self.batch_size] + [1] * self.tensor_map.axes()) + noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) + diffusion_times = diffusion_amount * tf.ones(shape=[self.batch_size] + [1] * self.input_map.axes()) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) # mix the images with noises accordingly noisy_images = signal_rates * images + noise_rates * noises + control_embed = self.control_embed_model(batch[1]) + # use the network to separate noisy images to their components pred_noises, generated_images = self.denoise( - noisy_images, noise_rates, signal_rates, training=False, + control_embed, noisy_images, noise_rates, signal_rates, training=False, ) plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) for row in range(num_rows): @@ -945,39 +893,42 @@ def plot_reconstructions( plt.show() plt.close() - def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_cols=6): - images = images_original[0][self.tensor_map.input_name()] - noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape) - # reverse diffusion = sampling - num_images = images.shape[0] - step_size = 1.0 / max(0.0001, diffusion_steps) - - # important line: - # at the first sampling step, the "noisy image" is pure noise - # but its signal rate is assumed to be nonzero (min_signal_rate) - next_noisy_images = images * masks + noises * (1 - masks) - for step in range(diffusion_steps): - noisy_images = next_noisy_images + def control_plot_images( + self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, + renoise=None, + ): + control_embed = self.control_embed_model(control_batch) + # plot random generated images for visual evaluation of generation quality + generated_images = self.generate( + control_embed, + num_images=max(self.batch_size, num_rows * num_cols), + diffusion_steps=plot_diffusion_steps, + reseed=reseed, + renoise=renoise, + ) - # separate the current noisy image to its components - diffusion_times = tf.ones([num_images] + [1] * self.tensor_map.axes()) - step * step_size - noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) - pred_noises, generated_images = self.denoise( - noisy_images, noise_rates, signal_rates, training=False, - ) + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) + for row in range(num_rows): + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + plt.imshow(generated_images[index], cmap='gray') + plt.axis("off") + plt.tight_layout() + plt.show() + plt.close() - # apply the mask - generated_images = generated_images * (1 - masks) + images * masks + return generated_images - # remix the predicted components using the next signal and noise rates - next_diffusion_times = diffusion_times - step_size - next_noise_rates, next_signal_rates = self.diffusion_schedule( - next_diffusion_times, - ) - next_noisy_images = ( - next_signal_rates * generated_images + next_noise_rates * pred_noises - ) - # this new noisy image will be used in the next step + def control_plot_images_noise(self, control_batch, initial_noise, epoch=None, logs=None, num_rows=2, num_cols=8): + control_embed = self.control_embed_model(control_batch) + # plot random generated images for visual evaluation of generation quality + generated_images = self.generate_from_noise( + control_embed, + num_images=max(self.batch_size, num_rows * num_cols), + diffusion_steps=plot_diffusion_steps, + initial_noise=initial_noise, + ) plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) for row in range(num_rows): @@ -990,42 +941,66 @@ def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_c plt.show() plt.close() + return generated_images -class DiffusionBlock(Block): - def __init__( - self, - *, - tensor_map: TensorMap, - dense_blocks: List[int] = [32, 32, 32], - dense_layers: List[int] = [256], - batch_size: int = 16, - block_size: int = 3, - conv_x: int = 3, - activation: str = 'swish', - **kwargs, + def control_plot_ecgs( + self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, + renoise=None, ): - self.tensor_map = tensor_map - self.batch_size = batch_size - if not self.can_apply(): - return - - self.diffusion_model = DiffusionModel(tensor_map, dense_blocks, block_size, conv_x) - import tensorflow_addons as tfa - self.diffusion_model.compile( - optimizer=tfa.optimizers.AdamW( - learning_rate=learning_rate, weight_decay=weight_decay, - ), - loss=keras.losses.mean_absolute_error, + control_embed = self.control_embed_model(control_batch) + # plot random generated images for visual evaluation of generation quality + generated_images = self.generate( + control_embed, + num_images=max(self.batch_size, num_rows * num_cols), + diffusion_steps=plot_diffusion_steps, + reseed=reseed, + renoise=renoise, ) - def can_apply(self): - return self.tensor_map.axes() > 1 + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) + for row in range(num_rows): + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + for lead in range(generated_images.shape[-1]): + plt.plot(generated_images[index, :, lead], label=lead) + plt.axis("off") + plt.tight_layout() + plt.show() + plt.close() - def __call__(self, x: Tensor, intermediates: Dict[TensorMap, List[Tensor]] = None) -> Tensor: - if not self.can_apply(): - return x - times = tf.ones([self.batch_size]+[1]*self.tensor_map.axes()) - x = self.diffusion_model([x, times]) - #x = self.loss_layer(x) - intermediates[self.tensor_map].append(x) - return x + return generated_images + + def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): + control_batch = {} + for cm in self.output_maps: + control_batch[cm.output_name()] = np.zeros((max(self.batch_size, num_rows * num_cols),) + cm.shape) + if 'Sex' in cm.name: + control_batch[cm.output_name()][:, 0] = 1 # all female + + print(f'\nControl batch keys: {list(control_batch.keys())}') + control_embed = self.control_embed_model(control_batch) + + # plot random generated images for visual evaluation of generation quality + generated_images = self.generate( + control_embed, + num_images=max(self.batch_size, num_rows * num_cols), + diffusion_steps=plot_diffusion_steps, + reseed=reseed, + ) + logging.info(f'Generated ECGs with shape:{generated_images.shape}') + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) + for row in range(num_rows): + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + for lead in range(generated_images.shape[-1]): + plt.plot(generated_images[index, :, lead], label=lead) + plt.axis("off") + plt.tight_layout() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'diffusion_ecg_generations_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") + plt.close() diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 00065dcb8..979f61d3e 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -108,13 +108,14 @@ def _get_callbacks( def train_diffusion_model(args): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) - model = DiffusionModel(args.tensor_maps_in[0], args.batch_size, args.dense_blocks, args.block_size, args.conv_x) + model = DiffusionModel(args.tensor_maps_in[0], args.batch_size, args.dense_blocks, args.block_size, args.conv_x, + args.diffusion_loss, args.sigmoid_beta) model.compile( optimizer=tfa.optimizers.AdamW( learning_rate=args.learning_rate, weight_decay=1e-4, ), - loss=keras.losses.mean_absolute_error, + loss=keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error, ) batch = next(generate_train) for k in batch[0]: From baf60a37bdd76ef82116cea3dba81a221405691b Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 19 Dec 2024 14:08:58 -0500 Subject: [PATCH 091/119] sigmoid loss unconditioned --- ml4h/models/model_factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ml4h/models/model_factory.py b/ml4h/models/model_factory.py index eddc05fc7..ca97b7a0b 100755 --- a/ml4h/models/model_factory.py +++ b/ml4h/models/model_factory.py @@ -16,7 +16,6 @@ from ml4h.models.Block import Block from ml4h.TensorMap import TensorMap from ml4h.metrics import get_metric_dict -from ml4h.models.diffusion_blocks import DiffusionBlock from ml4h.optimizers import NON_KERAS_OPTIMIZERS, get_optimizer from ml4h.models.perceiver_blocks import PerceiverEncoder, PerceiverLatentLayer from ml4h.models.layer_wrappers import ACTIVATION_FUNCTIONS, NORMALIZATION_CLASSES @@ -64,7 +63,6 @@ 'resnet_encoder': ResNetEncoder, 'movinet_encoder': MoviNetEncoder, 'bert_encoder': BertEncoder, - 'diffusion': DiffusionBlock, 'identity_decode': IdentityDecoderBlock, 'identity_encode': IdentityEncoderBlock, } From 2df321c7898ead35411bce0be5ac9df027a243d6 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 19 Dec 2024 14:20:36 -0500 Subject: [PATCH 092/119] sigmoid loss unconditioned --- ml4h/models/diffusion_blocks.py | 10 ++++++++-- ml4h/models/train.py | 22 +++++++++++++++------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index ac15a9c05..d17fc19be 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -374,8 +374,10 @@ def reverse_diffusion(self, initial_noise, diffusion_steps): return pred_images - def generate(self, num_images, diffusion_steps): + def generate(self, num_images, diffusion_steps, reseed=None): # noise -> images -> denormalized images + if reseed is not None: + tf.random.set_seed(reseed) initial_noise = tf.random.normal(shape=(num_images,) + self.tensor_map.shape) generated_images = self.reverse_diffusion(initial_noise, diffusion_steps) generated_images = self.denormalize(generated_images) @@ -474,11 +476,12 @@ def test_step(self, images_original): return {m.name: m.result() for m in self.metrics} - def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, prefix='./figures/'): + def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, reseed=None, prefix='./figures/'): # plot random generated images for visual evaluation of generation quality generated_images = self.generate( num_images=num_rows * num_cols, diffusion_steps=plot_diffusion_steps, + reseed=reseed, ) plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) @@ -497,12 +500,14 @@ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, prefix='./f if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") + plt.close() def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, prefix='./figures/'): # plot random generated images for visual evaluation of generation quality generated_images = self.generate( num_images=max(self.batch_size, num_rows * num_cols), diffusion_steps=plot_diffusion_steps, + reseed=reseed, ) plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) @@ -518,6 +523,7 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") + plt.close() def plot_reconstructions( self, images_original, diffusion_amount=0, diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 979f61d3e..18a9c3fe0 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -124,6 +124,11 @@ def train_diffusion_model(args): for k in batch[1]: logging.info(f"label {k} {batch[1][k].shape}") checkpoint_path = f"{args.output_folder}{args.id}/{args.id}" + if os.path.exists(checkpoint_path+'.index'): + model.load_weights(checkpoint_path) + logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}') + else: + logging.info(f'No checkpoint at: {checkpoint_path}') checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, save_weights_only=True, @@ -132,8 +137,11 @@ def train_diffusion_model(args): save_best_only=True, ) + callbacks = [checkpoint_callback] + # calculate mean and variance of training dataset for normalization model.normalizer.adapt(feature_batch) + if args.inspect_model: model.network.summary(print_fn=logging.info, expand_nested=True) tf.keras.utils.plot_model( @@ -148,12 +156,12 @@ def train_diffusion_model(args): layer_range=None, show_layer_activations=False, ) - - if os.path.exists(checkpoint_path+'.index'): - model.load_weights(checkpoint_path) - logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}') - else: - logging.info(f'No checkpoint at: {checkpoint_path}') + prefix_value = f'{args.output_folder}{args.id}/learning_generations/' + if model.tensor_map.axes() == 2: + plot_partial = partial(model.plot_ecgs, reseed=args.random_seed, prefix=prefix_value) + else: + plot_partial = partial(model.plot_images, reseed=args.random_seed, prefix=prefix_value) + callbacks.append(keras.callbacks.LambdaCallback(on_epoch_end=plot_partial)) history = model.fit( generate_train, @@ -161,7 +169,7 @@ def train_diffusion_model(args): epochs=args.epochs, validation_data=generate_valid, validation_steps=args.validation_steps, - callbacks=[checkpoint_callback], + callbacks=callbacks, ) model.load_weights(checkpoint_path) #diffusion_model.compile(optimizer='adam', loss='mse') From bb5feaaf3ee64df34a544983c8d38d07e18952f7 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 19 Dec 2024 14:34:59 -0500 Subject: [PATCH 093/119] sigmoid loss unconditioned --- ml4h/models/diffusion_blocks.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index d17fc19be..b5bfcb088 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -476,7 +476,7 @@ def test_step(self, images_original): return {m.name: m.result() for m in self.metrics} - def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, reseed=None, prefix='./figures/'): + def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): # plot random generated images for visual evaluation of generation quality generated_images = self.generate( num_images=num_rows * num_cols, @@ -502,7 +502,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, reseed=None plt.savefig(figure_path, bbox_inches="tight") plt.close() - def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, prefix='./figures/'): + def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): # plot random generated images for visual evaluation of generation quality generated_images = self.generate( num_images=max(self.batch_size, num_rows * num_cols), @@ -515,8 +515,10 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, for col in range(num_cols): index = row * num_cols + col plt.subplot(num_rows, num_cols, index + 1) - plt.plot(generated_images[index, ..., 0]) + for lead in range(generated_images.shape[-1]): + plt.plot(generated_images[index, :, lead], label=lead) plt.axis("off") + plt.tight_layout() now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') figure_path = os.path.join(prefix, f'diffusion_ecg_generations_{now_string}{IMAGE_EXT}') From 44807ae26e2e24394de0e3891252e74462ae497d Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 19 Dec 2024 14:42:29 -0500 Subject: [PATCH 094/119] sigmoid loss unconditioned --- ml4h/models/diffusion_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index b5bfcb088..6da5d3e3d 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -518,7 +518,7 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, for lead in range(generated_images.shape[-1]): plt.plot(generated_images[index, :, lead], label=lead) plt.axis("off") - + plt.tight_layout() now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') figure_path = os.path.join(prefix, f'diffusion_ecg_generations_{now_string}{IMAGE_EXT}') From 6b3ddfc3c6f7a9c1dc6e80e4256157ff674763c4 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 20 Dec 2024 09:49:24 -0500 Subject: [PATCH 095/119] sigmoid loss unconditioned --- ml4h/tensormap/ukb/mri_brain.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index 45cf6d152..da766da2b 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -465,3 +465,14 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): normalization=ZeroMeanStd1(), dependent_map=axial_index_map, ) + +t2_flair_random_slice = TensorMap( + 't1_mni_random_slice', + Interpretation.CONTINUOUS, + shape=(192, 256, 1), + path_prefix='ukb_brain_mri/T2_FLAIR_orig_defaced/', + tensor_from_file=random_mni_slice_tensor, + normalization=ZeroMeanStd1(), + dependent_map=axial_index_map, +) + From 9bf044a13509c3389a7c8e56cc37887a2282bc8b Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 20 Dec 2024 12:48:59 -0500 Subject: [PATCH 096/119] sigmoid loss unconditioned --- ml4h/tensormap/ukb/mri_brain.py | 45 ++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index da766da2b..a518d2960 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -439,20 +439,21 @@ def _masked_brain_tensor(tm, hd5, dependents={}): tensor_from_file=_mni_label_masked({'Left_accumbens': 19, 'Right_accumbens': 70}), normalization=ZeroMeanStd1(), ) - -def random_mni_slice_tensor(tm, hd5, dependents={}): - slice_index = np.random.randint(182) - tensor = pad_or_crop_array_to_shape( - tm.shape, np.array( - tm.hd5_first_dataset_in_group(hd5, f'{tm.path_prefix}axial_{slice_index}/'), dtype=np.float32, - ), - ) - dependents[tm.dependent_map] = np.zeros( - tm.dependent_map.shape, - dtype=np.float32, - ) - dependents[tm.dependent_map][0] = float(slice_index) / 182.0 - return tensor +def _random_slice_bounded(low=0, high=182): + def random_mni_slice_tensor(tm, hd5, dependents={}): + slice_index = np.random.randint(low, high) + tensor = pad_or_crop_array_to_shape( + tm.shape, np.array( + tm.hd5_first_dataset_in_group(hd5, f'{tm.path_prefix}axial_{slice_index}/'), dtype=np.float32, + ), + ) + dependents[tm.dependent_map] = np.zeros( + tm.dependent_map.shape, + dtype=np.float32, + ) + dependents[tm.dependent_map][0] = float(slice_index) / 182.0 + return tensor + return random_mni_slice_tensor axial_index_map = TensorMap('axial_index', Interpretation.CONTINUOUS, shape=(1,), channel_map={'axial_index':0}) @@ -461,17 +462,27 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): Interpretation.CONTINUOUS, shape=(192, 192, 1), path_prefix='ukb_brain_mri/T1_brain_to_MNI/', - tensor_from_file=random_mni_slice_tensor, + tensor_from_file=_random_slice_bounded(), + normalization=ZeroMeanStd1(), + dependent_map=axial_index_map, +) + +t1_random_slice = TensorMap( + 't1_random_slice', + Interpretation.CONTINUOUS, + shape=(224, 256, 1), + path_prefix='ukb_brain_mri/T1/', + tensor_from_file=_random_slice_bounded(6, 200), normalization=ZeroMeanStd1(), dependent_map=axial_index_map, ) t2_flair_random_slice = TensorMap( - 't1_mni_random_slice', + 't2_flair_random_slice', Interpretation.CONTINUOUS, shape=(192, 256, 1), path_prefix='ukb_brain_mri/T2_FLAIR_orig_defaced/', - tensor_from_file=random_mni_slice_tensor, + tensor_from_file=_random_slice_bounded(6, 200), normalization=ZeroMeanStd1(), dependent_map=axial_index_map, ) From f95ee8623a30e9fbc791c632ade2627f553172eb Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 2 Jan 2025 12:22:32 -0500 Subject: [PATCH 097/119] sigmoid loss unconditioned --- ml4h/tensormap/ukb/mri_brain.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ml4h/tensormap/ukb/mri_brain.py b/ml4h/tensormap/ukb/mri_brain.py index a518d2960..6c19ecc88 100755 --- a/ml4h/tensormap/ukb/mri_brain.py +++ b/ml4h/tensormap/ukb/mri_brain.py @@ -477,6 +477,16 @@ def random_mni_slice_tensor(tm, hd5, dependents={}): dependent_map=axial_index_map, ) +t1_random_slice_256 = TensorMap( + 't1_random_slice_256', + Interpretation.CONTINUOUS, + shape=(256, 256, 1), + path_prefix='ukb_brain_mri/T1/', + tensor_from_file=_random_slice_bounded(16, 192), + normalization=ZeroMeanStd1(), + dependent_map=axial_index_map, +) + t2_flair_random_slice = TensorMap( 't2_flair_random_slice', Interpretation.CONTINUOUS, From 1d402707a74a98492a79ddb59653384d7ae08a5c Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 3 Jan 2025 12:02:32 -0500 Subject: [PATCH 098/119] sigmoid loss unconditioned --- ml4h/metrics.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/ml4h/metrics.py b/ml4h/metrics.py index d4444f174..d0349024b 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -1,5 +1,7 @@ # metrics.py import logging + +import keras import numpy as np import tensorflow as tf import tensorflow.keras.backend as K @@ -714,3 +716,67 @@ def concordance_index_censored(event_indicator, event_time, estimate, tied_tol=1 """ w = np.ones_like(estimate) return _estimate_concordance_index(event_indicator, event_time, estimate, w, tied_tol) + + +class KernelInceptionDistance(keras.metrics.Metric): + def __init__(self, name, input_shape, kernel_image_size, **kwargs): + super().__init__(name=name, **kwargs) + + # KID is estimated per batch and is averaged across batches + self.kid_tracker = keras.metrics.Mean(name="kid_tracker") + + # a pretrained InceptionV3 is used without its classification layer + # transform the pixel values to the 0-255 range, then use the same + # preprocessing as during pretraining + self.encoder = keras.Sequential( + [ + keras.Input(shape=input_shape), # TODO: handle multi-channel + layers.Lambda(lambda x: tf.tile(x, [1, 1, 1, 3])), + layers.Rescaling(255.0), + layers.Resizing(height=kernel_image_size, width=kernel_image_size), + layers.Lambda(keras.applications.inception_v3.preprocess_input), + keras.applications.InceptionV3( + include_top=False, + input_shape=(kernel_image_size, kernel_image_size, 3), + weights="imagenet", + ), + layers.GlobalAveragePooling2D(), + ], + name="inception_encoder", + ) + + def polynomial_kernel(self, features_1, features_2): + feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32) + return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0 + + def update_state(self, real_images, generated_images, sample_weight=None): + real_features = self.encoder(real_images, training=False) + generated_features = self.encoder(generated_images, training=False) + + # compute polynomial kernels using the two sets of features + kernel_real = self.polynomial_kernel(real_features, real_features) + kernel_generated = self.polynomial_kernel( + generated_features, generated_features + ) + kernel_cross = self.polynomial_kernel(real_features, generated_features) + + # estimate the squared maximum mean discrepancy using the average kernel values + batch_size = tf.shape(real_features)[0] + batch_size_f = tf.cast(batch_size, dtype=tf.float32) + mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / ( + batch_size_f * (batch_size_f - 1.0) + ) + mean_kernel_generated = tf.reduce_sum( + kernel_generated * (1.0 - tf.eye(batch_size)) + ) / (batch_size_f * (batch_size_f - 1.0)) + mean_kernel_cross = tf.reduce_mean(kernel_cross) + kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross + + # update the average KID estimate + self.kid_tracker.update_state(kid) + + def result(self): + return self.kid_tracker.result() + + def reset_state(self): + self.kid_tracker.reset_state() From 2f0602d7c960dc9963bf24de05022d5f42b43225 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 3 Jan 2025 12:08:21 -0500 Subject: [PATCH 099/119] sigmoid loss unconditioned --- ml4h/models/diffusion_blocks.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 6da5d3e3d..32cda7659 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -12,6 +12,7 @@ from keras import layers from ml4h.defines import IMAGE_EXT +from ml4h.metrics import KernelInceptionDistance from ml4h.models.Block import Block from ml4h.TensorMap import TensorMap @@ -302,11 +303,11 @@ def compile(self, **kwargs): self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") self.image_loss_tracker = keras.metrics.Mean(name="i_loss") - # self.kid = KID(name = "kid", input_shape = self.tensor_map.shape) + self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75) @property def metrics(self): - return [self.noise_loss_tracker, self.image_loss_tracker] + return [self.noise_loss_tracker, self.image_loss_tracker, self.kid] def denormalize(self, images): # convert the pixel values back to 0-1 range @@ -468,11 +469,11 @@ def test_step(self, images_original): # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small - # images = self.denormalize(images) - # generated_images = self.generate( - # num_images=self.batch_size, diffusion_steps=20, - # ) - # self.kid.update_state(images, generated_images) + images = self.denormalize(images) + generated_images = self.generate( + num_images=self.batch_size, diffusion_steps=20 + ) + self.kid.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics} From a5d518dc6bc3b669d48688252e80f7c656ea73ac Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 3 Jan 2025 12:12:10 -0500 Subject: [PATCH 100/119] sigmoid loss unconditioned --- ml4h/metrics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ml4h/metrics.py b/ml4h/metrics.py index d0349024b..7e2247114 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -731,16 +731,16 @@ def __init__(self, name, input_shape, kernel_image_size, **kwargs): self.encoder = keras.Sequential( [ keras.Input(shape=input_shape), # TODO: handle multi-channel - layers.Lambda(lambda x: tf.tile(x, [1, 1, 1, 3])), - layers.Rescaling(255.0), - layers.Resizing(height=kernel_image_size, width=kernel_image_size), - layers.Lambda(keras.applications.inception_v3.preprocess_input), + keras.layers.Lambda(lambda x: tf.tile(x, [1, 1, 1, 3])), + keras.layers.Rescaling(255.0), + keras.layers.Resizing(height=kernel_image_size, width=kernel_image_size), + keras.layers.Lambda(keras.applications.inception_v3.preprocess_input), keras.applications.InceptionV3( include_top=False, input_shape=(kernel_image_size, kernel_image_size, 3), weights="imagenet", ), - layers.GlobalAveragePooling2D(), + keras.layers.GlobalAveragePooling2D(), ], name="inception_encoder", ) From ff81dee23880e3992aa64f5e604a209ccfef91ee Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 3 Jan 2025 12:32:43 -0500 Subject: [PATCH 101/119] kernel inception distance --- ml4h/models/diffusion_blocks.py | 54 ++++++++++++++++++++++++++++----- ml4h/models/train.py | 20 ++++++++---- ml4h/recipes.py | 2 ++ 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 32cda7659..74f4e04cd 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -303,11 +303,15 @@ def compile(self, **kwargs): self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") self.image_loss_tracker = keras.metrics.Mean(name="i_loss") - self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75) + if self.tensor_map.axes() == 3: + self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75) @property def metrics(self): - return [self.noise_loss_tracker, self.image_loss_tracker, self.kid] + m = [self.noise_loss_tracker, self.image_loss_tracker] + if self.tensor_map.axes() == 3: + m.append(self.kid) + return m def denormalize(self, images): # convert the pixel values back to 0-1 range @@ -469,11 +473,12 @@ def test_step(self, images_original): # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small - images = self.denormalize(images) - generated_images = self.generate( - num_images=self.batch_size, diffusion_steps=20 - ) - self.kid.update_state(images, generated_images) + if self.tensor_map.axes() == 3: + images = self.denormalize(images) + generated_images = self.generate( + num_images=self.batch_size, diffusion_steps=20 + ) + self.kid.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics} @@ -607,6 +612,7 @@ class DiffusionController(keras.Model): def __init__( self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, + supervisor = None, ): super().__init__() @@ -620,6 +626,7 @@ def __init__( self.ema_network = keras.models.clone_model(self.network) self.use_sigmoid_loss = diffusion_loss == 'sigmoid' self.beta = sigmoid_beta + self.supervisor = supervisor def compile(self, **kwargs): @@ -627,11 +634,16 @@ def compile(self, **kwargs): self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") self.image_loss_tracker = keras.metrics.Mean(name="i_loss") + if self.supervisor is not None: + self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss") # self.kid = KID(name = "kid", input_shape = self.tensor_map.shape) @property def metrics(self): - return [self.noise_loss_tracker, self.image_loss_tracker] + m = [self.noise_loss_tracker, self.image_loss_tracker] + if self.supervisor is not None: + m.append(self.supervised_loss_tracker) + return m def denormalize(self, images): # convert the pixel values back to 0-1 range @@ -751,6 +763,17 @@ def train_step(self, batch): lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared) weight = tf.math.sigmoid(self.beta - lambda_t) noise_loss = weight * noise_loss + if self.supervisor is not None: + loss_fn = tf.keras.losses.MeanSquaredError() + supervised_preds = self.supervisor(pred_images, training=True) + supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) + self.supervised_loss_tracker.update_state(supervised_loss) + # Combine losses: add noise_loss and supervised_loss + noise_loss += 0.01 * supervised_loss + + # Gradients for self.supervised_model + supervised_gradients = tape.gradient(supervised_loss, self.supervisor.trainable_weights) + self.optimizer.apply_gradients(zip(supervised_gradients, self.supervisor.trainable_weights)) gradients = tape.gradient(noise_loss, self.network.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) @@ -827,6 +850,21 @@ def test_step(self, batch): noise_loss = self.loss(noises, pred_noises) image_loss = self.loss(images, pred_images) + if self.use_sigmoid_loss: + signal_rates_squared = tf.square(signal_rates) + noise_rates_squared = tf.square(noise_rates) + + # Compute log-SNR (lambda_t) + lambda_t = tf.math.log(signal_rates_squared / noise_rates_squared) + weight = tf.math.sigmoid(self.beta - lambda_t) + noise_loss = weight * noise_loss + if self.supervisor is not None: + loss_fn = tf.keras.losses.MeanSquaredError() + supervised_preds = self.supervisor(pred_images, training=True) + supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) + self.supervised_loss_tracker.update_state(supervised_loss) + # Combine losses: add noise_loss and supervised_loss + noise_loss += 0.01*supervised_loss self.image_loss_tracker.update_state(image_loss) self.noise_loss_tracker.update_state(noise_loss) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 18a9c3fe0..9c9f4f3a9 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -279,13 +279,21 @@ def interpolate_controlled_generations(diffuser, tensor_maps_out, control_tm, ba plt.close() -def train_diffusion_control_model(args): +def train_diffusion_control_model(args, supervised=False): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) - model = DiffusionController( - args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, - args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, - args.sigmoid_beta, args.diffusion_condition_strategy, - ) + if supervised: + supervised_model, _, _, _ = make_multimodal_multitask_model(**args.__dict__) + model = DiffusionController( + args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, + args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, + args.sigmoid_beta, args.diffusion_condition_strategy, supervised_model, + ) + else: + model = DiffusionController( + args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, + args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, + args.sigmoid_beta, args.diffusion_condition_strategy, + ) loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error model.compile( diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 753b27ad4..4642303b7 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -116,6 +116,8 @@ def run(args): train_diffusion_model(args) elif 'train_diffusion_control' == args.mode: train_diffusion_control_model(args) + elif 'train_diffusion_supervised' == args.mode: + train_diffusion_control_model(args, supervised=True) elif 'train_siamese' == args.mode: train_siamese_model(args) elif 'write_tensor_maps' == args.mode: From 78279094919608dfa27800c26433ab3d3f020b11 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 3 Jan 2025 12:39:32 -0500 Subject: [PATCH 102/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 74f4e04cd..ec30425b7 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -747,7 +747,7 @@ def train_step(self, batch): print(f'noises.shape {noises.shape} images.shape {images.shape}') noisy_images = signal_rates * images + noise_rates * noises - with tf.GradientTape() as tape: + with tf.GradientTape(persistent=True) if self.supervisor else tf.GradientTape() as tape: # train the network to separate noisy images to their components pred_noises, pred_images = self.denoise( control_embed, noisy_images, noise_rates, signal_rates, training=True, From 3229037863be631d04dc396e9dfcba34e4559711 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 8 Jan 2025 10:28:06 -0500 Subject: [PATCH 103/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 2 +- ml4h/plots.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index ec30425b7..1f57be313 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -786,7 +786,7 @@ def train_step(self, batch): ema_weight.assign(ema * ema_weight + (1 - ema) * weight) # KID is not measured during the training phase for computational efficiency - return {m.name: m.result() for m in self.metrics[:-1]} + return {m.name: m.result() for m in self.metrics} # def call(self, inputs): # # normalize images to have standard deviation of 1, like the noises diff --git a/ml4h/plots.py b/ml4h/plots.py index b8285f3b8..29f47538f 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -439,7 +439,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu ) for k in sorted(history.history.keys()): - if not k.startswith("val_"): + if not k.startswith("val_") or k == 'val_supervised_loss': if isinstance(history.history[k][0], LearningRateSchedule): history.history[k] = [ history.history[k][0](i * training_steps) @@ -470,7 +470,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path) - for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss']: + for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss', 'val_supervised_loss']: if log_label not in history.history: continue logging.info( From c9c16ddd8a4a4d2d09c61223ab2c10260d346ce4 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Wed, 8 Jan 2025 11:32:06 -0500 Subject: [PATCH 104/119] condition and supervise --- ml4h/plots.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index 29f47538f..5b2a9afe4 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -439,7 +439,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu ) for k in sorted(history.history.keys()): - if not k.startswith("val_") or k == 'val_supervised_loss': + if not k.startswith("val_") or k in ['val_kid', 'val_supervised_loss']: if isinstance(history.history[k][0], LearningRateSchedule): history.history[k] = [ history.history[k][0](i * training_steps) @@ -470,7 +470,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path) - for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss', 'val_supervised_loss']: + for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss', 'val_supervised_loss', 'val_kid']: if log_label not in history.history: continue logging.info( From 1632042747bb22a8f8db3984400726d4ccb50a0d Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Thu, 9 Jan 2025 12:57:35 -0500 Subject: [PATCH 105/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 1f57be313..ced212d76 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -303,12 +303,14 @@ def compile(self, **kwargs): self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") self.image_loss_tracker = keras.metrics.Mean(name="i_loss") + self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse") + self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae") if self.tensor_map.axes() == 3: self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75) @property def metrics(self): - m = [self.noise_loss_tracker, self.image_loss_tracker] + m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric] if self.tensor_map.axes() == 3: m.append(self.kid) return m @@ -428,13 +430,15 @@ def train_step(self, images_original): self.noise_loss_tracker.update_state(noise_loss) self.image_loss_tracker.update_state(image_loss) + self.mse_metric.update_state(noises, pred_noises) + self.mae_metric.update_state(noises, pred_noises) # track the exponential moving averages of weights for weight, ema_weight in zip(self.network.weights, self.ema_network.weights): ema_weight.assign(ema * ema_weight + (1 - ema) * weight) # KID is not measured during the training phase for computational efficiency - return {m.name: m.result() for m in self.metrics[:-1]} + return {m.name: m.result() for m in self.metrics} def test_step(self, images_original): # normalize images to have standard deviation of 1, like the noises @@ -470,6 +474,8 @@ def test_step(self, images_original): self.image_loss_tracker.update_state(image_loss) self.noise_loss_tracker.update_state(noise_loss) + self.mse_metric.update_state(noises, pred_noises) + self.mae_metric.update_state(noises, pred_noises) # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small @@ -631,16 +637,17 @@ def __init__( def compile(self, **kwargs): super().compile(**kwargs) - self.noise_loss_tracker = keras.metrics.Mean(name="n_loss") self.image_loss_tracker = keras.metrics.Mean(name="i_loss") + self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse") + self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae") if self.supervisor is not None: self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss") # self.kid = KID(name = "kid", input_shape = self.tensor_map.shape) @property def metrics(self): - m = [self.noise_loss_tracker, self.image_loss_tracker] + m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric] if self.supervisor is not None: m.append(self.supervised_loss_tracker) return m From 1f6bc031c674f403121e56c398f0a4304182134c Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 10 Jan 2025 13:33:40 -0500 Subject: [PATCH 106/119] condition and supervise --- ml4h/arguments.py | 4 +++ ml4h/models/diffusion_blocks.py | 48 ++++++--------------------------- ml4h/models/train.py | 2 +- ml4h/recipes.py | 2 +- 4 files changed, 14 insertions(+), 42 deletions(-) diff --git a/ml4h/arguments.py b/ml4h/arguments.py index f1111f3c0..525a05c44 100755 --- a/ml4h/arguments.py +++ b/ml4h/arguments.py @@ -244,6 +244,10 @@ def parse_args(): '--sigmoid_beta', default=-3, type=float, help='Beta to use with sigmoid loss for diffusion models.', ) + parser.add_argument( + '--supervision_scalar', default=0.01, type=float, + help='For `train_diffusion_supervise` mode, this weights the supervision loss from phenotype prediction on denoised data.', + ) parser.add_argument( '--transformer_size', default=32, type=int, help='Number of output neurons in Transformer encoders and decoders, ' diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index ced212d76..c1c7fde67 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -618,7 +618,7 @@ class DiffusionController(keras.Model): def __init__( self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, - supervisor = None, + supervisor = None, supervision_scalar = 0.01, ): super().__init__() @@ -633,6 +633,7 @@ def __init__( self.use_sigmoid_loss = diffusion_loss == 'sigmoid' self.beta = sigmoid_beta self.supervisor = supervisor + self.supervision_scalar = supervision_scalar def compile(self, **kwargs): @@ -776,7 +777,7 @@ def train_step(self, batch): supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) self.supervised_loss_tracker.update_state(supervised_loss) # Combine losses: add noise_loss and supervised_loss - noise_loss += 0.01 * supervised_loss + noise_loss += self.supervision_scalar * supervised_loss # Gradients for self.supervised_model supervised_gradients = tape.gradient(supervised_loss, self.supervisor.trainable_weights) @@ -787,6 +788,8 @@ def train_step(self, batch): self.noise_loss_tracker.update_state(noise_loss) self.image_loss_tracker.update_state(image_loss) + self.mse_metric.update_state(noises, pred_noises) + self.mae_metric.update_state(noises, pred_noises) # track the exponential moving averages of weights for weight, ema_weight in zip(self.network.weights, self.ema_network.weights): @@ -795,43 +798,6 @@ def train_step(self, batch): # KID is not measured during the training phase for computational efficiency return {m.name: m.result() for m in self.metrics} - # def call(self, inputs): - # # normalize images to have standard deviation of 1, like the noises - # images = inputs[self.input_map.input_name()] - # self.normalizer.update_state(images) - # images = self.normalizer(images, training=False) - - # control_embed = self.control_embed_model(inputs) - - # noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape) - - # # sample uniform random diffusion times - # diffusion_times = tf.random.uniform( - # shape=[self.batch_size, ] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0 - # ) - # noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) - # # mix the images with noises accordingly - # noisy_images = signal_rates * images + noise_rates * noises - - # # use the network to separate noisy images to their components - # pred_noises, pred_images = self.denoise( - # control_embed, noisy_images, noise_rates, signal_rates, training=False - # ) - - # noise_loss = self.loss(noises, pred_noises) - # image_loss = self.loss(images, pred_images) - - # self.image_loss_tracker.update_state(image_loss) - # self.noise_loss_tracker.update_state(noise_loss) - - # # measure KID between real and generated images - # # this is computationally demanding, kid_diffusion_steps has to be small - # images = self.denormalize(images) - # generated_images = self.generate( - # control_embed, num_images=self.batch_size, diffusion_steps=20 - # ) - # return generated_images - def test_step(self, batch): # normalize images to have standard deviation of 1, like the noises images = batch[0][self.input_map.input_name()] @@ -871,10 +837,12 @@ def test_step(self, batch): supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) self.supervised_loss_tracker.update_state(supervised_loss) # Combine losses: add noise_loss and supervised_loss - noise_loss += 0.01*supervised_loss + noise_loss += self.supervision_scalar*supervised_loss self.image_loss_tracker.update_state(image_loss) self.noise_loss_tracker.update_state(noise_loss) + self.mse_metric.update_state(noises, pred_noises) + self.mae_metric.update_state(noises, pred_noises) # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 9c9f4f3a9..9380e52a9 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -286,7 +286,7 @@ def train_diffusion_control_model(args, supervised=False): model = DiffusionController( args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, - args.sigmoid_beta, args.diffusion_condition_strategy, supervised_model, + args.sigmoid_beta, args.diffusion_condition_strategy, supervised_model, args.supervision_scalar, ) else: model = DiffusionController( diff --git a/ml4h/recipes.py b/ml4h/recipes.py index 4642303b7..928f6054e 100755 --- a/ml4h/recipes.py +++ b/ml4h/recipes.py @@ -116,7 +116,7 @@ def run(args): train_diffusion_model(args) elif 'train_diffusion_control' == args.mode: train_diffusion_control_model(args) - elif 'train_diffusion_supervised' == args.mode: + elif 'train_diffusion_supervise' == args.mode: train_diffusion_control_model(args, supervised=True) elif 'train_siamese' == args.mode: train_siamese_model(args) From d2566d7dcc5323052fb6afbfef335f87222b4ebd Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Sat, 11 Jan 2025 14:45:11 -0500 Subject: [PATCH 107/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 10 ++++++++-- ml4h/plots.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index c1c7fde67..3d65943e6 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -772,7 +772,10 @@ def train_step(self, batch): weight = tf.math.sigmoid(self.beta - lambda_t) noise_loss = weight * noise_loss if self.supervisor is not None: - loss_fn = tf.keras.losses.MeanSquaredError() + if self.output_maps[0].is_categorical(): + loss_fn = tf.keras.losses.CategoricalCrossentropy() + else: + loss_fn = tf.keras.losses.MeanSquaredError() supervised_preds = self.supervisor(pred_images, training=True) supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) self.supervised_loss_tracker.update_state(supervised_loss) @@ -832,7 +835,10 @@ def test_step(self, batch): weight = tf.math.sigmoid(self.beta - lambda_t) noise_loss = weight * noise_loss if self.supervisor is not None: - loss_fn = tf.keras.losses.MeanSquaredError() + if self.output_maps[0].is_categorical(): + loss_fn = tf.keras.losses.CategoricalCrossentropy() + else: + loss_fn = tf.keras.losses.MeanSquaredError() supervised_preds = self.supervisor(pred_images, training=True) supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds) self.supervised_loss_tracker.update_state(supervised_loss) diff --git a/ml4h/plots.py b/ml4h/plots.py index 5b2a9afe4..72ac552ea 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -470,7 +470,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path) - for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss', 'val_supervised_loss', 'val_kid']: + for log_label in ['loss', 'mse', 'mae', 'val_loss', 'n_loss', 'val_n_loss', 'val_supervised_loss', 'val_kid']: if log_label not in history.history: continue logging.info( From 2617ff7bbbbe7bfe71ac3fbd017202c62604b4a5 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Sun, 12 Jan 2025 15:36:50 -0500 Subject: [PATCH 108/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 3d65943e6..0ab00d1f6 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -260,7 +260,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and ((len(widths) - 1) - i) % attention_modulo == 0: if len(input_shape) > 2: - c2 = upsample(size=x.shape[1:-1]*2)(control[control_idxs]) + c2 = upsample(size=(x.shape[1]*2, x.shape[2]*2))(control[control_idxs]) else: c2 = upsample(size=x.shape[-2]*2)(control[control_idxs]) x = up_block_control(width, block_depth, conv, upsample, From 6a90d2652fa7577ed3f4dfb5af0e30de139ef4b9 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 11:59:59 -0500 Subject: [PATCH 109/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 11 +++++++---- ml4h/models/train.py | 6 +++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 0ab00d1f6..fcb8126cc 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -259,7 +259,7 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s for i, width in enumerate(reversed(widths[:-1])): if attention_modulo > 1 and ((len(widths) - 1) - i) % attention_modulo == 0: - if len(input_shape) > 2: + if len(input_shape) == 3: c2 = upsample(size=(x.shape[1]*2, x.shape[2]*2))(control[control_idxs]) else: c2 = upsample(size=x.shape[-2]*2)(control[control_idxs]) @@ -540,8 +540,7 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, plt.close() def plot_reconstructions( - self, images_original, diffusion_amount=0, - epoch=None, logs=None, num_rows=3, num_cols=6, + self, images_original, diffusion_amount=0, epoch=None, logs=None, num_rows=4, num_cols=4, prefix='./figures/', ): images = images_original[0][self.tensor_map.input_name()] self.normalizer.update_state(images) @@ -565,7 +564,11 @@ def plot_reconstructions( plt.imshow(generated_images[index], cmap='gray') plt.axis("off") plt.tight_layout() - plt.show() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'diffusion_reconstructions_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") plt.close() def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_cols=6): diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 9380e52a9..612c40c8f 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -10,6 +10,7 @@ from datetime import datetime import tensorflow as tf +from ml4h.explorations import predictions_to_pngs from tensorflow import keras import tensorflow_addons as tfa from tensorflow.keras.callbacks import History @@ -23,7 +24,7 @@ from ml4h.defines import IMAGE_EXT, MODEL_EXT from ml4h.models.inspect import plot_and_time_model from ml4h.models.model_factory import get_custom_objects, make_multimodal_multitask_model -from ml4h.tensor_generators import test_train_valid_tensor_generators +from ml4h.tensor_generators import test_train_valid_tensor_generators, big_batch_from_minibatch_generator def train_model_from_generators( @@ -372,6 +373,8 @@ def train_diffusion_control_model(args, supervised=False): model.load_weights(checkpoint_path) if args.inspect_model: + data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) + predictions_to_pngs(data, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: @@ -386,4 +389,5 @@ def train_diffusion_control_model(args, supervised=False): regress_on_controlled_generations(model, eval_model, tm_out, args.test_steps, args.batch_size, 0.0,2.0,f'{args.output_folder}/{args.id}/') + return model From 44ac1833ae1c82bce10b57b99b11116e5748cdc6 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:02:52 -0500 Subject: [PATCH 110/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 8 ++++++-- ml4h/models/train.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index fcb8126cc..cc4f35ff2 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -896,7 +896,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None def plot_reconstructions( self, batch, diffusion_amount=0, - epoch=None, logs=None, num_rows=4, num_cols=4, + epoch=None, logs=None, num_rows=4, num_cols=4, prefix='./figures/', ): images = batch[0][self.input_map.input_name()] self.normalizer.update_state(images) @@ -921,7 +921,11 @@ def plot_reconstructions( plt.imshow(generated_images[index], cmap='gray') plt.axis("off") plt.tight_layout() - plt.show() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") plt.close() def control_plot_images( diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 612c40c8f..683565c75 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -375,6 +375,7 @@ def train_diffusion_control_model(args, supervised=False): if args.inspect_model: data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) predictions_to_pngs(data, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') + model.plot_reconstructions(data, prefix=f'{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: From 75e86e66c29f32a647cea874f4538137b04b3ce6 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:13:41 -0500 Subject: [PATCH 111/119] condition and supervise --- ml4h/models/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 683565c75..a40f2edb9 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -374,8 +374,10 @@ def train_diffusion_control_model(args, supervised=False): if args.inspect_model: data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) - predictions_to_pngs(data, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') + model.plot_reconstructions(data, prefix=f'{args.output_folder}/{args.id}/') + images = data[args.tensor_maps_in[0].input_name()] + predictions_to_pngs(images, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: From d65c3e94a78cc1ca6aa932a381882cfa929a03b6 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:17:06 -0500 Subject: [PATCH 112/119] condition and supervise --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index a40f2edb9..f5d041396 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -375,7 +375,7 @@ def train_diffusion_control_model(args, supervised=False): if args.inspect_model: data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) - model.plot_reconstructions(data, prefix=f'{args.output_folder}/{args.id}/') + model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') images = data[args.tensor_maps_in[0].input_name()] predictions_to_pngs(images, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, From 0a32e0bd65a33b9df77002382a0ea3f2f44edfd1 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:22:17 -0500 Subject: [PATCH 113/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 17 ++++++++++++++++- ml4h/models/train.py | 2 -- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index cc4f35ff2..a7bf28fe2 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -922,12 +922,27 @@ def plot_reconstructions( plt.axis("off") plt.tight_layout() now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}') + figure_path = os.path.join(prefix, f'diffusion_image_reconstructions_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") + plt.close() + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) + for row in range(num_rows): + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + plt.imshow(images[index], cmap='gray') + plt.axis("off") + plt.tight_layout() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}') if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") plt.close() + def control_plot_images( self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, renoise=None, diff --git a/ml4h/models/train.py b/ml4h/models/train.py index f5d041396..e4c03f752 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -376,8 +376,6 @@ def train_diffusion_control_model(args, supervised=False): data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') - images = data[args.tensor_maps_in[0].input_name()] - predictions_to_pngs(images, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: From 4f6f0620b299aab1141f17e4ada57d327ac06119 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:31:05 -0500 Subject: [PATCH 114/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 1 + ml4h/models/train.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index a7bf28fe2..780246e32 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -941,6 +941,7 @@ def plot_reconstructions( os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") plt.close() + return generated_images def control_plot_images( diff --git a/ml4h/models/train.py b/ml4h/models/train.py index e4c03f752..b6f910a95 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -375,7 +375,10 @@ def train_diffusion_control_model(args, supervised=False): if args.inspect_model: data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) - model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') + preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') + #images = data[args.tensor_maps_in[0].input_name()] + image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]} + predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: From 5cd6e08d2f3bd61ae1d73bb42a753fba55b72ce4 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:41:56 -0500 Subject: [PATCH 115/119] condition and supervise --- ml4h/models/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index b6f910a95..4da267824 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -374,9 +374,7 @@ def train_diffusion_control_model(args, supervised=False): if args.inspect_model: data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) - - preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') - #images = data[args.tensor_maps_in[0].input_name()] + preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/') image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]} predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, From 1ca98bee2ab3a185e21f373f1d308c833977421f Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:49:59 -0500 Subject: [PATCH 116/119] condition and supervise --- ml4h/models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 4da267824..fbb807a80 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -376,7 +376,7 @@ def train_diffusion_control_model(args, supervised=False): data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/') image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]} - predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/') + predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: From fc307465f4eafadd6a1a31d9bdc49531badd4815 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 13:52:18 -0500 Subject: [PATCH 117/119] condition and supervise --- ml4h/plots.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml4h/plots.py b/ml4h/plots.py index 72ac552ea..85185ef81 100755 --- a/ml4h/plots.py +++ b/ml4h/plots.py @@ -470,7 +470,8 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path) - for log_label in ['loss', 'mse', 'mae', 'val_loss', 'n_loss', 'val_n_loss', 'val_supervised_loss', 'val_kid']: + for log_label in ['loss', 'mse', 'mae', 'val_loss', 'val_mse', 'val_mae', 'n_loss', 'val_n_loss', + 'val_supervised_loss', 'val_kid']: if log_label not in history.history: continue logging.info( From efa75274fd14ec30d8ffbc1840d617c94f17c14d Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 14 Jan 2025 16:01:33 -0500 Subject: [PATCH 118/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 9 +++++---- ml4h/models/train.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 780246e32..cb03c41ba 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -284,7 +284,7 @@ def get_control_embed_model(output_maps, control_size): class DiffusionModel(keras.Model): - def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta): + def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta, inspect_model): super().__init__() self.tensor_map = tensor_map @@ -294,6 +294,7 @@ def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, dif self.ema_network = keras.models.clone_model(self.network) self.use_sigmoid_loss = diffusion_loss == 'sigmoid' self.beta = sigmoid_beta + self.inspect_model = inspect_model def can_apply(self): return self.tensor_map.axes() > 1 @@ -305,13 +306,13 @@ def compile(self, **kwargs): self.image_loss_tracker = keras.metrics.Mean(name="i_loss") self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse") self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae") - if self.tensor_map.axes() == 3: + if self.tensor_map.axes() == 3 and self.inspect_model: self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75) @property def metrics(self): m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric] - if self.tensor_map.axes() == 3: + if self.tensor_map.axes() == 3 and self.inspect_model: m.append(self.kid) return m @@ -479,7 +480,7 @@ def test_step(self, images_original): # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small - if self.tensor_map.axes() == 3: + if self.tensor_map.axes() == 3 and self.inspect_model: images = self.denormalize(images) generated_images = self.generate( num_images=self.batch_size, diffusion_steps=20 diff --git a/ml4h/models/train.py b/ml4h/models/train.py index fbb807a80..4f362d792 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -110,7 +110,7 @@ def _get_callbacks( def train_diffusion_model(args): generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__) model = DiffusionModel(args.tensor_maps_in[0], args.batch_size, args.dense_blocks, args.block_size, args.conv_x, - args.diffusion_loss, args.sigmoid_beta) + args.diffusion_loss, args.sigmoid_beta, args.inspect_model) model.compile( optimizer=tfa.optimizers.AdamW( @@ -173,9 +173,15 @@ def train_diffusion_model(args): callbacks=callbacks, ) model.load_weights(checkpoint_path) - #diffusion_model.compile(optimizer='adam', loss='mse') plot_metric_history(history, args.training_steps, args.id, os.path.dirname(checkpoint_path)) if args.inspect_model: + metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True) + logging.info(f'Test metrics: {metrics}') + + data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) + preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/') + image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]} + predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/') if model.tensor_map.axes() == 2: model.plot_ecgs(num_rows=4, prefix=os.path.dirname(checkpoint_path)) else: From 6d8d1d419e17d04158613763645d3939481f39cf Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Tue, 14 Jan 2025 16:12:59 -0500 Subject: [PATCH 119/119] condition and supervise --- ml4h/models/diffusion_blocks.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index cb03c41ba..55e9d3c48 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -571,6 +571,21 @@ def plot_reconstructions( os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") plt.close() + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) + for row in range(num_rows): + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + plt.imshow(images[index], cmap='gray') + plt.axis("off") + plt.tight_layout() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") + plt.close() + return generated_images def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_cols=6): images = images_original[0][self.tensor_map.input_name()]