From a245a950fc505256a2fdc4afabf65fbe2d927892 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:03:55 +0000 Subject: [PATCH 01/19] refactor to monai1.4, add quality check, correct typo Signed-off-by: Can-Zhao --- .../configs/image_median_statistics.json | 72 ++ .../configs/inference.json | 113 ++- models/maisi_ct_generative/configs/train.json | 32 +- models/maisi_ct_generative/large_files.yml | 12 +- .../maisi_ct_generative/scripts/__init__.py | 2 +- .../scripts/download_files.py | 11 + .../maisi_ct_generative/scripts/find_masks.py | 148 ++-- .../scripts/quality_check.py | 147 ++++ models/maisi_ct_generative/scripts/sample.py | 712 +++++++++++++----- models/maisi_ct_generative/scripts/trainer.py | 12 +- models/maisi_ct_generative/scripts/utils.py | 491 ++++++++++-- 11 files changed, 1339 insertions(+), 413 deletions(-) create mode 100644 models/maisi_ct_generative/configs/image_median_statistics.json create mode 100644 models/maisi_ct_generative/scripts/download_files.py create mode 100644 models/maisi_ct_generative/scripts/quality_check.py diff --git a/models/maisi_ct_generative/configs/image_median_statistics.json b/models/maisi_ct_generative/configs/image_median_statistics.json new file mode 100644 index 00000000..df966538 --- /dev/null +++ b/models/maisi_ct_generative/configs/image_median_statistics.json @@ -0,0 +1,72 @@ +{ + "liver": { + "min_median": -14.0, + "max_median": 1000.0, + "percentile_0_5": 9.530000000000001, + "percentile_99_5": 162.0, + "sigma_6_low": -21.596463547885904, + "sigma_6_high": 156.27881534763367, + "sigma_12_low": -110.53410299564568, + "sigma_12_high": 245.21645479539342 + }, + "spleen": { + "min_median": -69.0, + "max_median": 1000.0, + "percentile_0_5": 16.925000000000004, + "percentile_99_5": 184.07500000000073, + "sigma_6_low": -43.133891656525165, + "sigma_6_high": 177.40494997185993, + "sigma_12_low": -153.4033124707177, + "sigma_12_high": 287.6743707860525 + }, + "pancreas": { + "min_median": -124.0, + "max_median": 1000.0, + "percentile_0_5": -29.0, + "percentile_99_5": 145.92000000000007, + "sigma_6_low": -56.59382515620725, + "sigma_6_high": 149.50627399318438, + "sigma_12_low": -159.64387473090306, + "sigma_12_high": 252.5563235678802 + }, + "kidney": { + "min_median": -165.5, + "max_median": 819.0, + "percentile_0_5": -40.0, + "percentile_99_5": 254.61999999999898, + "sigma_6_low": -130.56375604853028, + "sigma_6_high": 267.28163511081016, + "sigma_12_low": -329.4864516282005, + "sigma_12_high": 466.20433069048045 + }, + "lung": { + "min_median": -1000.0, + "max_median": 65.0, + "percentile_0_5": -937.0, + "percentile_99_5": -366.9500000000007, + "sigma_6_low": -1088.5583843889117, + "sigma_6_high": -551.8503346949108, + "sigma_12_low": -1356.912409235912, + "sigma_12_high": -283.4963098479103 + }, + "bone": { + "min_median": 77.5, + "max_median": 1000.0, + "percentile_0_5": 136.45499999999998, + "percentile_99_5": 551.6350000000002, + "sigma_6_low": 71.39901958080469, + "sigma_6_high": 471.9957615639765, + "sigma_12_low": -128.8993514107812, + "sigma_12_high": 672.2941325555623 + }, + "brain": { + "min_median": -1000.0, + "max_median": 238.0, + "percentile_0_5": -951.0, + "percentile_99_5": 126.25, + "sigma_6_low": -304.8208236135867, + "sigma_6_high": 369.5118535139189, + "sigma_12_low": -641.9871621773394, + "sigma_12_high": 706.6781920776717 + } +} diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index d4268fc7..dbc19f6c 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -19,7 +19,6 @@ "all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'", "label_dict_json": "$@bundle_root + '/configs/label_dict.json'", "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'", - "quality_check_args": null, "num_output_samples": 1, "body_region": [ "abdomen" @@ -62,8 +61,10 @@ 64, 64 ], + "autoencoder_sliding_window_infer_size": [96, 96, 96], + "autoencoder_sliding_window_infer_overlap": 0.6667, "autoencoder_def": { - "_target_": "scripts.custom_network_tp.AutoencoderKlckModifiedTp", + "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", "spatial_dims": "@spatial_dims", "in_channels": "@image_channels", "out_channels": "@image_channels", @@ -73,11 +74,7 @@ 128, 256 ], - "num_res_blocks": [ - 2, - 2, - 2 - ], + "num_res_blocks": [2,2,2], "norm_num_groups": 32, "norm_eps": 1e-06, "attention_levels": [ @@ -88,10 +85,13 @@ "with_encoder_nonlocal_attn": false, "with_decoder_nonlocal_attn": false, "use_checkpointing": false, - "use_convtranspose": false + "use_convtranspose": false, + "norm_float16": true, + "num_splits": 8, + "dim_split": 1 }, - "difusion_unet_def": { - "_target_": "scripts.custom_network_diffusion.CustomDiffusionModelUNet", + "diffusion_unet_def": { + "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "out_channels": "@latent_channels", @@ -115,12 +115,12 @@ ], "num_res_blocks": 2, "use_flash_attention": true, - "input_top_region_index": true, - "input_bottom_region_index": true, - "input_spacing": true + "include_top_region_index_input": true, + "include_bottom_region_index_input": true, + "include_spacing_input": true }, "controlnet_def": { - "_target_": "scripts.custom_network_controlnet.CustomControlNet", + "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "num_channels": [ @@ -144,28 +144,20 @@ "num_res_blocks": 2, "use_flash_attention": true, "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [ - 8, - 32, - 64 - ] + "conditioning_embedding_num_channels": [8, 32, 64] }, "mask_generation_autoencoder_def": { - "_target_": "generative.networks.nets.AutoencoderKL", - "spatial_dims": 3, + "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", + "spatial_dims": "@spatial_dims", "in_channels": 8, "out_channels": 125, - "latent_channels": 4, + "latent_channels": "@latent_channels", "num_channels": [ 32, 64, 128 ], - "num_res_blocks": [ - 1, - 2, - 2 - ], + "num_res_blocks": [1, 2, 2], "norm_num_groups": 32, "norm_eps": 1e-06, "attention_levels": [ @@ -177,31 +169,19 @@ "with_decoder_nonlocal_attn": false, "use_flash_attention": false, "use_checkpointing": true, - "use_convtranspose": true + "use_convtranspose": true, + "norm_float16": true, + "num_splits": 8, + "dim_split": 1 }, "mask_generation_diffusion_def": { - "_target_": "generative.networks.nets.DiffusionModelUNet", - "spatial_dims": 3, - "in_channels": 4, - "out_channels": 4, - "num_channels": [ - 64, - 128, - 256, - 512 - ], - "attention_levels": [ - false, - false, - true, - true - ], - "num_head_channels": [ - 0, - 0, - 32, - 32 - ], + "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "out_channels": "@latent_channels", + "channels":[64, 128, 256, 512], + "attention_levels":[false, false, true, true], + "num_head_channels":[0, 0, 32, 32], "num_res_blocks": 2, "use_flash_attention": true, "with_conditioning": true, @@ -209,25 +189,25 @@ "cross_attention_dim": 10 }, "autoencoder": "$@autoencoder_def.to(@device)", - "checkpoint_autoencoder": "$scripts.utils.load_autoencoder_ckpt(@trained_autoencoder_path)", + "checkpoint_autoencoder": "$torch.load(@trained_autoencoder_path)", "load_autoencoder": "$@autoencoder.load_state_dict(@checkpoint_autoencoder)", - "difusion_unet": "$@difusion_unet_def.to(@device)", - "checkpoint_difusion_unet": "$torch.load(@trained_diffusion_path)", - "load_diffusion": "$@difusion_unet.load_state_dict(@checkpoint_difusion_unet['unet_state_dict'])", + "diffusion_unet": "$@diffusion_unet_def.to(@device)", + "checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path)", + "load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])", "controlnet": "$@controlnet_def.to(@device)", - "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @difusion_unet.state_dict())", + "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())", "checkpoint_controlnet": "$torch.load(@trained_controlnet_path)", "load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", - "scale_factor": "$@checkpoint_difusion_unet['scale_factor'].to(@device)", + "scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)", "mask_generation_autoencoder": "$@mask_generation_autoencoder_def.to(@device)", "checkpoint_mask_generation_autoencoder": "$torch.load(@trained_mask_generation_autoencoder_path)", "load_mask_generation_autoencoder": "$@mask_generation_autoencoder.load_state_dict(@checkpoint_mask_generation_autoencoder, strict=True)", - "mask_generation_difusion_unet": "$@mask_generation_diffusion_def.to(@device)", - "checkpoint_mask_generation_difusion_unet": "$torch.load(@trained_mask_generation_diffusion_path)", - "load_mask_generation_diffusion": "$@mask_generation_difusion_unet.load_state_dict(@checkpoint_mask_generation_difusion_unet, strict=True)", - "mask_generation_scale_factor": 1.0055984258651733, + "mask_generation_diffusion_unet": "$@mask_generation_diffusion_def.to(@device)", + "checkpoint_mask_generation_diffusion_unet": "$torch.load(@trained_mask_generation_diffusion_path)", + "load_mask_generation_diffusion": "$@mask_generation_diffusion_unet.load_state_dict(@checkpoint_mask_generation_diffusion_unet['unet_state_dict'], strict=True)", + "mask_generation_scale_factor": "$@checkpoint_mask_generation_diffusion_unet['scale_factor']", "noise_scheduler": { - "_target_": "generative.networks.schedulers.DDPMScheduler", + "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195, @@ -235,7 +215,7 @@ "clip_sample": false }, "mask_generation_noise_scheduler": { - "_target_": "generative.networks.schedulers.DDPMScheduler", + "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195, @@ -263,12 +243,12 @@ "label_dict_json": "@label_dict_json", "label_dict_remap_json": "@label_dict_remap_json", "autoencoder": "@autoencoder", - "difusion_unet": "@difusion_unet", + "diffusion_unet": "@diffusion_unet", "controlnet": "@controlnet", "scale_factor": "@scale_factor", "noise_scheduler": "@noise_scheduler", "mask_generation_autoencoder": "@mask_generation_autoencoder", - "mask_generation_difusion_unet": "@mask_generation_difusion_unet", + "mask_generation_diffusion_unet": "@mask_generation_diffusion_unet", "mask_generation_scale_factor": "@mask_generation_scale_factor", "mask_generation_noise_scheduler": "@mask_generation_noise_scheduler", "controllable_anatomy_size": "@controllable_anatomy_size", @@ -278,12 +258,13 @@ "latent_shape": "@latent_shape", "mask_generation_latent_shape": "@mask_generation_latent_shape", "output_size": "@output_size", - "quality_check_args": "@quality_check_args", "spacing": "@spacing", "output_dir": "@output_dir", "num_inference_steps": "@num_inference_steps", "mask_generation_num_inference_steps": "@mask_generation_num_inference_steps", - "random_seed": "@random_seed" + "random_seed": "@random_seed", + "autoencoder_sliding_window_infer_size": "@autoencoder_sliding_window_infer_size", + "autoencoder_sliding_window_infer_overlap": "@autoencoder_sliding_window_infer_overlap" }, "run": [ "$@ldm_sampler.sample_multiple_images(@num_output_samples)" diff --git a/models/maisi_ct_generative/configs/train.json b/models/maisi_ct_generative/configs/train.json index 33ba088e..51f6a236 100644 --- a/models/maisi_ct_generative/configs/train.json +++ b/models/maisi_ct_generative/configs/train.json @@ -28,8 +28,8 @@ "spatial_dims": 3, "image_channels": 1, "latent_channels": 4, - "difusion_unet_def": { - "_target_": "scripts.custom_network_diffusion.CustomDiffusionModelUNet", + "diffusion_unet_def": { + "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "out_channels": "@latent_channels", @@ -53,12 +53,12 @@ ], "num_res_blocks": 2, "use_flash_attention": true, - "input_top_region_index": true, - "input_bottom_region_index": true, - "input_spacing": true + "include_top_region_index_input": true, + "include_bottom_region_index_input": true, + "include_spacing_input": true }, "controlnet_def": { - "_target_": "scripts.custom_network_controlnet.CustomControlNet", + "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "num_channels": [ @@ -82,26 +82,22 @@ "num_res_blocks": 2, "use_flash_attention": true, "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [ - 8, - 32, - 64 - ] + "conditioning_embedding_num_channels": [8, 32, 64] }, "noise_scheduler": { - "_target_": "generative.networks.schedulers.DDPMScheduler", + "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195, "schedule": "scaled_linear_beta", "clip_sample": false }, - "unzip_dataset": "scripts.utils.unzip_dataset(@dataset_dir)", - "difusion_unet": "$@difusion_unet_def.to(@device)", - "checkpoint_difusion_unet": "$torch.load(@trained_diffusion_path)", - "load_diffusion": "$@difusion_unet.load_state_dict(@checkpoint_difusion_unet['unet_state_dict'])", + "unzip_dataset": "$scripts.utils.unzip_dataset(@dataset_dir)", + "diffusion_unet": "$@diffusion_unet_def.to(@device)", + "checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path)", + "load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])", "controlnet": "$@controlnet_def.to(@device)", - "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @difusion_unet.state_dict())", + "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())", "checkpoint_controlnet": "$torch.load(@trained_controlnet_path)", "load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", "scale_factor": "$@checkpoint_controlnet['scale_factor'].to(@device)", @@ -244,7 +240,7 @@ "max_epochs": "@epochs", "device": "@device", "train_data_loader": "@train#dataloader", - "difusion_unet": "@difusion_unet", + "diffusion_unet": "@diffusion_unet", "controlnet": "@controlnet", "noise_scheduler": "@noise_scheduler", "loss_function": "@loss", diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml index 9111cee2..bac9076a 100644 --- a/models/maisi_ct_generative/large_files.yml +++ b/models/maisi_ct_generative/large_files.yml @@ -1,20 +1,20 @@ large_files: - path: "models/autoencoder_epoch273.pt" - url: "https://drive.google.com/file/d/1jQefG0yJPzSvTG5rIJVHNqDReBTvVmZ0/view?usp=drive_link" + url: "https://drive.google.com/file/d/1Ojw25lFO8QbHkxazdK4CgZTyp3GFNZGz/view?usp=sharing" - path: "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt" - url: "https://drive.google.com/file/d/1FtOHBGUF5dLZNHtiuhf5EH448EQGGs-_/view?usp=sharing" + url: "https://drive.google.com/file/d/1lklNv4MTdI_9bwFRMd98QQ7JLerR5gC_/view?usp=drive_link" - path: "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt" - url: "https://drive.google.com/file/d/1izr52Whkk56OevNTk2QzI86eJV9TTaLk/view?usp=sharing" + url: "https://drive.google.com/file/d/1mLYeqeZ819_WpZPlAInhcWuCIHgn3QNT/view?usp=drive_link" - path: "models/mask_generation_autoencoder.pt" - url: "https://drive.google.com/file/d/1FzWrpv6ornYUaPiAWGOOxhRx2P9Wnynm/view?usp=drive_link" + url: "https://drive.google.com/file/d/19JnX-C6QAg4RfghTwpPnj4KEWhtawpCy/view?usp=drive_link" - path: "models/mask_generation_diffusion_unet.pt" - url: "https://drive.google.com/file/d/11SA9RUZ6XmCOJr5v6w6UW1kDzr6hlymw/view?usp=drive_link" + url: "https://drive.google.com/file/d/1yOQvlhXFGY1ZYavADM3N34vgg5AEitda/view?usp=drive_link" - path: "configs/candidate_masks_flexible_size_and_spacing_3000.json" url: "https://drive.google.com/file/d/1yMkH-lrAsn2YUGoTuVKNMpicziUmU-1J/view?usp=sharing" - path: "configs/all_anatomy_size_condtions.json" url: "https://drive.google.com/file/d/1AJyt1DSoUd2x2AOQOgM7IxeSyo4MXNX0/view?usp=sharing" - path: "datasets/all_masks_flexible_size_and_spacing_3000.zip" - url: "https://drive.google.com/file/d/16MKsDKkHvDyF2lEir4dzlxwex_GHStUf/view?usp=sharing" + url: "https://drive.google.com/file/d/1AJyt1DSoUd2x2AOQOgM7IxeSyo4MXNX0/view?usp=sharing" - path: "datasets/IntegrationTest-AbdomenCT.nii.gz" url: "https://drive.google.com/file/d/1OTgt_dyBgvP52krKRXWXD3u0L5Zbj5JR/view?usp=share_link" - path: "datasets/C4KC-KiTS_subset.zip" diff --git a/models/maisi_ct_generative/scripts/__init__.py b/models/maisi_ct_generative/scripts/__init__.py index 41d37723..bf77e0f5 100644 --- a/models/maisi_ct_generative/scripts/__init__.py +++ b/models/maisi_ct_generative/scripts/__init__.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import custom_network_diffusion, custom_network_tp, sample, utils +from . import sample, utils diff --git a/models/maisi_ct_generative/scripts/download_files.py b/models/maisi_ct_generative/scripts/download_files.py new file mode 100644 index 00000000..aee1a416 --- /dev/null +++ b/models/maisi_ct_generative/scripts/download_files.py @@ -0,0 +1,11 @@ +import yaml +import os +from monai.apps import download_url + +# Load YAML file +with open('large_files.yml', 'r') as file: + data = yaml.safe_load(file) + +# Iterate over each file in the YAML and download it +for file in data['large_files']: + download_url(url=file["url"], filepath=file["path"]) diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index 078ae394..c919d393 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -12,109 +12,143 @@ import json import os -import zipfile +from typing import Sequence +from monai.apps.utils import extractall +from monai.utils import ensure_tuple_rep -def convert_body_region(body_region: list[int]): - body_region_indices = [] - - for _k in range(len(body_region)): - region = body_region[_k].lower() - idx = None - if "head" in region: - idx = 0 - elif "chest" in region or "thorax" in region or "chest/thorax" in region: - idx = 1 - elif "abdomen" in region: - idx = 2 - elif "pelvis" in region or "lower" in region or "pelvis/lower" in region: - idx = 3 - else: - raise ValueError("Input region information is incorrect.") +def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]: + """ + Convert body region string to body region index. + Args: + body_region: list of input body region string. If single str, will be converted to list of str. + Return: + body_region_indices, list of input body region index. + """ + if type(body_region) is str: + body_region = [body_region] - body_region_indices.append(idx) + # body region mapping for maisi + region_mapping_maisi = { + "head": 0, + "chest": 1, + "thorax": 1, + "chest/thorax": 1, + "abdomen": 2, + "pelvis": 3, + "lower": 3, + "pelvis/lower": 3, + } + + # perform mapping + body_region_indices = [] + for region in body_region: + normalized_region = region.lower() # norm str to lower case + if normalized_region not in region_mapping_maisi: + raise ValueError(f"Invalid region: {normalized_region}") + body_region_indices.append(region_mapping_maisi[normalized_region]) return body_region_indices def find_masks( - body_region: str | list[str], - anatomy_list: int | list[int], - spacing: list[float], - output_size: list[int], + body_region: str | Sequence[str], + anatomy_list: int | Sequence[int], + spacing: Sequence[float] | float = 1.0, + output_size: Sequence[int] = [512, 512, 512], check_spacing_and_output_size: bool = False, - database_filepath: str = "./database.json", - mask_foldername: str = "./masks", + database_filepath: str = "./configs/database.json", + mask_foldername: str = "./datasets/masks/", ): - if type(body_region) is str: - body_region = [body_region] - + """ + Find candidate masks that fullfills all the requirements. + They shoud contain all the body region in `body_region`, all the anatomies in `anatomy_list`. + If there is no tumor specified in `anatomy_list`, we also expect the candidate masks to be tumor free. + If check_spacing_and_output_size is True, the candidate masks need to have the expected `spacing` and `output_size`. + Args: + body_region: list of input body region string. If single str, will be converted to list of str. + The found candidate mask will include these body regions. + anatomy_list: list of input anatomy. The found candidate mask will include these anatomies. + spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions. + output_size: list of three int, expected candidate mask spatial size. + check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size` and voxel size of `spacing`. + database_filepath: path for the json file that stores the information of all the candidate masks. + mask_foldername: directory that saves all the candidate masks. + Return: + candidate_masks, list of dict, each dict contains information of one candidate mask that fullfills all the requirements. + """ + # check and preprocess input body_region = convert_body_region(body_region) - if type(anatomy_list) is int: + if isinstance(anatomy_list, int): anatomy_list = [anatomy_list] - if not os.path.isfile(database_filepath): - raise ValueError(f"Please download {database_filepath}.") + spacing = ensure_tuple_rep(spacing, 3) if not os.path.exists(mask_foldername): zip_file_path = mask_foldername + ".zip" if not os.path.isfile(zip_file_path): - raise ValueError(f"Please downloaded {zip_file_path}.") + raise ValueError(f"Please download {zip_file_path} following the instruction in ./datasets/README.md.") - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - print(mask_foldername) - zip_ref.extractall(path="./datasets") + print(f"Extracting {zip_file_path} to {os.path.dirname(zip_file_path)}") + extractall(filepath=zip_file_path, output_dir=os.path.dirname(zip_file_path), file_type="zip") print(f"Unzipped {zip_file_path} to {mask_foldername}.") + if not os.path.isfile(database_filepath): + raise ValueError(f"Please download {database_filepath} following the instruction in ./datasets/README.md.") with open(database_filepath, "r") as f: db = json.load(f) - candidates = [] - for _i in range(len(db)): - _item = db[_i] + # select candidate_masks + candidate_masks = [] + for _item in db: if not set(anatomy_list).issubset(_item["label_list"]): continue + # extract region indice (top_index and bottom_index) for candidate mask top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0] top_index = top_index[0] bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0] bottom_index = bottom_index[0] - flag = False + # whether to keep this mask, default to be True. + keep_mask = True + + # if candiate mask does not contain all the body_region, skip it for _idx in body_region: if _idx > bottom_index or _idx < top_index: - flag = True + keep_mask = False - # check if candiate mask contains tumors for tumor_label in [23, 24, 26, 27, 128]: # we skip those mask with tumors if users do not provide tumor label in anatomy_list if tumor_label not in anatomy_list and tumor_label in _item["label_list"]: - flag = True + keep_mask = False if check_spacing_and_output_size: - # check if the output_size and spacing are same as user's input + # if the output_size and spacing are different with user's input, skip it for axis in range(3): if _item["dim"][axis] != output_size[axis] or _item["spacing"][axis] != spacing[axis]: - flag = True + keep_mask = False - if flag is True: - continue + if keep_mask: + # if decide to keep this mask, we pack the information of this mask and add to final output. + candidate = { + "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]), + "spacing": _item["spacing"], + "dim": _item["dim"], + "top_region_index": _item["top_region_index"], + "bottom_region_index": _item["bottom_region_index"], + } - candidate = {} - if "label_filename" in _item: - candidate["label"] = os.path.join(mask_foldername, _item["label_filename"]) - candidate["pseudo_label"] = os.path.join(mask_foldername, _item["pseudo_label_filename"]) - candidate["spacing"] = _item["spacing"] - candidate["dim"] = _item["dim"] - candidate["top_region_index"] = _item["top_region_index"] - candidate["bottom_region_index"] = _item["bottom_region_index"] + # Conditionally add the label to the candidate dictionary + if "label_filename" in _item: + candidate["label"] = os.path.join(mask_foldername, _item["label_filename"]) - candidates.append(candidate) + candidate_masks.append(candidate) - if len(candidates) == 0 and not check_spacing_and_output_size: - raise ValueError("Cannot find body region with given organ list.") + if len(candidate_masks) == 0 and not check_spacing_and_output_size: + raise ValueError("Cannot find body region with given anatomy list.") - return candidates + return candidate_masks diff --git a/models/maisi_ct_generative/scripts/quality_check.py b/models/maisi_ct_generative/scripts/quality_check.py new file mode 100644 index 00000000..22373276 --- /dev/null +++ b/models/maisi_ct_generative/scripts/quality_check.py @@ -0,0 +1,147 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nibabel as nib +import numpy as np + + +def get_masked_data(label_data, image_data, labels): + """ + Extracts and returns the image data corresponding to specified labels within a 3D volume. + + This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array. + The function handles cases with both a large and small number of labels, optimizing performance accordingly. + + Args: + label_data (np.ndarray): A NumPy array containing label data, representing different anatomical + regions or classes in a 3D medical image. + image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions + will be extracted. + labels (list of int): A list of integers representing the label values to be used for masking. + + Returns: + np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified + labels in `label_data`. If no labels are provided, an empty array is returned. + + Raises: + ValueError: If `image_data` and `label_data` do not have the same shape. + + Example: + label_int_dict = {"liver": [1], "kidney": [5, 14]} + masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"]) + """ + + # Check if the shapes of image_data and label_data match + if image_data.shape != label_data.shape: + raise ValueError( + f"Shape mismatch: image_data has shape {image_data.shape}, " + f"but label_data has shape {label_data.shape}. They must be the same." + ) + + if not labels: + return np.array([]) # Return an empty array if no labels are provided + + labels = list(set(labels)) # remove duplicate items + + # Optimize performance based on the number of labels + num_label_acceleration_thresh = 3 + if len(labels) >= num_label_acceleration_thresh: + # if many labels, np.isin is faster + mask = np.isin(label_data, labels) + else: + # Use logical OR to combine masks if the number of labels is small + mask = np.zeros_like(label_data, dtype=bool) + for label in labels: + mask = np.logical_or(mask, label_data == label) + + # Retrieve the masked data + masked_data = image_data[mask.astype(bool)] + + return masked_data + + +def is_outlier(statistics, image_data, label_data, label_int_dict): + """ + Perform a quality check on the generated image by comparing its statistics with precomputed thresholds. + + Args: + statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges. + image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array. + label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest. + label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists. + e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]} + + Returns: + dict: A dictionary with labels as keys, each containing the quality check result, + including whether it's an outlier, the median value, and the thresholds used. + If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`. + + Example: + # Example input data + statistics = { + "liver": { + "sigma_6_low": -21.596463547885904, + "sigma_6_high": 156.27881534763367 + }, + "kidney": { + "sigma_6_low": -15.0, + "sigma_6_high": 120.0 + } + } + label_int_dict = { + "liver": [1], + "kidney": [5, 14] + } + image_data = np.random.rand(100, 100, 100) # Replace with actual image data + label_data = np.zeros((100, 100, 100)) # Replace with actual label data + label_data[40:60, 40:60, 40:60] = 1 # Example region for liver + label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney + result = is_outlier(statistics, image_data, label_data, label_int_dict) + """ + outlier_results = {} + + for label_name, stats in statistics.items(): + # Get the thresholds from the statistics + low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs + high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs + + # Retrieve the corresponding label integers + labels = label_int_dict.get(label_name, []) + masked_data = get_masked_data(label_data, image_data, labels) + masked_data = masked_data[~np.isnan(masked_data)] + + if len(masked_data) == 0 or masked_data.size == 0: + outlier_results[label_name] = { + "is_outlier": False, + "median_value": None, + "low_thresh": low_thresh, + "high_thresh": high_thresh, + } + continue + + # Compute the median of the masked region + median_value = np.nanmedian(masked_data) + + if np.isnan(median_value): + median_value = None + is_outlier = False + else: + # Determine if the median value is an outlier + is_outlier = median_value < low_thresh or median_value > high_thresh + + outlier_results[label_name] = { + "is_outlier": is_outlier, + "median_value": median_value, + "low_thresh": low_thresh, + "high_thresh": high_thresh, + } + + return outlier_results diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 9616df41..3c78236b 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -10,12 +10,16 @@ # limitations under the License. import json +import logging +import math +import os import random +import time from datetime import datetime import monai import torch -from generative.inferers import LatentDiffusionInferer +from monai.inferers.inferer import DiffusionInferer from monai.data import MetaTensor from monai.inferers import sliding_window_inference from monai.transforms import Compose, SaveImage @@ -24,23 +28,64 @@ from .augmentation import augmentation from .find_masks import find_masks -from .utils import MapLabelValue, binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask +from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels +from .quality_check import is_outlier class ReconModel(torch.nn.Module): + """ + A PyTorch module for reconstructing images from latent representations. + + Attributes: + autoencoder: The autoencoder model used for decoding. + scale_factor: Scaling factor applied to the input before decoding. + """ + def __init__(self, autoencoder, scale_factor): super().__init__() self.autoencoder = autoencoder self.scale_factor = scale_factor def forward(self, z): + """ + Decode the input latent representation to an image. + + Args: + z (torch.Tensor): The input latent representation. + + Returns: + torch.Tensor: The reconstructed image. + """ recon_pt_nda = self.autoencoder.decode_stage_2_outputs(z / self.scale_factor) return recon_pt_nda +def initialize_noise_latents(latent_shape, device): + """ + Initialize random noise latents for image generation with float16. + + Args: + latent_shape (tuple): The shape of the latent space. + device (torch.device): The device to create the tensor on. + + Returns: + torch.Tensor: Initialized noise latents. + """ + return ( + torch.randn( + [ + 1, + ] + + list(latent_shape) + ) + .half() + .to(device) + ) + + def ldm_conditional_sample_one_mask( autoencoder, - difusion_unet, + diffusion_unet, noise_scheduler, scale_factor, anatomy_size, @@ -48,67 +93,96 @@ def ldm_conditional_sample_one_mask( latent_shape, label_dict_remap_json, num_inference_steps=1000, + autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_overlap=0.6667, ): - with torch.no_grad(): - with torch.cuda.amp.autocast(): - - # Generate random noise - latents = torch.randn([1] + list(latent_shape)).half().to(device) - anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device) - # synthesize masks - noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) - inferer_ddpm = LatentDiffusionInferer(noise_scheduler, scale_factor=scale_factor) - synthetic_mask = inferer_ddpm.sample( - input_noise=latents, - autoencoder_model=autoencoder, - diffusion_model=difusion_unet, - scheduler=noise_scheduler, - verbose=True, - conditioning=anatomy_size.to(device), - ) - synthetic_mask = torch.softmax(synthetic_mask, dim=1) - synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) - # mapping raw index to 132 labels - with open(label_dict_remap_json, "r") as f: - mapping_dict = json.load(f) - mapping = [v for _, v in mapping_dict.items()] - mapper = MapLabelValue( - orig_labels=[pair[0] for pair in mapping], - target_labels=[pair[1] for pair in mapping], - dtype=torch.uint8, + """ + Generate a single synthetic mask using a latent diffusion model. + + Args: + autoencoder (nn.Module): The autoencoder model. + diffusion_unet (nn.Module): The diffusion U-Net model. + noise_scheduler: The noise scheduler for the diffusion process. + scale_factor (float): Scaling factor for the latent space. + anatomy_size (torch.Tensor): Tensor specifying the desired anatomy sizes. + device (torch.device): The device to run the computation on. + latent_shape (tuple): The shape of the latent space. + label_dict_remap_json (str): Path to the JSON file for label remapping. + num_inference_steps (int): Number of inference steps for the diffusion process. + autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. + autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. + + Returns: + torch.Tensor: The generated synthetic mask. + """ + recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) + + with torch.no_grad(), torch.cuda.amp.autocast(): + # Generate random noise + latents = initialize_noise_latents(latent_shape, device) + anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device) + # synthesize latents + noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) + inferer_ddpm = DiffusionInferer(noise_scheduler) + latents = inferer_ddpm.sample( + input_noise=latents, + diffusion_model=diffusion_unet, + scheduler=noise_scheduler, + verbose=True, + conditioning=anatomy_size.to(device), + ) + # decode latents to synthesized masks + if math.prod(latent_shape[1:]) < math.prod(autoencoder_sliding_window_infer_size): + synthetic_mask = recon_model(latents).cpu().detach() + else: + synthetic_mask = ( + sliding_window_inference( + inputs=latents, + roi_size=( + autoencoder_sliding_window_infer_size[0], + autoencoder_sliding_window_infer_size[1], + autoencoder_sliding_window_infer_size[2], + ), + sw_batch_size=1, + predictor=recon_model, + mode="gaussian", + overlap=autoencoder_sliding_window_infer_overlap, + sw_device=device, + device=torch.device("cpu"), + progress=True, + ) + .cpu() + .detach() ) - synthetic_mask = mapper(synthetic_mask[0, ...])[None, ...].to(device) - - # post process - data = synthetic_mask.squeeze().cpu().detach().numpy() - if anatomy_size[0, 0, 5].item() != -1.0: - target_tumor_label = 23 - elif anatomy_size[0, 0, 6].item() != -1.0: - target_tumor_label = 24 - elif anatomy_size[0, 0, 7].item() != -1.0: - target_tumor_label = 26 - elif anatomy_size[0, 0, 8].item() != -1.0: - target_tumor_label = 27 - elif anatomy_size[0, 0, 9].item() != -1.0: - target_tumor_label = 128 - else: - target_tumor_label = None + synthetic_mask = torch.softmax(synthetic_mask, dim=1) + synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) + # mapping raw index to 132 labels + synthetic_mask = remap_labels(synthetic_mask, label_dict_remap_json) + + ###### post process ##### + data = synthetic_mask.squeeze().cpu().detach().numpy() + + labels = [23, 24, 26, 27, 128] + target_tumor_label = None + for index, size in enumerate(anatomy_size[0, 0, 5:10]): + if size.item() != -1.0: + target_tumor_label = labels[index] - print("target_tumor_label for postprocess:", target_tumor_label) - data = general_mask_generation_post_process(data, target_tumor_label=target_tumor_label, device=device) - synthetic_mask = torch.from_numpy(data).unsqueeze(0).unsqueeze(0).to(device) + logging.info(f"target_tumor_label for postprocess:{target_tumor_label}") + data = general_mask_generation_post_process(data, target_tumor_label=target_tumor_label, device=device) + synthetic_mask = torch.from_numpy(data).unsqueeze(0).unsqueeze(0).to(device) return synthetic_mask def ldm_conditional_sample_one_image( autoencoder, - difusion_unet, + diffusion_unet, controlnet, noise_scheduler, scale_factor, device, - comebine_label_or, + combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, @@ -116,7 +190,33 @@ def ldm_conditional_sample_one_image( output_size, noise_factor, num_inference_steps=1000, + autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_overlap=0.6667, ): + """ + Generate a single synthetic image using a latent diffusion model with controlnet. + + Args: + autoencoder (nn.Module): The autoencoder model. + diffusion_unet (nn.Module): The diffusion U-Net model. + controlnet (nn.Module): The controlnet model. + noise_scheduler: The noise scheduler for the diffusion process. + scale_factor (float): Scaling factor for the latent space. + device (torch.device): The device to run the computation on. + combine_label_or (torch.Tensor): The combined label tensor. + top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. + bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. + spacing_tensor (torch.Tensor): Tensor specifying the spacing. + latent_shape (tuple): The shape of the latent space. + output_size (tuple): The desired output size of the image. + noise_factor (float): Factor to scale the initial noise. + num_inference_steps (int): Number of inference steps for the diffusion process. + autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. + autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. + + Returns: + tuple: A tuple containing the synthetic image and its corresponding label. + """ # CT image intensity range a_min = -1000 a_max = 1000 @@ -126,104 +226,157 @@ def ldm_conditional_sample_one_image( recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) - with torch.no_grad(): - with torch.cuda.amp.autocast(): - # generate segmentation mask - comebine_label = comebine_label_or.to(device) - if ( - output_size[0] != comebine_label.shape[2] - or output_size[1] != comebine_label.shape[3] - or output_size[2] != comebine_label.shape[4] - ): - print( - "output_size is not a desired value. Need to interpolate the mask to " - "match with output_size. The result image will be very low quality." - ) - comebine_label = torch.nn.functional.interpolate(comebine_label, size=output_size, mode="nearest") - - controlnet_cond_vis = binarize_labels(comebine_label.as_tensor().long()).half() + with torch.no_grad(), torch.cuda.amp.autocast(): + logging.info("---- Start generating latent features... ----") + start_time = time.time() + # generate segmentation mask + combine_label = combine_label_or.to(device) + if ( + output_size[0] != combine_label.shape[2] + or output_size[1] != combine_label.shape[3] + or output_size[2] != combine_label.shape[4] + ): + logging.info( + "output_size is not a desired value. Need to interpolate the mask to match with output_size. The result image will be very low quality." + ) + combine_label = torch.nn.functional.interpolate(combine_label, size=output_size, mode="nearest") - # Generate random noise - latents = torch.randn([1] + list(latent_shape)).half().to(device) * noise_factor + controlnet_cond_vis = binarize_labels(combine_label.as_tensor().long()).half() - # synthesize latents - noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) - for t in tqdm(noise_scheduler.timesteps, ncols=110): - # Get controlnet output - down_block_res_samples, mid_block_res_sample = controlnet( - x=latents, timesteps=torch.Tensor((t,)).to(device), controlnet_cond=controlnet_cond_vis - ) - latent_model_input = latents - noise_pred = difusion_unet( - x=latent_model_input, - timesteps=torch.Tensor((t,)).to(device), - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - latents, _ = noise_scheduler.step(noise_pred, t, latents) + # Generate random noise + latents = initialize_noise_latents(latent_shape, device) * noise_factor - # decode latents to synthesized images + # synthesize latents + noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) + for t in tqdm(noise_scheduler.timesteps, ncols=110): + # Get controlnet output + down_block_res_samples, mid_block_res_sample = controlnet( + x=latents, + timesteps=torch.Tensor((t,)).to(device), + controlnet_cond=controlnet_cond_vis, + ) + latent_model_input = latents + noise_pred = diffusion_unet( + x=latent_model_input, + timesteps=torch.Tensor((t,)).to(device), + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + latents, _ = noise_scheduler.step(noise_pred, t, latents) + end_time = time.time() + logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----") + del noise_pred + torch.cuda.empty_cache() + + # decode latents to synthesized images + logging.info("---- Start decoding latent features into images... ----") + start_time = time.time() + if math.prod(latent_shape[1:]) < math.prod(autoencoder_sliding_window_infer_size): + synthetic_images = recon_model(latents) + else: synthetic_images = sliding_window_inference( inputs=latents, roi_size=( - min(output_size[0] // 4 // 4 * 3, 96), - min(output_size[1] // 4 // 4 * 3, 96), - min(output_size[2] // 4 // 4 * 3, 96), + min(output_size[0] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[0]), + min(output_size[1] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[1]), + min(output_size[2] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[2]), ), sw_batch_size=1, predictor=recon_model, mode="gaussian", - overlap=2.0 / 3.0, + overlap=autoencoder_sliding_window_infer_overlap, sw_device=device, - device=device, + device=torch.device("cpu"), + progress=True, ) - synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu() + end_time = time.time() + logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----") - # post processing: + ## post processing: # project output to [0, 1] synthetic_images = (synthetic_images - b_min) / (b_max - b_min) # project output to [-1000, 1000] synthetic_images = synthetic_images * (a_max - a_min) + a_min # regularize background intensities - synthetic_images = crop_img_body_mask(synthetic_images, comebine_label) + synthetic_images = crop_img_body_mask(synthetic_images, combine_label) + torch.cuda.empty_cache() + + return synthetic_images, combine_label - return synthetic_images, comebine_label +def filter_mask_with_organs(combine_label, anatomy_list): + """ + Filter a mask to only include specified organs. -def filter_mask_with_organs(comebine_label, anatomy_list): - # final output mask file has shape of output_size, contaisn labels in anatomy_list + Args: + combine_label (torch.Tensor): The input mask. + anatomy_list (list): List of organ labels to keep. + + Returns: + torch.Tensor: The filtered mask. + """ + # final output mask file has shape of output_size, contains labels in anatomy_list # it is already interpolated to target size - comebine_label = comebine_label.long() + combine_label = combine_label.long() # filter out the organs that are not in anatomy_list for i in range(len(anatomy_list)): organ = anatomy_list[i] # replace it with a negative value so it will get mixed - comebine_label[comebine_label == organ] = -(i + 1) + combine_label[combine_label == organ] = -(i + 1) # zero-out voxels with value not in anatomy_list - comebine_label[comebine_label > 0] = 0 + combine_label[combine_label > 0] = 0 # output positive values - comebine_label = -comebine_label - return comebine_label + combine_label = -combine_label + return combine_label + + +def crop_img_body_mask(synthetic_images, combine_label): + """ + Crop the synthetic image using a body mask. + Args: + synthetic_images (torch.Tensor): The synthetic images. + combine_label (torch.Tensor): The body mask. -def crop_img_body_mask(synthetic_images, comebine_label): - synthetic_images[comebine_label == 0] = -1000 + Returns: + torch.Tensor: The cropped synthetic images. + """ + synthetic_images[combine_label == 0] = -1000 return synthetic_images -def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size): +def check_input( + body_region, + anatomy_list, + label_dict_json, + output_size, + spacing, + controllable_anatomy_size=[("pancreas", 0.5)], +): + """ + Validate input parameters for image generation. + + Args: + body_region (list): List of body regions. + anatomy_list (list): List of anatomical structures. + label_dict_json (str): Path to the label dictionary JSON file. + output_size (tuple): Desired output size of the image. + spacing (tuple): Desired voxel spacing. + controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. + + Raises: + ValueError: If any input parameter is invalid. + """ # check output_size and spacing format if output_size[0] != output_size[1]: raise ValueError(f"The first two components of output_size need to be equal, yet got {output_size}.") if (output_size[0] not in [256, 384, 512]) or (output_size[2] not in [128, 256, 384, 512, 640, 768]): raise ValueError( - "The output_size[0] have to be chosen from [256, 384, 512], and " - "output_size[2] have to be chosen from [128, 256, 384, 512, 640, 768], " - f"yet got {output_size}." + f"The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." ) if spacing[0] != spacing[1]: @@ -233,13 +386,24 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." ) + if output_size[0] * spacing[0] < 256: + FOV = [output_size[axis] * spacing[axis] for axis in range(3)] + raise ValueError( + f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). The FOV will be {FOV}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least 384mm for other body regions like abdomen. There is no such restriction for z-axis." + ) + # check controllable_anatomy_size format if len(controllable_anatomy_size) > 10: raise ValueError( - f"The length of list controllable_anatomy_size has to be less than 10. " - f"Yet got length equal to {len(controllable_anatomy_size)}." + f"The length of list controllable_anatomy_size has to be less than 10. Yet got length equal to {len(controllable_anatomy_size)}." ) - available_controllable_organ = ["liver", "gallbladder", "stomach", "pancreas", "colon"] + available_controllable_organ = [ + "liver", + "gallbladder", + "stomach", + "pancreas", + "colon", + ] available_controllable_tumor = [ "hepatic tumor", "bone lesion", @@ -253,9 +417,7 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing for controllable_anatomy_size_pair in controllable_anatomy_size: if controllable_anatomy_size_pair[0] not in available_controllable_anatomy: raise ValueError( - f"The controllable_anatomy have to be chosen from " - f"{available_controllable_anatomy}, yet got " - f"{controllable_anatomy_size_pair[0]}." + f"The controllable_anatomy have to be chosen from {available_controllable_anatomy}, yet got {controllable_anatomy_size_pair[0]}." ) if controllable_anatomy_size_pair[0] in available_controllable_tumor: controllable_tumor += [controllable_anatomy_size_pair[0]] @@ -265,8 +427,7 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing continue if controllable_anatomy_size_pair[1] < 0 or controllable_anatomy_size_pair[1] > 1.0: raise ValueError( - f"The controllable size scale have to be between 0 and 1,0, or equal to -1, " - f"yet got {controllable_anatomy_size_pair[1]}." + f"The controllable size scale have to be between 0 and 1,0, or equal to -1, yet got {controllable_anatomy_size_pair[1]}." ) if len(controllable_tumor + controllable_organ) != len(list(set(controllable_tumor + controllable_organ))): raise ValueError(f"Please do not repeat controllable_anatomy. Got {controllable_tumor + controllable_organ}.") @@ -274,14 +435,22 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing raise ValueError(f"Only one controllable tumor is supported. Yet got {controllable_tumor}.") if len(controllable_anatomy_size) > 0: - print( - "controllable_anatomy_size is not empty. We will ignore body_region and " - "anatomy_list and synthesize based on controllable_anatomy_size." + logging.info( + f"`controllable_anatomy_size` is not empty.\nWe will ignore `body_region` and `anatomy_list` and synthesize based on `controllable_anatomy_size`: ({controllable_anatomy_size})." ) else: - print("controllable_anatomy_size is empty. We will synthesize based on body_region and anatomy_list.") + logging.info( + f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `body_region`: ({body_region}) and `anatomy_list`: ({anatomy_list})." + ) # check body_region format - available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] + available_body_region = [ + "head", + "chest", + "thorax", + "abdomen", + "pelvis", + "lower", + ] for region in body_region: if region not in available_body_region: raise ValueError( @@ -296,11 +465,19 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing raise ValueError( f"The components in anatomy_list have to be chosen from {label_dict.keys()}, yet got {anatomy}." ) + logging.info(f"The generate results will have voxel size to be {spacing}mm, volume size to be {output_size}.") return class LDMSampler: + """ + A sampler class for generating synthetic medical images and masks using latent diffusion models. + + Attributes: + Various attributes related to model configuration, input parameters, and generation settings. + """ + def __init__( self, body_region, @@ -311,12 +488,12 @@ def __init__( label_dict_json, label_dict_remap_json, autoencoder, - difusion_unet, + diffusion_unet, controlnet, noise_scheduler, scale_factor, mask_generation_autoencoder, - mask_generation_difusion_unet, + mask_generation_diffusion_unet, mask_generation_scale_factor, mask_generation_noise_scheduler, device, @@ -327,13 +504,20 @@ def __init__( controllable_anatomy_size, image_output_ext=".nii.gz", label_output_ext=".nii.gz", - quality_check_args=None, - spacing=(1, 1, 1), + real_img_median_statistics="./configs/image_median_statistics.json", + spacing=[1, 1, 1], num_inference_steps=None, mask_generation_num_inference_steps=None, random_seed=None, + autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_overlap=0.6667, ) -> None: + """ + Initialize the LDMSampler with various parameters and models. + Args: + Various parameters related to model configuration, input settings, and output specifications. + """ if random_seed is not None: set_determinism(seed=random_seed) @@ -348,12 +532,12 @@ def __init__( self.data_root = all_mask_files_base_dir self.label_dict_remap_json = label_dict_remap_json self.autoencoder = autoencoder - self.difusion_unet = difusion_unet + self.diffusion_unet = diffusion_unet self.controlnet = controlnet self.noise_scheduler = noise_scheduler self.scale_factor = scale_factor self.mask_generation_autoencoder = mask_generation_autoencoder - self.mask_generation_difusion_unet = mask_generation_difusion_unet + self.mask_generation_diffusion_unet = mask_generation_diffusion_unet self.mask_generation_scale_factor = mask_generation_scale_factor self.mask_generation_noise_scheduler = mask_generation_noise_scheduler self.device = device @@ -364,7 +548,7 @@ def __init__( self.noise_factor = 1.0 self.controllable_anatomy_size = controllable_anatomy_size if len(self.controllable_anatomy_size): - print("controllable_anatomy_size is given, mask generation is triggered!") + logging.info("controllable_anatomy_size is given, mask generation is triggered!") # overwrite the anatomy_list by given organs in self.controllable_anatomy_size self.anatomy_list = [label_dict[organ_and_size[0]] for organ_and_size in self.controllable_anatomy_size] self.image_output_ext = image_output_ext @@ -375,14 +559,42 @@ def __init__( mask_generation_num_inference_steps if mask_generation_num_inference_steps is not None else 1000 ) - # quality check disabled for this version - self.quality_check_args = quality_check_args + if any(size % 16 != 0 for size in autoencoder_sliding_window_infer_size): + raise ValueError( + f"autoencoder_sliding_window_infer_size must be divisible by 16.\n Got {autoencoder_sliding_window_infer_size}" + ) + if not (0 <= autoencoder_sliding_window_infer_overlap <= 1): + raise ValueError( + f"Value of autoencoder_sliding_window_infer_overlap must be between 0 and 1.\n Got {autoencoder_sliding_window_infer_overlap}" + ) + self.autoencoder_sliding_window_infer_size = autoencoder_sliding_window_infer_size + self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap + + # quality check args + self.max_try_time = 5 # if not pass quality check, will try self.max_try_time times + with open(real_img_median_statistics, "r") as json_file: + self.median_statistics = json.load(json_file) + self.label_int_dict = { + "liver": [1], + "spleen": [3], + "pancreas": [4], + "kidney": [5, 14], + "lung": [28, 29, 30, 31, 31], + "brain": [22], + "hepatic tumor": [26], + "bone lesion": [128], + "lung tumor": [23], + "colon cancer primaries": [27], + "pancreatic tumor": [24], + "bone": list(range(33, 57)) + list(range(63, 98)) + [120, 122, 127], + } + # networks self.autoencoder.eval() - self.difusion_unet.eval() + self.diffusion_unet.eval() self.controlnet.eval() self.mask_generation_autoencoder.eval() - self.mask_generation_difusion_unet.eval() + self.mask_generation_diffusion_unet.eval() self.spacing = spacing @@ -400,9 +612,16 @@ def __init__( monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), ] ) - print("LDM sampler initialized.") + logging.info("LDM sampler initialized.") def sample_multiple_images(self, num_img): + """ + Generate multiple synthetic images and masks. + + Args: + num_img (int): Number of images to generate. + """ + output_filenames = [] if len(self.controllable_anatomy_size) > 0: # we will use mask generation instead of finding candidate masks # create a dummy selected_mask_files for placeholder @@ -424,47 +643,60 @@ def sample_multiple_images(self, num_img): if len(candidate_mask_files) < num_img: # if we cannot find enough masks based on the exact match of anatomy list, spacing, and output size, # then we will try to find the closest mask in terms of spacing, and output size. - print("Resample to get desired output size and spacing") + logging.info("Resample mask file to get desired output size and spacing") candidate_mask_files = self.find_closest_masks(num_img) need_resample = True selected_mask_files = self.select_mask(candidate_mask_files, num_img) - print(selected_mask_files) + logging.info(f"Images will be generated based on {selected_mask_files}.") if len(selected_mask_files) != num_img: raise ValueError( - f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img " - f"({num_img}). This should not happen. Please revisit function " - f"select_mask(self, candidate_mask_files, num_img)." + f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img ({num_img}). This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." ) for item in selected_mask_files: + logging.info("---- Start preparing masks... ----") + start_time = time.time() if len(self.controllable_anatomy_size) > 0: # generate a synthetic mask - (comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( - self.prepare_one_mask_and_meta_info(anatomy_size_condtion) - ) + ( + combine_label_or, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + ) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion) else: # read in mask file mask_file = item["mask_file"] if_aug = item["if_aug"] - (comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( - self.read_mask_information(mask_file) - ) + ( + combine_label_or, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + ) = self.read_mask_information(mask_file) if need_resample: - comebine_label_or = self.ensure_output_size_and_spacing(comebine_label_or) + combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) # mask augmentation - if if_aug is True: - comebine_label_or = augmentation(comebine_label_or, self.output_size) + if if_aug: + combine_label_or = augmentation(combine_label_or, self.output_size) + end_time = time.time() + logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----") torch.cuda.empty_cache() # generate image/label pairs to_generate = True try_time = 0 while to_generate: synthetic_images, synthetic_labels = self.sample_one_pair( - comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + combine_label_or, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + ) + # synthetic image quality check + pass_quality_check = self.quality_check( + synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() ) - # current quality always return True - pass_quality_check = self.quality_check(synthetic_images) - if pass_quality_check or try_time > 3: + if pass_quality_check or try_time > self.max_try_time: # save image/label pairs output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" @@ -476,6 +708,9 @@ def sample_multiple_images(self, num_img): separate_folder=False, ) img_saver(synthetic_images[0]) + synthetic_images_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext + ) # filter out the organs that are not in anatomy_list synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) label_saver = SaveImage( @@ -485,13 +720,29 @@ def sample_multiple_images(self, num_img): separate_folder=False, ) label_saver(synthetic_labels[0]) + synthetic_labels_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext + ) + output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) to_generate = False else: - print("Generated image/label pair did not pass quality check, will re-generate another pair.") + logging.info( + "Generated image/label pair did not pass quality check, will re-generate another pair." + ) try_time += 1 - return + return output_filenames def select_mask(self, candidate_mask_files, num_img): + """ + Select mask files for image generation. + + Args: + candidate_mask_files (list): List of candidate mask files. + num_img (int): Number of images to generate. + + Returns: + list: Selected mask files with augmentation flags. + """ selected_mask_files = [] random.shuffle(candidate_mask_files) @@ -501,17 +752,33 @@ def select_mask(self, candidate_mask_files, num_img): return selected_mask_files def sample_one_pair( - self, comebine_label_or_aug, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + self, + combine_label_or_aug, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, ): + """ + Generate a single pair of synthetic image and mask. + + Args: + combine_label_or_aug (torch.Tensor): Combined label tensor or augmented label. + top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. + bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. + spacing_tensor (torch.Tensor): Tensor specifying the spacing. + + Returns: + tuple: A tuple containing the synthetic image and its corresponding label. + """ # generate image/label pairs synthetic_images, synthetic_labels = ldm_conditional_sample_one_image( autoencoder=self.autoencoder, - difusion_unet=self.difusion_unet, + diffusion_unet=self.diffusion_unet, controlnet=self.controlnet, noise_scheduler=self.noise_scheduler, scale_factor=self.scale_factor, device=self.device, - comebine_label_or=comebine_label_or_aug, + combine_label_or=combine_label_or_aug, top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, spacing_tensor=spacing_tensor, @@ -519,10 +786,24 @@ def sample_one_pair( output_size=self.output_size, noise_factor=self.noise_factor, num_inference_steps=self.num_inference_steps, + autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, + autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, ) return synthetic_images, synthetic_labels - def prepare_anatomy_size_condtion(self, controllable_anatomy_size): + def prepare_anatomy_size_condtion( + self, + controllable_anatomy_size, + ): + """ + Prepare anatomy size conditions for mask generation. + + Args: + controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. + + Returns: + list: Prepared anatomy size conditions. + """ anatomy_size_idx = { "gallbladder": 0, "liver": 1, @@ -536,7 +817,7 @@ def prepare_anatomy_size_condtion(self, controllable_anatomy_size): "bone lesion": 9, } provide_anatomy_size = [None for _ in range(10)] - print("controllable_anatomy_size:", controllable_anatomy_size) + logging.info(f"controllable_anatomy_size: {controllable_anatomy_size}") for element in controllable_anatomy_size: anatomy_name, anatomy_size = element provide_anatomy_size[anatomy_size_idx[anatomy_name]] = anatomy_size @@ -555,40 +836,56 @@ def prepare_anatomy_size_condtion(self, controllable_anatomy_size): diff += abs(provide_size - db_size) candidate_list.append((size, diff)) candidate_condition = sorted(candidate_list, key=lambda x: x[1])[0][0] - print("provide_anatomy_size:", provide_anatomy_size) - print("candidate_condition:", candidate_condition) # overwrite the anatomy size provided by users for element in controllable_anatomy_size: anatomy_name, anatomy_size = element candidate_condition[anatomy_size_idx[anatomy_name]] = anatomy_size - print("final candidate_condition:", candidate_condition) + return candidate_condition def prepare_one_mask_and_meta_info(self, anatomy_size_condtion): - comebine_label_or = self.sample_one_mask(anatomy_size=anatomy_size_condtion) + """ + Prepare a single mask and its associated meta information. + + Args: + anatomy_size_condtion (list): Anatomy size conditions. + + Returns: + tuple: A tuple containing the prepared mask and associated tensors. + """ + combine_label_or = self.sample_one_mask(anatomy_size=anatomy_size_condtion) # TODO: current mask generation model only can generate 256^3 volumes with 1.5 mm spacing. affine = torch.zeros((4, 4)) affine[0, 0] = 1.5 affine[1, 1] = 1.5 affine[2, 2] = 1.5 affine[3, 3] = 1.0 # dummy - comebine_label_or = MetaTensor(comebine_label_or, affine=affine) - comebine_label_or = self.ensure_output_size_and_spacing(comebine_label_or) + combine_label_or = MetaTensor(combine_label_or, affine=affine) + combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) - top_region_index, bottom_region_index = get_body_region_index_from_mask(comebine_label_or) + top_region_index, bottom_region_index = get_body_region_index_from_mask(combine_label_or) spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2 top_region_index_tensor = torch.FloatTensor(top_region_index).unsqueeze(0).half().to(self.device) * 1e2 bottom_region_index_tensor = torch.FloatTensor(bottom_region_index).unsqueeze(0).half().to(self.device) * 1e2 - return comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + return combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor def sample_one_mask(self, anatomy_size): + """ + Generate a single synthetic mask. + + Args: + anatomy_size (list): Anatomy size specifications. + + Returns: + torch.Tensor: The generated synthetic mask. + """ # generate one synthetic mask synthetic_mask = ldm_conditional_sample_one_mask( self.mask_generation_autoencoder, - self.mask_generation_difusion_unet, + self.mask_generation_diffusion_unet, self.mask_generation_noise_scheduler, self.mask_generation_scale_factor, anatomy_size, @@ -596,10 +893,25 @@ def sample_one_mask(self, anatomy_size): self.mask_generation_latent_shape, label_dict_remap_json=self.label_dict_remap_json, num_inference_steps=self.mask_generation_num_inference_steps, + autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, + autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, ) return synthetic_mask def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=True): + """ + Ensure the output mask has the correct size and spacing. + + Args: + labels (torch.Tensor): Input label tensor. + check_contains_target_labels (bool): Whether to check if the resampled mask contains target labels. + + Returns: + torch.Tensor: Resampled label tensor. + + Raises: + ValueError: If the resampled mask doesn't contain required class labels. + """ current_spacing = [labels.affine[0, 0], labels.affine[1, 1], labels.affine[2, 2]] current_shape = list(labels.squeeze().shape) @@ -614,27 +926,41 @@ def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=Tr need_resample = True # resample to target size and spacing if need_resample: - print("Resampling mask to target shape and sapcing") - print(f"Output size: {current_shape} -> {self.output_size}") - print(f"Sapcing: {current_spacing} -> {self.spacing}") + logging.info("Resampling mask to target shape and spacing") + logging.info(f"Resize Spacing: {current_spacing} -> {self.spacing}") + logging.info(f"Output size: {current_shape} -> {self.output_size}") spacing = monai.transforms.Spacing(pixdim=tuple(self.spacing), mode="nearest") - pad = monai.transforms.SpatialPad(spatial_size=tuple(self.output_size)) - crop = monai.transforms.CenterSpatialCrop(roi_size=tuple(self.output_size)) - labels = crop(pad(spacing(labels.squeeze(0)))).unsqueeze(0) + pad_crop = monai.transforms.ResizeWithPadOrCrop(spatial_size=tuple(self.output_size)) + labels = pad_crop(spacing(labels.squeeze(0))).unsqueeze(0).to(labels.dtype) + contained_labels = torch.unique(labels) if check_contains_target_labels: # check if the resampled mask still contains those target labels for anatomy_label in self.anatomy_list: if anatomy_label not in contained_labels: raise ValueError( - "Resampled mask does not contain required class labels. Please tune spacing and output size" + f"Resampled mask does not contain required class labels {anatomy_label}. Please tune spacing and output size." ) return labels def read_mask_information(self, mask_file): + """ + Read mask information from a file. + + Args: + mask_file (str): Path to the mask file. + + Returns: + tuple: A tuple containing the mask tensor and associated information. + """ val_data = self.val_transforms(mask_file) - for key in ["pseudo_label", "spacing", "top_region_index", "bottom_region_index"]: + for key in [ + "pseudo_label", + "spacing", + "top_region_index", + "bottom_region_index", + ]: val_data[key] = val_data[key].unsqueeze(0).to(self.device) return ( @@ -645,6 +971,18 @@ def read_mask_information(self, mask_file): ) def find_closest_masks(self, num_img): + """ + Find the closest matching masks from the database. + + Args: + num_img (int): Number of images to generate. + + Returns: + list: List of closest matching mask candidates. + + Raises: + ValueError: If suitable candidates cannot be found. + """ # first check the database based on anatomy list candidates = find_masks( self.body_region, @@ -694,6 +1032,20 @@ def find_closest_masks(self, num_img): raise ValueError("Cannot find body region with given organ list.") return final_candidates - def quality_check(self, image): - # This version disabled quality check - return True + def quality_check(self, image_data, label_data): + """ + Perform a quality check on the generated image. + Args: + image_data (np.ndarray): The generated image. + label_data (np.ndarray): The corresponding whole body mask. + Returns: + bool: True if the image passes the quality check, False otherwise. + """ + outlier_results = is_outlier(self.median_statistics, image_data, label_data, self.label_int_dict) + for label, result in outlier_results.items(): + if result.get("is_outlier", False): + logging.info( + f"Generated image quality check for label '{label}' failed: median value {result['median_value']} is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})." + ) + return False + return True \ No newline at end of file diff --git a/models/maisi_ct_generative/scripts/trainer.py b/models/maisi_ct_generative/scripts/trainer.py index e935e325..bf5876df 100644 --- a/models/maisi_ct_generative/scripts/trainer.py +++ b/models/maisi_ct_generative/scripts/trainer.py @@ -15,7 +15,7 @@ import torch import torch.nn.functional as F -from generative.networks.schedulers import Scheduler +from monai.networks.schedulers import Scheduler from monai.config import IgniteInfo from monai.engines.trainer import Trainer from monai.engines.utils import IterationEvents, PrepareBatchExtraInput, default_metric_cmp_fn @@ -50,7 +50,7 @@ class MAISIControlNetTrainer(Trainer): max_epochs: the total epoch number for trainer to run. train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader. controlnet: controlnet to train in the trainer, should be regular PyTorch `torch.nn.Module`. - difusion_unet: difusion_unet used in the trainer, should be regular PyTorch `torch.nn.Module`. + diffusion_unet: diffusion_unet used in the trainer, should be regular PyTorch `torch.nn.Module`. optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim` or its subclass. epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. @@ -98,7 +98,7 @@ def __init__( max_epochs: int, train_data_loader: Iterable | DataLoader, controlnet: torch.nn.Module, - difusion_unet: torch.nn.Module, + diffusion_unet: torch.nn.Module, optimizer: Optimizer, loss_function: Callable, inferer: Inferer, @@ -143,7 +143,7 @@ def __init__( ) self.controlnet = controlnet - self.difusion_unet = difusion_unet + self.diffusion_unet = diffusion_unet self.optimizer = optimizer self.loss_function = loss_function self.inferer = inferer @@ -151,7 +151,7 @@ def __init__( self.hyper_kwargs = hyper_kwargs self.noise_scheduler = noise_scheduler self.logger.addFilter(RankFilter()) - for p in self.difusion_unet.parameters(): + for p in self.diffusion_unet.parameters(): p.requires_grad = False print("freeze the parameters of the diffusion unet model.") @@ -200,7 +200,7 @@ def _compute_pred_loss(): down_block_res_samples, mid_block_res_sample = engine.controlnet( x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond ) - noise_pred = engine.difusion_unet( + noise_pred = engine.diffusion_unet( x=noisy_latent, timesteps=timesteps, top_region_index_tensor=top_region_index, diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index 5d939aff..f9c843ca 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -9,15 +9,24 @@ # See the License for the specific language governing permissions and import copy import json +import logging +import math import os import zipfile -from typing import Sequence +from argparse import Namespace +from datetime import timedelta +from typing import Any, Sequence import numpy as np import skimage import torch +import torch.distributed as dist import torch.nn.functional as F +from monai.bundle import ConfigParser from monai.config import DtypeLike, NdarrayOrTensor +from monai.data import CacheDataset, DataLoader, partition_dataset +from monai.transforms import Compose, EnsureTyped, Lambdad, LoadImaged, Orientationd +from monai.transforms.utils_morphological_ops import dilate, erode from monai.utils import ( TransformBackends, convert_data_type, @@ -26,6 +35,8 @@ get_equivalent_dtype, ) from scipy import stats +from torch import Tensor + def unzip_dataset(dataset_dir): @@ -36,6 +47,7 @@ def unzip_dataset(dataset_dir): with zipfile.ZipFile(zip_file_path, "r") as zip_ref: zip_ref.extractall(path=os.path.dirname(dataset_dir)) print(f"Unzipped {zip_file_path} to {dataset_dir}.") + return def add_data_dir2path(list_files, data_dir, fold=None): @@ -70,7 +82,42 @@ def maisi_datafold_read(json_list, data_base_dir, fold=None): return train_files, val_files +def remap_labels(mask, label_dict_remap_json): + """ + Remap labels in the mask according to the provided label dictionary. + + This function reads a JSON file containing label mapping information and applies + the mapping to the input mask. + + Args: + mask (Tensor): The input mask tensor to be remapped. + label_dict_remap_json (str): Path to the JSON file containing the label mapping dictionary. + + Returns: + Tensor: The remapped mask tensor. + """ + with open(label_dict_remap_json, "r") as f: + mapping_dict = json.load(f) + mapper = MapLabelValue( + orig_labels=[pair[0] for pair in mapping_dict.values()], + target_labels=[pair[1] for pair in mapping_dict.values()], + dtype=torch.uint8, + ) + return mapper(mask[0, ...])[None, ...].to(mask.device) + + def get_index_arr(img): + """ + Generate an index array for the given image. + + This function creates a 3D array of indices corresponding to the dimensions of the input image. + + Args: + img (ndarray): The input image array. + + Returns: + ndarray: A 3D array containing the indices for each dimension of the input image. + """ return np.moveaxis( np.moveaxis( np.stack(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), np.arange(img.shape[2]))), 0, 3 @@ -81,7 +128,22 @@ def get_index_arr(img): def supress_non_largest_components(img, target_label, default_val=0): - """As a last step, supress all non largest components""" + """ + Suppress all components except the largest one(s) for specified target labels. + + This function identifies the largest component(s) for each target label and + suppresses all other smaller components. + + Args: + img (ndarray): The input image array. + target_label (list): List of label values to process. + default_val (int, optional): Value to assign to suppressed voxels. Defaults to 0. + + Returns: + tuple: A tuple containing: + - ndarray: Modified image with non-largest components suppressed. + - int: Number of voxels that were changed. + """ index_arr = get_index_arr(img) img_mod = copy.deepcopy(img) new_background = np.zeros(img.shape, dtype=np.bool_) @@ -102,68 +164,289 @@ def supress_non_largest_components(img, target_label, default_val=0): return img_mod, diff -def erode3d(input_tensor, erosion=3, value=0.0): - # Define the structuring element - erosion = ensure_tuple_rep(erosion, 3) - structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device) +def erode_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: + """ + Erode 2D/3D binary mask with data type as torch tensor. + + Args: + mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. + filter_size: erosion filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. - # Pad the input tensor to handle border pixels - input_padded = F.pad( - input_tensor.float().unsqueeze(0).unsqueeze(0), - (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2), - mode="constant", - value=value, + Return: + Tensor: eroded mask, same shape as input. + """ + return ( + erode( + mask_t.float() + .unsqueeze(0) + .unsqueeze( + 0, + ), + filter_size, + pad_value=pad_value, + ) + .squeeze(0) + .squeeze(0) + ) + + +def dilate_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: + """ + Dilate 2D/3D binary mask with data type as torch tensor. + + Args: + mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. + filter_size: dilation filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. + + Return: + Tensor: dilated mask, same shape as input. + """ + return ( + dilate( + mask_t.float() + .unsqueeze(0) + .unsqueeze( + 0, + ), + filter_size, + pad_value=pad_value, + ) + .squeeze(0) + .squeeze(0) ) - # Apply erosion operation - output = F.conv3d(input_padded, structuring_element, padding=0) - # Set output values based on the minimum value within the structuring element - output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) +def binarize_labels(x: Tensor, bits: int = 8) -> Tensor: + """ + Convert input tensor to binary representation. + + This function takes an input tensor and converts it to a binary representation + using the specified number of bits. + + Args: + x (Tensor): Input tensor with shape (B, 1, H, W, D). + bits (int, optional): Number of bits to use for binary representation. Defaults to 8. + + Returns: + Tensor: Binary representation of the input tensor with shape (B, bits, H, W, D). + """ + mask = 2 ** torch.arange(bits).to(x.device, x.dtype) + return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().squeeze(1).permute(0, 4, 1, 2, 3) + + +def setup_ddp(rank: int, world_size: int) -> torch.device: + """ + Initialize the distributed process group. + + Args: + rank (int): rank of the current process. + world_size (int): number of processes participating in the job. + + Returns: + torch.device: device of the current process. + """ + dist.init_process_group( + backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size + ) + dist.barrier() + device = torch.device(f"cuda:{rank}") + return device + + +def define_instance(args: Namespace, instance_def_key: str) -> Any: + """ + Define and instantiate an object based on the provided arguments and instance definition key. + + This function uses a ConfigParser to parse the arguments and instantiate an object + defined by the instance_def_key. + + Args: + args: An object containing the arguments to be parsed. + instance_def_key (str): The key used to retrieve the instance definition from the parsed content. + + Returns: + The instantiated object as defined by the instance_def_key in the parsed configuration. + """ + parser = ConfigParser(vars(args)) + parser.parse(True) + return parser.get_parsed_content(instance_def_key, instantiate=True) + + +def add_data_dir2path(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]: + """ + Read a list of data dictionary. + + Args: + list_files (list): input data to load and transform to generate dataset for model. + data_dir (str): directory of files. + fold (int, optional): fold index for cross validation. Defaults to None. + + Returns: + tuple[list, list]: A tuple of two arrays (training, validation). + """ + new_list_files = copy.deepcopy(list_files) + if fold is not None: + new_list_files_train = [] + new_list_files_val = [] + for d in new_list_files: + d["image"] = os.path.join(data_dir, d["image"]) + + if "label" in d: + d["label"] = os.path.join(data_dir, d["label"]) - return output.squeeze(0).squeeze(0) + if fold is not None: + if d["fold"] == fold: + new_list_files_val.append(copy.deepcopy(d)) + else: + new_list_files_train.append(copy.deepcopy(d)) + if fold is not None: + return new_list_files_train, new_list_files_val + else: + return new_list_files, [] -def dilate3d(input_tensor, erosion=3, value=0.0): - # Define the structuring element - erosion = ensure_tuple_rep(erosion, 3) - structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device) - # Pad the input tensor to handle border pixels - input_padded = F.pad( - input_tensor.float().unsqueeze(0).unsqueeze(0), - (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2), - mode="constant", - value=value, +def prepare_maisi_controlnet_json_dataloader( + json_data_list: list | str, + data_base_dir: list | str, + batch_size: int = 1, + fold: int = 0, + cache_rate: float = 0.0, + rank: int = 0, + world_size: int = 1, +) -> tuple[DataLoader, DataLoader]: + """ + Prepare dataloaders for training and validation. + + Args: + json_data_list (list | str): the name of JSON files listing the data. + data_base_dir (list | str): directory of files. + batch_size (int, optional): how many samples per batch to load . Defaults to 1. + fold (int, optional): fold index for cross validation. Defaults to 0. + cache_rate (float, optional): percentage of cached data in total. Defaults to 0.0. + rank (int, optional): rank of the current process. Defaults to 0. + world_size (int, optional): number of processes participating in the job. Defaults to 1. + + Returns: + tuple[DataLoader, DataLoader]: A tuple of two dataloaders (training, validation). + """ + use_ddp = world_size > 1 + if isinstance(json_data_list, list): + assert isinstance(data_base_dir, list) + list_train = [] + list_valid = [] + for data_list, data_root in zip(json_data_list, data_base_dir): + with open(data_list, "r") as f: + json_data = json.load(f)["training"] + train, val = add_data_dir2path(json_data, data_root, fold) + list_train += train + list_valid += val + else: + with open(json_data_list, "r") as f: + json_data = json.load(f)["training"] + list_train, list_valid = add_data_dir2path(json_data, data_base_dir, fold) + + common_transform = [ + LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True), + Orientationd(keys=["label"], axcodes="RAS"), + EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True), + Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)), + Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)), + Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), + Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2), + ] + train_transforms, val_transforms = Compose(common_transform), Compose(common_transform) + + train_loader = None + + if use_ddp: + list_train = partition_dataset( + data=list_train, + shuffle=True, + num_partitions=world_size, + even_divisible=True, + )[rank] + train_ds = CacheDataset(data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8) + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) + if use_ddp: + list_valid = partition_dataset( + data=list_valid, + shuffle=True, + num_partitions=world_size, + even_divisible=False, + )[rank] + val_ds = CacheDataset( + data=list_valid, + transform=val_transforms, + cache_rate=cache_rate, + num_workers=8, ) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) + return train_loader, val_loader - # Apply erosion operation - output = F.conv3d(input_padded, structuring_element, padding=0) - # Set output values based on the minimum value within the structuring element - output = torch.where(output > 0, 1.0, 0.0) +def organ_fill_by_closing(data, target_label, device, close_times=2, filter_size=3, pad_value=0.0): + """ + Fill holes in an organ mask using morphological closing operations. - return output.squeeze(0).squeeze(0) + This function performs a series of dilation and erosion operations to fill holes + in the organ mask identified by the target label. + Args: + data (ndarray): The input data containing organ labels. + target_label (int): The label of the organ to be processed. + device (str): The device to perform the operations on (e.g., 'cuda:0'). + close_times (int, optional): Number of times to perform the closing operation. Defaults to 2. + filter_size (int, optional): Size of the filter for dilation and erosion. Defaults to 3. + pad_value (float, optional): Value used for padding in dilation and erosion. Defaults to 0.0. -def organ_fill_by_closing(data, target_label, device): + Returns: + ndarray: Boolean mask of the filled organ. + """ mask = (data == target_label).astype(np.uint8) - mask = dilate3d(torch.from_numpy(mask).to(device), erosion=3, value=0.0) - mask = erode3d(mask, erosion=3, value=0.0) - mask = dilate3d(mask, erosion=3, value=0.0) - mask = erode3d(mask, erosion=3, value=0.0).cpu().numpy() - return mask.astype(np.bool_) + mask = torch.from_numpy(mask).to(device) + for _ in range(close_times): + mask = dilate_one_img(mask, filter_size=filter_size, pad_value=pad_value) + mask = erode_one_img(mask, filter_size=filter_size, pad_value=pad_value) + return mask.cpu().numpy().astype(np.bool_) def organ_fill_by_removed_mask(data, target_label, remove_mask, device): + """ + Fill an organ mask in regions where it was previously removed. + + Args: + data (ndarray): The input data containing organ labels. + target_label (int): The label of the organ to be processed. + remove_mask (ndarray): Boolean mask indicating regions where the organ was removed. + device (str): The device to perform the operations on (e.g., 'cuda:0'). + + Returns: + ndarray: Boolean mask of the filled organ in previously removed regions. + """ mask = (data == target_label).astype(np.uint8) - mask = dilate3d(torch.from_numpy(mask).to(device), erosion=3, value=0.0) - mask = dilate3d(mask, erosion=3, value=0.0) - roi_oragn_mask = dilate3d(mask, erosion=3, value=0.0).cpu().numpy() + mask = dilate_one_img(torch.from_numpy(mask).to(device), filter_size=3, pad_value=0.0) + mask = dilate_one_img(mask, filter_size=3, pad_value=0.0) + roi_oragn_mask = dilate_one_img(mask, filter_size=3, pad_value=0.0).cpu().numpy() return (roi_oragn_mask * remove_mask).astype(np.bool_) def get_body_region_index_from_mask(input_mask): + """ + Determine the top and bottom body region indices from an input mask. + + Args: + input_mask (Tensor): Input mask tensor containing body region labels. + + Returns: + tuple: Two lists representing the top and bottom region indices. + """ region_indices = {} # head and neck region_indices["region_0"] = [22, 120] @@ -177,7 +460,7 @@ def get_body_region_index_from_mask(input_mask): nda = input_mask.cpu().numpy().squeeze() unique_elements = np.lib.arraysetops.unique(nda) unique_elements = list(unique_elements) - print(f"nda: {nda.shape} {unique_elements}.") + # print(f"nda: {nda.shape} {unique_elements}.") overlap_array = np.zeros(len(region_indices), dtype=np.uint8) for _j in range(len(region_indices)): overlap = any(element in region_indices[f"region_{_j}"] for element in unique_elements) @@ -189,20 +472,39 @@ def get_body_region_index_from_mask(input_mask): bottom_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amax(overlap_array_indices), ...] bottom_region_index = list(bottom_region_index) bottom_region_index = [int(_k) for _k in bottom_region_index] - print(f"{top_region_index} {bottom_region_index}") + # print(f"{top_region_index} {bottom_region_index}") return top_region_index, bottom_region_index def general_mask_generation_post_process(volume_t, target_tumor_label=None, device="cuda:0"): + """ + Perform post-processing on a generated mask volume. + + This function applies various refinement steps to improve the quality of the generated mask, + including body mask refinement, tumor prediction refinement, and organ-specific processing. + + Args: + volume_t (ndarray): Input volume containing organ and tumor labels. + target_tumor_label (int, optional): Label of the target tumor. Defaults to None. + device (str, optional): Device to perform operations on. Defaults to "cuda:0". + + Returns: + ndarray: Post-processed volume with refined organ and tumor labels. + """ # assume volume_t is np array with shape (H,W,D) hepatic_vessel = volume_t == 25 airway = volume_t == 132 # ------------ refine body mask pred - body_region_mask = erode3d(torch.from_numpy((volume_t > 0)).to(device), erosion=3, value=0.0).cpu().numpy() + body_region_mask = ( + erode_one_img(torch.from_numpy((volume_t > 0)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() + ) body_region_mask, _ = supress_non_largest_components(body_region_mask, [1]) body_region_mask = ( - dilate3d(torch.from_numpy(body_region_mask).to(device), erosion=3, value=0.0).cpu().numpy().astype(np.uint8) + dilate_one_img(torch.from_numpy(body_region_mask).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() + .astype(np.uint8) ) volume_t = volume_t * body_region_mask @@ -248,7 +550,9 @@ def general_mask_generation_post_process(volume_t, target_tumor_label=None, devi if target_tumor_label == 23 and np.sum(target_tumor) > 0: # speical process for cases with lung tumor - dia_lung_tumor_mask = dilate3d(torch.from_numpy((data == 23)).to(device), erosion=3, value=0.0).cpu().numpy() + dia_lung_tumor_mask = ( + dilate_one_img(torch.from_numpy((data == 23)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() + ) tmp = ( (data * (dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8))).astype(np.float32).flatten() ) @@ -256,14 +560,16 @@ def general_mask_generation_post_process(volume_t, target_tumor_label=None, devi mode = int(stats.mode(tmp.flatten(), nan_policy="omit")[0]) if mode in [28, 29, 30, 31, 32]: dia_lung_tumor_mask = ( - dilate3d(torch.from_numpy(dia_lung_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() + dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() ) lung_remove_mask = dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8).astype(np.uint8) data[organ_fill_by_removed_mask(data, target_label=mode, remove_mask=lung_remove_mask, device=device)] = ( mode ) dia_lung_tumor_mask = ( - dilate3d(torch.from_numpy(dia_lung_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() + dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() ) data[ organ_fill_by_removed_mask( @@ -284,9 +590,13 @@ def general_mask_generation_post_process(volume_t, target_tumor_label=None, devi data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 dia_tumor_mask = ( - dilate3d(torch.from_numpy((data == target_tumor_label)).to(device), erosion=3, value=0.0).cpu().numpy() + dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() + ) + dia_tumor_mask = ( + dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() ) - dia_tumor_mask = dilate3d(torch.from_numpy(dia_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() data[ organ_fill_by_removed_mask( data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device @@ -311,9 +621,13 @@ def general_mask_generation_post_process(volume_t, target_tumor_label=None, devi if target_tumor_label == 27 and np.sum(target_tumor) > 0: # speical process for cases with colon tumor dia_tumor_mask = ( - dilate3d(torch.from_numpy((data == target_tumor_label)).to(device), erosion=3, value=0.0).cpu().numpy() + dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() + ) + dia_tumor_mask = ( + dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() ) - dia_tumor_mask = dilate3d(torch.from_numpy(dia_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() data[ organ_fill_by_removed_mask( data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device @@ -375,6 +689,15 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) def __call__(self, img: NdarrayOrTensor): + """ + Apply the label mapping to the input image. + + Args: + img (NdarrayOrTensor): Input image to be remapped. + + Returns: + NdarrayOrTensor: Remapped image. + """ if self.use_numpy: img_np, *_ = convert_data_type(img, np.ndarray) _out_shape = img_np.shape @@ -396,34 +719,44 @@ def __call__(self, img: NdarrayOrTensor): return out -def load_autoencoder_ckpt(load_autoencoder_path): - checkpoint_autoencoder = torch.load(load_autoencoder_path) - new_state_dict = {} - for k, v in checkpoint_autoencoder.items(): - if "decoder" in k and "conv" in k: - new_key = ( - k.replace("conv.weight", "conv.conv.weight") - if "conv.weight" in k - else k.replace("conv.bias", "conv.conv.bias") - ) - new_state_dict[new_key] = v - elif "encoder" in k and "conv" in k: - new_key = ( - k.replace("conv.weight", "conv.conv.weight") - if "conv.weight" in k - else k.replace("conv.bias", "conv.conv.bias") - ) - new_state_dict[new_key] = v - else: - new_state_dict[k] = v - checkpoint_autoencoder = new_state_dict - return checkpoint_autoencoder +def KL_loss(z_mu, z_sigma): + """ + Compute the Kullback-Leibler (KL) divergence loss for a variational autoencoder (VAE). + The KL divergence measures how one probability distribution diverges from a second, expected probability distribution. + In the context of VAEs, this loss term ensures that the learned latent space distribution is close to a standard normal distribution. -def binarize_labels(x, bits=8): + Args: + z_mu (torch.Tensor): Mean of the latent variable distribution, shape [N,C,H,W,D] or [N,C,H,W]. + z_sigma (torch.Tensor): Standard deviation of the latent variable distribution, same shape as 'z_mu'. + + Returns: + torch.Tensor: The computed KL divergence loss, averaged over the batch. """ - x: the input tensor with shape (B, 1, H, W, D) - bits: the num of channel to represent the data. + eps = 1e-10 + kl_loss = 0.5 * torch.sum( + z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2) + eps) - 1, + dim=list(range(1, len(z_sigma.shape))), + ) + return torch.sum(kl_loss) / kl_loss.shape[0] + + +def dynamic_infer(inferer, model, images): """ - mask = 2 ** torch.arange(bits).to(x.device, x.dtype) - return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().squeeze(1).permute(0, 4, 1, 2, 3) + Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer. + + This function determines whether to use the model directly or to use the provided inferer + (such as a sliding window inferer) based on the size of the input images. + + Args: + inferer: An inference object, typically a monai SlidingWindowInferer, which handles patch-based inference. + model (torch.nn.Module): The model used for inference. + images (torch.Tensor): The input images for inference, shape [N,C,H,W,D] or [N,C,H,W]. + + Returns: + torch.Tensor: The output from the model or the inferer, depending on the input size. + """ + if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size): + return model(images) + else: + return inferer(network=model, inputs=images) \ No newline at end of file From 4be4de89834beedae9a674892d0a532aa6d3d024 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:06:24 +0000 Subject: [PATCH 02/19] add readme about download Signed-off-by: Can-Zhao --- models/maisi_ct_generative/docs/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/models/maisi_ct_generative/docs/README.md b/models/maisi_ct_generative/docs/README.md index bcb7c05f..ab4f0bcb 100644 --- a/models/maisi_ct_generative/docs/README.md +++ b/models/maisi_ct_generative/docs/README.md @@ -23,6 +23,10 @@ The inference requires: - Disk Memory: at least 21GB disk memory ### Execute inference +The model weights can be downloaded with +``` +python -m scripts.download_files +``` The following code generates a synthetic image from a random sampled noise. ``` python -m monai.bundle run --config_file configs/inference.json From cd0efa2472ca40148c1688d650143b85dbb92af8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 18:07:25 +0000 Subject: [PATCH 03/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../configs/inference.json | 45 ++++++++++++++++--- models/maisi_ct_generative/configs/train.json | 6 ++- models/maisi_ct_generative/docs/README.md | 2 +- models/maisi_ct_generative/scripts/sample.py | 2 +- models/maisi_ct_generative/scripts/utils.py | 2 +- 5 files changed, 46 insertions(+), 11 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index dbc19f6c..7cea2c85 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -61,7 +61,11 @@ 64, 64 ], - "autoencoder_sliding_window_infer_size": [96, 96, 96], + "autoencoder_sliding_window_infer_size": [ + 96, + 96, + 96 + ], "autoencoder_sliding_window_infer_overlap": 0.6667, "autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", @@ -74,7 +78,11 @@ 128, 256 ], - "num_res_blocks": [2,2,2], + "num_res_blocks": [ + 2, + 2, + 2 + ], "norm_num_groups": 32, "norm_eps": 1e-06, "attention_levels": [ @@ -144,7 +152,11 @@ "num_res_blocks": 2, "use_flash_attention": true, "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [8, 32, 64] + "conditioning_embedding_num_channels": [ + 8, + 32, + 64 + ] }, "mask_generation_autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", @@ -157,7 +169,11 @@ 64, 128 ], - "num_res_blocks": [1, 2, 2], + "num_res_blocks": [ + 1, + 2, + 2 + ], "norm_num_groups": 32, "norm_eps": 1e-06, "attention_levels": [ @@ -179,9 +195,24 @@ "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "out_channels": "@latent_channels", - "channels":[64, 128, 256, 512], - "attention_levels":[false, false, true, true], - "num_head_channels":[0, 0, 32, 32], + "channels": [ + 64, + 128, + 256, + 512 + ], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], "num_res_blocks": 2, "use_flash_attention": true, "with_conditioning": true, diff --git a/models/maisi_ct_generative/configs/train.json b/models/maisi_ct_generative/configs/train.json index 51f6a236..ab88ffad 100644 --- a/models/maisi_ct_generative/configs/train.json +++ b/models/maisi_ct_generative/configs/train.json @@ -82,7 +82,11 @@ "num_res_blocks": 2, "use_flash_attention": true, "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [8, 32, 64] + "conditioning_embedding_num_channels": [ + 8, + 32, + 64 + ] }, "noise_scheduler": { "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", diff --git a/models/maisi_ct_generative/docs/README.md b/models/maisi_ct_generative/docs/README.md index ab4f0bcb..20a87c19 100644 --- a/models/maisi_ct_generative/docs/README.md +++ b/models/maisi_ct_generative/docs/README.md @@ -23,7 +23,7 @@ The inference requires: - Disk Memory: at least 21GB disk memory ### Execute inference -The model weights can be downloaded with +The model weights can be downloaded with ``` python -m scripts.download_files ``` diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 3c78236b..69a7869a 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -1048,4 +1048,4 @@ def quality_check(self, image_data, label_data): f"Generated image quality check for label '{label}' failed: median value {result['median_value']} is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})." ) return False - return True \ No newline at end of file + return True diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index f9c843ca..ef66cc82 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -759,4 +759,4 @@ def dynamic_infer(inferer, model, images): if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size): return model(images) else: - return inferer(network=model, inputs=images) \ No newline at end of file + return inferer(network=model, inputs=images) From 31b5e7c4149526c2eba95cfc213becd0d3891949 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:09:19 +0000 Subject: [PATCH 04/19] update meta Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/metadata.json | 1 + 1 file changed, 1 insertion(+) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index 87c0eff2..11199e4b 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -2,6 +2,7 @@ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json", "version": "0.3.6", "changelog": { + "0.4.0": "update to use monai 1.4, model ckpt updated, rm GenerativeAI repo, add quality check", "0.3.6": "first oss version" }, "monai_version": "1.3.1", From eaa974f69bdae0ebf8cd39acbe03b71420a12648 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:14:14 +0000 Subject: [PATCH 05/19] update meta Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/metadata.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index 11199e4b..c9475207 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -1,6 +1,6 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json", - "version": "0.3.6", + "version": "0.4.0", "changelog": { "0.4.0": "update to use monai 1.4, model ckpt updated, rm GenerativeAI repo, add quality check", "0.3.6": "first oss version" From 8c14e4ff833ec570cf58b5125dacc79a1b3c528c Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:18:43 +0000 Subject: [PATCH 06/19] isort, balck Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/download_files.py | 7 ++++--- models/maisi_ct_generative/scripts/sample.py | 4 ++-- models/maisi_ct_generative/scripts/trainer.py | 2 +- models/maisi_ct_generative/scripts/utils.py | 1 - 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/models/maisi_ct_generative/scripts/download_files.py b/models/maisi_ct_generative/scripts/download_files.py index aee1a416..4616568d 100644 --- a/models/maisi_ct_generative/scripts/download_files.py +++ b/models/maisi_ct_generative/scripts/download_files.py @@ -1,11 +1,12 @@ -import yaml import os + +import yaml from monai.apps import download_url # Load YAML file -with open('large_files.yml', 'r') as file: +with open("large_files.yml", "r") as file: data = yaml.safe_load(file) # Iterate over each file in the YAML and download it -for file in data['large_files']: +for file in data["large_files"]: download_url(url=file["url"], filepath=file["path"]) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 69a7869a..493fb832 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -19,17 +19,17 @@ import monai import torch -from monai.inferers.inferer import DiffusionInferer from monai.data import MetaTensor from monai.inferers import sliding_window_inference +from monai.inferers.inferer import DiffusionInferer from monai.transforms import Compose, SaveImage from monai.utils import set_determinism from tqdm import tqdm from .augmentation import augmentation from .find_masks import find_masks -from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels from .quality_check import is_outlier +from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels class ReconModel(torch.nn.Module): diff --git a/models/maisi_ct_generative/scripts/trainer.py b/models/maisi_ct_generative/scripts/trainer.py index bf5876df..7c779ef3 100644 --- a/models/maisi_ct_generative/scripts/trainer.py +++ b/models/maisi_ct_generative/scripts/trainer.py @@ -15,11 +15,11 @@ import torch import torch.nn.functional as F -from monai.networks.schedulers import Scheduler from monai.config import IgniteInfo from monai.engines.trainer import Trainer from monai.engines.utils import IterationEvents, PrepareBatchExtraInput, default_metric_cmp_fn from monai.inferers import Inferer +from monai.networks.schedulers import Scheduler from monai.transforms import Transform from monai.utils import RankFilter, min_version, optional_import from monai.utils.enums import CommonKeys as Keys diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index ef66cc82..7c548d6d 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -38,7 +38,6 @@ from torch import Tensor - def unzip_dataset(dataset_dir): if not os.path.exists(dataset_dir): zip_file_path = dataset_dir + ".zip" From bf36c983aff981c01851f45ab8658b5971fb4874 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:20:47 +0000 Subject: [PATCH 07/19] update meta Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/metadata.json | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index c9475207..87b8c1ed 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -5,15 +5,13 @@ "0.4.0": "update to use monai 1.4, model ckpt updated, rm GenerativeAI repo, add quality check", "0.3.6": "first oss version" }, - "monai_version": "1.3.1", + "monai_version": "1.4.0", "pytorch_version": "2.2.2", "numpy_version": "1.24.4", "optional_packages_version": { "fire": "0.6.0", "nibabel": "5.2.1", - "monai-generative": "0.2.3", "tqdm": "4.66.2", - "xformers": "0.0.26" }, "supported_apps": { "maisi-nim": "" From 68630eeda92bcca8c3a9c01eefa6fe9be3673b50 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:23:08 +0000 Subject: [PATCH 08/19] update meta Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/metadata.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index 87b8c1ed..adfbf72c 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -11,7 +11,7 @@ "optional_packages_version": { "fire": "0.6.0", "nibabel": "5.2.1", - "tqdm": "4.66.2", + "tqdm": "4.66.2" }, "supported_apps": { "maisi-nim": "" From ca5bf221ae9b889c2b21a9e2b4220af6b1fd418b Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:30:23 +0000 Subject: [PATCH 09/19] rm custom network Signed-off-by: Can-Zhao --- .../scripts/custom_network_controlnet.py | 177 -- .../scripts/custom_network_diffusion.py | 1993 ----------------- .../scripts/custom_network_tp.py | 1053 --------- 3 files changed, 3223 deletions(-) delete mode 100644 models/maisi_ct_generative/scripts/custom_network_controlnet.py delete mode 100644 models/maisi_ct_generative/scripts/custom_network_diffusion.py delete mode 100644 models/maisi_ct_generative/scripts/custom_network_tp.py diff --git a/models/maisi_ct_generative/scripts/custom_network_controlnet.py b/models/maisi_ct_generative/scripts/custom_network_controlnet.py deleted file mode 100644 index ad36c5b6..00000000 --- a/models/maisi_ct_generative/scripts/custom_network_controlnet.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from typing import Sequence - -import torch -from generative.networks.nets.controlnet import ControlNet -from generative.networks.nets.diffusion_model_unet import get_timestep_embedding - - -class CustomControlNet(ControlNet): - """ - Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image - Diffusion Models" (https://arxiv.org/abs/2302.05543) - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for up/downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - conditioning_embedding_in_channels: number of input channels for the conditioning embedding. - conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - conditioning_embedding_in_channels: int = 1, - conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), - ) -> None: - super().__init__( - spatial_dims, - in_channels, - num_res_blocks, - num_channels, - attention_levels, - norm_num_groups, - norm_eps, - resblock_updown, - num_head_channels, - with_conditioning, - transformer_num_layers, - cross_attention_dim, - num_class_embeds, - upcast_attention, - use_flash_attention, - conditioning_embedding_in_channels, - conditioning_embedding_num_channels, - ) - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - context: torch.Tensor | None = None, - class_labels: torch.Tensor | None = None, - ) -> tuple[list[torch.Tensor], torch.Tensor]: - """ - Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims). - conditioning_scale: conditioning scale. - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - """ - # 1. time - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=x.dtype) - emb = self.time_embed(t_emb) - - # 2. class - if self.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=x.dtype) - emb = emb + class_emb - - # 3. initial convolution - h = self.conv_in(x) - - # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - controlnet_cond = torch.utils.checkpoint.checkpoint( - self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False - ) - - h += controlnet_cond - - # 4. down - if context is not None and self.with_conditioning is False: - raise ValueError("model should have with_conditioning = True if context is provided") - down_block_res_samples = [h] - for downsample_block in self.down_blocks: - h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) - for residual in res_samples: - down_block_res_samples.append(residual) - - # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) - - # 6. Control net blocks - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(h) - - # 6. scaling - down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] - mid_block_res_sample *= conditioning_scale - - return down_block_res_samples, mid_block_res_sample diff --git a/models/maisi_ct_generative/scripts/custom_network_diffusion.py b/models/maisi_ct_generative/scripts/custom_network_diffusion.py deleted file mode 100644 index 9e4cbf60..00000000 --- a/models/maisi_ct_generative/scripts/custom_network_diffusion.py +++ /dev/null @@ -1,1993 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/huggingface/diffusers -# which has the following license: -# https://github.com/huggingface/diffusers/blob/main/LICENSE -# -# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= - -from __future__ import annotations - -import importlib.util -import math -from collections.abc import Sequence - -import torch -import torch.nn.functional as F -from monai.networks.blocks import Convolution, MLPBlock -from monai.networks.layers.factories import Pool -from monai.utils import ensure_tuple_rep -from torch import nn - -if importlib.util.find_spec("xformers") is not None: - import xformers - import xformers.ops - - has_xformers = True -else: - xformers = None - has_xformers = False - - -# TODO: Use MONAI's optional_import -# from monai.utils import optional_import -# xformers, has_xformers = optional_import("xformers.ops", name="xformers") - -__all__ = ["CustomDiffusionModelUNet"] - - -def zero_module(module: nn.Module) -> nn.Module: - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -class CrossAttention(nn.Module): - """ - A cross attention layer. - - Args: - query_dim: number of channels in the query. - cross_attention_dim: number of channels in the context. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each head. - dropout: dropout probability to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: int | None = None, - num_attention_heads: int = 8, - num_head_channels: int = 64, - dropout: float = 0.0, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - inner_dim = num_head_channels * num_attention_heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - self.scale = 1 / math.sqrt(num_head_channels) - self.num_heads = num_attention_heads - - self.upcast_attention = upcast_attention - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype=dtype) - - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - query = self.to_q(x) - context = context if context is not None else x - key = self.to_k(context) - value = self.to_v(context) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - return self.to_out(x) - - -class BasicTransformerBlock(nn.Module): - """ - A basic Transformer block. - - Args: - num_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - dropout: dropout probability to use. - cross_attention_dim: size of the context vector for cross attention. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - num_channels: int, - num_attention_heads: int, - num_head_channels: int, - dropout: float = 0.0, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.attn1 = CrossAttention( - query_dim=num_channels, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention - self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) - self.attn2 = CrossAttention( - query_dim=num_channels, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention if context is None - self.norm1 = nn.LayerNorm(num_channels) - self.norm2 = nn.LayerNorm(num_channels) - self.norm3 = nn.LayerNorm(num_channels) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - # 1. Self-Attention - x = self.attn1(self.norm1(x)) + x - - # 2. Cross-Attention - x = self.attn2(self.norm2(x), context=context) + x - - # 3. Feed-forward - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of channels in the input and output. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each attention head. - num_layers: number of layers of Transformer blocks to use. - dropout: dropout probability to use. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_attention_heads: int, - num_head_channels: int, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - inner_dim = num_attention_heads * num_head_channels - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - - self.proj_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=inner_dim, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - num_channels=inner_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) - for _ in range(num_layers) - ] - ) - - self.proj_out = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=inner_dim, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - ) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - # note: if no context is given, cross-attention defaults to self-attention - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - residual = x - x = self.norm(x) - x = self.proj_in(x) - - inner_dim = x.shape[1] - - if self.spatial_dims == 2: - x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - if self.spatial_dims == 3: - x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) - - for block in self.transformer_blocks: - x = block(x, context=context) - - if self.spatial_dims == 2: - x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - if self.spatial_dims == 3: - x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() - - x = self.proj_out(x) - return x + residual - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to - compute attention. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.scale = 1 / math.sqrt(num_channels / self.num_heads) - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - # norm - x = self.norm(x) - - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual - - -def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: - """ - Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic - Models" https://arxiv.org/abs/2006.11239. - - Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - embedding_dim: the dimension of the output. - max_period: controls the minimum frequency of the embeddings. - """ - # print(f'max_period: {max_period}; timesteps: {torch.norm(timesteps.float(), p=2)}; embedding_dim: {embedding_dim}') - - if timesteps.ndim != 1: - raise ValueError("Timesteps should be a 1d-array") - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) - freqs = torch.exp(exponent / half_dim) - - args = timesteps[:, None].float() * freqs[None, :] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) - - return embedding - - -class Downsample(nn.Module): - """ - Downsampling layer. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is - False, the number of output channels must be the same as the number of input channels. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points - for each dimension. - """ - - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.op = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=2, - kernel_size=3, - padding=padding, - conv_only=True, - ) - else: - if self.num_channels != self.out_channels: - raise ValueError("num_channels and out_channels must be equal when use_conv=False") - self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError( - f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " - f"({self.num_channels})" - ) - return self.op(x) - - -class Upsample(nn.Module): - """ - Upsampling layer with an optional convolution. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each - dimension. - """ - - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=padding, - conv_only=True, - ) - else: - self.conv = None - - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: - del emb - if x.shape[1] != self.num_channels: - raise ValueError("Input channels should be equal to num_channels") - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - - if self.use_conv: - x = self.conv(x) - return x - - -class ResnetBlock(nn.Module): - """ - Residual block with timestep conditioning. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - out_channels: number of output channels. - up: if True, performs upsampling. - down: if True, performs downsampling. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - out_channels: int | None = None, - up: bool = False, - down: bool = False, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.channels = in_channels - self.emb_channels = temb_channels - self.out_channels = out_channels or in_channels - self.up = up - self.down = down - - self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.nonlinearity = nn.SiLU() - self.conv1 = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) - elif down: - self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) - - self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) - - self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) - self.conv2 = zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - if self.out_channels == in_channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = self.nonlinearity(h) - - if self.upsample is not None: - if h.shape[0] >= 64: - x = x.contiguous() - h = h.contiguous() - x = self.upsample(x) - h = self.upsample(h) - elif self.downsample is not None: - x = self.downsample(x) - h = self.downsample(h) - - h = self.conv1(h) - - if self.spatial_dims == 2: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] - else: - temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] - h = h + temb - - h = self.norm2(h) - h = self.nonlinearity(h) - h = self.conv2(h) - - return self.skip_connection(x) + h - - -class DownBlock(nn.Module): - """ - Unet's down block containing resnet and downsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - del context - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class CrossAttnDownBlock(nn.Module): - """ - Unet's down block containing resnet, downsamplers and cross-attention blocks. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_downsample: bool = True, - resblock_updown: bool = False, - downsample_padding: int = 1, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - attentions.append( - SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - if resblock_updown: - self.downsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - down=True, - ) - else: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) - else: - self.downsampler = None - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - output_states = [] - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - output_states.append(hidden_states) - - if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states, temb) - output_states.append(hidden_states) - - return hidden_states, output_states - - -class AttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.attention = None - - self.resnet_1 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = AttentionBlock( - spatial_dims=spatial_dims, - num_channels=in_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - - self.resnet_2 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - del context - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class CrossAttnMidBlock(nn.Module): - """ - Unet's mid block containing resnet and cross-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.attention = None - - self.resnet_1 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - self.attention = SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=in_channels, - num_attention_heads=in_channels // num_head_channels, - num_head_channels=num_head_channels, - num_layers=transformer_num_layers, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - self.resnet_2 = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - - def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None - ) -> torch.Tensor: - hidden_states = self.resnet_1(hidden_states, temb) - hidden_states = self.attention(hidden_states, context=context) - hidden_states = self.resnet_2(hidden_states, temb) - - return hidden_states - - -class UpBlock(nn.Module): - """ - Unet's up block containing resnet and upsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - resnets = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class AttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=out_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - del context - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -class CrossAttnUpBlock(nn.Module): - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int = 1, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - add_upsample: bool = True, - resblock_updown: bool = False, - num_head_channels: int = 1, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - ) -> None: - super().__init__() - self.resblock_updown = resblock_updown - - resnets = [] - attentions = [] - - for i in range(num_res_blocks): - res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock( - spatial_dims=spatial_dims, - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - ) - ) - attentions.append( - SpatialTransformer( - spatial_dims=spatial_dims, - in_channels=out_channels, - num_attention_heads=out_channels // num_head_channels, - num_head_channels=num_head_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout=dropout_cattn, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - if resblock_updown: - self.upsampler = ResnetBlock( - spatial_dims=spatial_dims, - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - up=True, - ) - else: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) - else: - self.upsampler = None - - def forward( - self, - hidden_states: torch.Tensor, - res_hidden_states_list: list[torch.Tensor], - temb: torch.Tensor, - context: torch.Tensor | None = None, - ) -> torch.Tensor: - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_list[-1] - res_hidden_states_list = res_hidden_states_list[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=context) - - if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states, temb) - - return hidden_states - - -def get_down_block( - spatial_dims: int, - in_channels: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_downsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnDownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return DownBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=add_downsample, - resblock_updown=resblock_updown, - ) - - -def get_mid_block( - spatial_dims: int, - in_channels: int, - temb_channels: int, - norm_num_groups: int, - norm_eps: float, - with_conditioning: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_conditioning: - return CrossAttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return AttnMidBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - temb_channels=temb_channels, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - - -def get_up_block( - spatial_dims: int, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - num_res_blocks: int, - norm_num_groups: int, - norm_eps: float, - add_upsample: bool, - resblock_updown: bool, - with_attn: bool, - with_cross_attn: bool, - num_head_channels: int, - transformer_num_layers: int, - cross_attention_dim: int | None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, -) -> nn.Module: - if with_attn: - return AttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, - ) - elif with_cross_attn: - return CrossAttnUpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - num_head_channels=num_head_channels, - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - else: - return UpBlock( - spatial_dims=spatial_dims, - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - temb_channels=temb_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=add_upsample, - resblock_updown=resblock_updown, - ) - - -class CustomDiffusionModelUNet(nn.Module): - """ - Unet network with timestep embedding and attention mechanisms for conditioning based on - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 - and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see ResnetBlock) per level. - num_channels: tuple of block output channels. - attention_levels: list of levels to add attention. - norm_num_groups: number of groups for the normalization. - norm_eps: epsilon for the normalization. - resblock_updown: if True use residual blocks for up/downsampling. - num_head_channels: number of channels in each attention head. - with_conditioning: if True add spatial transformers to perform conditioning. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), - num_channels: Sequence[int] = (32, 64, 64, 64), - attention_levels: Sequence[bool] = (False, False, True, True), - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - resblock_updown: bool = False, - num_head_channels: int | Sequence[int] = 8, - with_conditioning: bool = False, - transformer_num_layers: int = 1, - cross_attention_dim: int | None = None, - num_class_embeds: int | None = None, - upcast_attention: bool = False, - use_flash_attention: bool = False, - dropout_cattn: float = 0.0, - input_top_region_index: bool = False, - input_bottom_region_index: bool = False, - input_spacing: bool = False, - ) -> None: - super().__init__() - if with_conditioning is True and cross_attention_dim is None: - raise ValueError( - "CustomDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " - "when using with_conditioning." - ) - if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "CustomDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) - if dropout_cattn > 1.0 or dropout_cattn < 0.0: - raise ValueError("Dropout cannot be negative or >1.0!") - - # All number of channels should be multiple of num_groups - if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): - raise ValueError("CustomDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") - - if len(num_channels) != len(attention_levels): - raise ValueError("CustomDiffusionModelUNet expects num_channels being same size of attention_levels") - - if isinstance(num_head_channels, int): - num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) - - if len(num_head_channels) != len(attention_levels): - raise ValueError( - "num_head_channels should have the same length as attention_levels. For the i levels without attention," - " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." - ) - - if isinstance(num_res_blocks, int): - num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) - - if len(num_res_blocks) != len(num_channels): - raise ValueError( - "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`num_channels`." - ) - - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - - self.in_channels = in_channels - self.block_out_channels = num_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_levels = attention_levels - self.num_head_channels = num_head_channels - self.with_conditioning = with_conditioning - - # input - self.conv_in = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=num_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - # time - time_embed_dim = num_channels[0] * 4 - self.time_embed = nn.Sequential( - nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - - # class embedding - self.num_class_embeds = num_class_embeds - if num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - - self.input_top_region_index = input_top_region_index - self.input_bottom_region_index = input_bottom_region_index - self.input_spacing = input_spacing - - new_time_embed_dim = time_embed_dim - if self.input_top_region_index: - # self.top_region_index_layer = nn.Linear(4, time_embed_dim) - self.top_region_index_layer = nn.Sequential( - nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - new_time_embed_dim += time_embed_dim - if self.input_bottom_region_index: - # self.bottom_region_index_layer = nn.Linear(4, time_embed_dim) - self.bottom_region_index_layer = nn.Sequential( - nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - new_time_embed_dim += time_embed_dim - if self.input_spacing: - # self.spacing_layer = nn.Linear(3, time_embed_dim) - self.spacing_layer = nn.Sequential( - nn.Linear(3, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) - ) - new_time_embed_dim += time_embed_dim - - # down - self.down_blocks = nn.ModuleList([]) - output_channel = num_channels[0] - for i in range(len(num_channels)): - input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 - - down_block = get_down_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=new_time_embed_dim, - num_res_blocks=num_res_blocks[i], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_downsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(attention_levels[i] and not with_conditioning), - with_cross_attn=(attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - self.down_blocks.append(down_block) - - # mid - self.middle_block = get_mid_block( - spatial_dims=spatial_dims, - in_channels=num_channels[-1], - temb_channels=new_time_embed_dim, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - with_conditioning=with_conditioning, - num_head_channels=num_head_channels[-1], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - # up - self.up_blocks = nn.ModuleList([]) - reversed_block_out_channels = list(reversed(num_channels)) - reversed_num_res_blocks = list(reversed(num_res_blocks)) - reversed_attention_levels = list(reversed(attention_levels)) - reversed_num_head_channels = list(reversed(num_head_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] - - is_final_block = i == len(num_channels) - 1 - - up_block = get_up_block( - spatial_dims=spatial_dims, - in_channels=input_channel, - prev_output_channel=prev_output_channel, - out_channels=output_channel, - temb_channels=new_time_embed_dim, - num_res_blocks=reversed_num_res_blocks[i] + 1, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - add_upsample=not is_final_block, - resblock_updown=resblock_updown, - with_attn=(reversed_attention_levels[i] and not with_conditioning), - with_cross_attn=(reversed_attention_levels[i] and with_conditioning), - num_head_channels=reversed_num_head_channels[i], - transformer_num_layers=transformer_num_layers, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - dropout_cattn=dropout_cattn, - ) - - self.up_blocks.append(up_block) - - # out - self.out = nn.Sequential( - nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), - nn.SiLU(), - zero_module( - Convolution( - spatial_dims=spatial_dims, - in_channels=num_channels[0], - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ), - ) - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - context: torch.Tensor | None = None, - class_labels: torch.Tensor | None = None, - down_block_additional_residuals: tuple[torch.Tensor] | None = None, - mid_block_additional_residual: torch.Tensor | None = None, - top_region_index_tensor: torch.Tensor | None = None, - bottom_region_index_tensor: torch.Tensor | None = None, - spacing_tensor: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). - mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). - """ - # 1. time - t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=x.dtype) - emb = self.time_embed(t_emb) - # print(f't_emb: {t_emb}; timesteps {timesteps}.') - # print(f'emb: {torch.norm(emb, p=2)}; t_emb: {torch.norm(t_emb, p=2)}') - - # 2. class - if self.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels) - class_emb = class_emb.to(dtype=x.dtype) - emb = emb + class_emb - - # 3. input - if self.input_top_region_index: - _emb = self.top_region_index_layer(top_region_index_tensor) - # emb = emb + _emb.to(dtype=x.dtype) - emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; top_region_index_tensor: {torch.norm(_emb, p=2)}') - if self.input_bottom_region_index: - _emb = self.bottom_region_index_layer(bottom_region_index_tensor) - # emb = emb + _emb.to(dtype=x.dtype) - emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; bottom_region_index_tensor: {torch.norm(_emb, p=2)}') - if self.input_spacing: - _emb = self.spacing_layer(spacing_tensor) - # emb = emb + _emb.to(dtype=x.dtype) - emb = torch.cat((emb, _emb), dim=1) - # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; spacing_tensor: {torch.norm(spacing_tensor, p=2)}') - - # 3. initial convolution - h = self.conv_in(x) - - # 4. down - if context is not None and self.with_conditioning is False: - raise ValueError("model should have with_conditioning = True if context is provided") - down_block_res_samples: list[torch.Tensor] = [h] - for downsample_block in self.down_blocks: - h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) - for residual in res_samples: - down_block_res_samples.append(residual) - - # Additional residual conections for Controlnets - if down_block_additional_residuals is not None: - new_down_block_res_samples = () - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) - - # Additional residual conections for Controlnets - if mid_block_additional_residual is not None: - h = h + mid_block_additional_residual - - # 6. up - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) - - # 7. output block - h = self.out(h) - - return h diff --git a/models/maisi_ct_generative/scripts/custom_network_tp.py b/models/maisi_ct_generative/scripts/custom_network_tp.py deleted file mode 100644 index 1fd33fe0..00000000 --- a/models/maisi_ct_generative/scripts/custom_network_tp.py +++ /dev/null @@ -1,1053 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Sequence - -import monai -import torch -import torch.nn as nn -import torch.nn.functional as F -from generative.networks.nets.autoencoderkl import AttentionBlock, AutoencoderKL, ResBlock - -NUM_SPLITS = 16 -# NUM_SPLITS = 32 -SPLIT_PADDING = 3 - - -class InplaceGroupNorm3D(torch.nn.GroupNorm): - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - super(InplaceGroupNorm3D, self).__init__(num_groups, num_channels, eps, affine) - - def forward(self, input): - # print("InplaceGroupNorm3D in", input.size()) - - # # normalization - # norm = 1e1 - # input /= norm - # # print("normalization2") - - # Ensure the tensor is 5D: (n, c, d, h, w) - if len(input.shape) != 5: - raise ValueError("Expected a 5D tensor") - - n, c, d, h, w = input.shape - - # Reshape to (n, num_groups, c // num_groups, d, h, w) - input = input.view(n, self.num_groups, c // self.num_groups, d, h, w) - - # input = input.to(dtype=torch.float64) - - # # Compute mean and std dev - # mean1 = input.mean([2, 3, 4, 5], keepdim=True) - # std1 = input.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() - # mean1 = mean1.to(dtype=torch.float32) - - if False: - input = input.to(dtype=torch.float64) - mean = input.mean([2, 3, 4, 5], keepdim=True) - # std = input.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() - - input = input.to(dtype=torch.float32) - mean = mean.to(dtype=torch.float32) - # std = mean.to(dtype=torch.float32) - else: - # means, stds = [], [] - inputs = [] - for _i in range(input.size(1)): - array = input[:, _i : _i + 1, ...] - array = array.to(dtype=torch.float32) - _mean = array.mean([2, 3, 4, 5], keepdim=True) - _std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() - - # del array - # torch.cuda.empty_cache() - - _mean = _mean.to(dtype=torch.float32) - _std = _std.to(dtype=torch.float32) - - # means.append(_mean) - # stds.append(_std) - - # mean = torch.cat([means[_k] for _k in range(len(means))], dim=1) - # std = torch.cat([stds[_k] for _k in range(len(stds))], dim=1) - # input = input.to(dtype=torch.float32) - - inputs.append(array.sub_(_mean).div_(_std).to(dtype=torch.float16)) - - # Normalize features (in-place) - # input.sub_(mean).div_(std) - - del input - torch.cuda.empty_cache() - - if False: - input = torch.cat([inputs[_k] for _k in range(len(inputs))], dim=1) - else: - if max(inputs[0].size()) < 500: - input = torch.cat([inputs[_k] for _k in range(len(inputs))], dim=1) - else: - import gc - - _type = inputs[0].device.type - if _type == "cuda": - input = inputs[0].clone().to("cpu", non_blocking=True) - else: - input = inputs[0].clone() - inputs[0] = 0 - torch.cuda.empty_cache() - - for _k in range(len(inputs) - 1): - input = torch.cat((input, inputs[_k + 1].cpu()), dim=1) - inputs[_k + 1] = 0 - torch.cuda.empty_cache() - gc.collect() - # print(f'InplaceGroupNorm3D cat: {_k + 1}/{len(inputs) - 1}.') - - if _type == "cuda": - input = input.to("cuda", non_blocking=True) - - # Reshape back to original size - input = input.view(n, c, d, h, w) - - # Apply affine transformation if enabled - if self.affine: - input.mul_(self.weight.view(1, c, 1, 1, 1)).add_(self.bias.view(1, c, 1, 1, 1)) - - # input = input.to(dtype=torch.float32) - # input *= norm - # print("InplaceGroupNorm3D out", input.size()) - - return input - - -class SplitConvolutionV1(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - strides: Sequence[int] | int = 1, - kernel_size: Sequence[int] | int = 3, - adn_ordering: str = "NDA", - act: tuple | str | None = "PRELU", - norm: tuple | str | None = "INSTANCE", - dropout: tuple | str | float | None = None, - dropout_dim: int | None = 1, - dilation: Sequence[int] | int = 1, - groups: int = 1, - bias: bool = True, - conv_only: bool = False, - is_transposed: bool = False, - padding: Sequence[int] | int | None = None, - output_padding: Sequence[int] | int | None = None, - ) -> None: - super(SplitConvolutionV1, self).__init__() - self.conv = monai.networks.blocks.convolutions.Convolution( - spatial_dims, - in_channels, - out_channels, - strides, - kernel_size, - adn_ordering, - act, - norm, - dropout, - dropout_dim, - dilation, - groups, - bias, - conv_only, - is_transposed, - padding, - output_padding, - ) - - self.tp_dim = 1 - self.stride = strides[self.tp_dim] if isinstance(strides, list) else strides - - def forward(self, x): - # Call parent's forward method - # x = super(SplitConvolution, self).forward(x) - - num_splits = NUM_SPLITS - # print("num_splits:", num_splits) - l = x.size(self.tp_dim + 2) - split_size = l // num_splits - - if False: - splits = [x[:, :, i * split_size : (i + 1) * split_size, :, :] for i in range(num_splits)] - else: - # padding = 1 - padding = SPLIT_PADDING - if padding % self.stride > 0: - padding = (padding // self.stride + 1) * self.stride - # print("padding:", padding) - - overlaps = [0] + [padding] * (num_splits - 1) - last_padding = x.size(self.tp_dim + 2) % split_size - - if self.tp_dim == 0: - splits = [ - x[ - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else last_padding), - :, - :, - ] - for i in range(num_splits) - ] - elif self.tp_dim == 1: - splits = [ - x[ - :, - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else last_padding), - :, - ] - for i in range(num_splits) - ] - elif self.tp_dim == 2: - splits = [ - x[ - :, - :, - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else last_padding), - ] - for i in range(num_splits) - ] - - # for _j in range(len(splits)): - # print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) - - del x - torch.cuda.empty_cache() - - splits_0_size = list(splits[0].size()) - # print("splits_0_size:", splits_0_size) - - # outputs = [super(SplitConvolution, self).forward(splits[i]) for i in range(num_splits)] - if False: - outputs = [self.conv(splits[i]) for i in range(num_splits)] - else: - outputs = [] - _type = splits[0].device.type - for _i in range(num_splits): - if True: - # if _type == 'cuda': - outputs.append(self.conv(splits[_i])) - else: - _t = splits[_i] - _t1 = self.conv(_t.to("cuda", non_blocking=True)) - del _t - torch.cuda.empty_cache() - _t1 = _t1.to("cpu", non_blocking=True) - outputs.append(_t1) - del _t1 - torch.cuda.empty_cache() - - splits[_i] = 0 - torch.cuda.empty_cache() - - # for _j in range(len(outputs)): - # print(f"outputs before {_j + 1}/{len(outputs)}:", outputs[_j].size()) - - del splits - torch.cuda.empty_cache() - - split_size_out = split_size - padding_s = padding - non_tp_dim = self.tp_dim + 1 if self.tp_dim < 2 else 0 - if outputs[0].size(non_tp_dim + 2) // splits_0_size[non_tp_dim + 2] == 2: - split_size_out *= 2 - padding_s *= 2 - elif splits_0_size[non_tp_dim + 2] // outputs[0].size(non_tp_dim + 2) == 2: - split_size_out = split_size_out // 2 - padding_s = padding_s // 2 - - if self.tp_dim == 0: - outputs[0] = outputs[0][:, :, :split_size_out, :, :] - for i in range(1, num_splits): - outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] - elif self.tp_dim == 1: - # print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") - outputs[0] = outputs[0][:, :, :, :split_size_out, :] - # # print("outputs", outputs[0].size(3), f"padding_s: {padding_s // 2}, {padding_s // 2 + split_size_out}") - # outputs[0] = outputs[0][:, :, :, padding_s // 2:padding_s // 2 + split_size_out, :] - for i in range(1, num_splits): - # print("outputs", outputs[i].size(3), f"padding_s: {padding_s}, {padding_s + split_size_out}") - outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] - elif self.tp_dim == 2: - outputs[0] = outputs[0][:, :, :, :, :split_size_out] - for i in range(1, num_splits): - outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] - - # for i in range(num_splits): - # print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) - - # if max(outputs[0].size()) < 500 or outputs[0].device.type != 'cuda': - # if True: - if max(outputs[0].size()) < 500: - # print(f'outputs[0].device.type: {outputs[0].device.type}.') - x = torch.cat(list(outputs), dim=self.tp_dim + 2) - else: - import gc - - # x = torch.randn(outputs[0].size(), dtype=outputs[0].dtype, pin_memory=True) - # x = outputs[0] - # x = x.to('cpu', non_blocking=True) - - _type = outputs[0].device.type - if _type == "cuda": - x = outputs[0].clone().to("cpu", non_blocking=True) - outputs[0] = 0 - torch.cuda.empty_cache() - for _k in range(len(outputs) - 1): - x = torch.cat((x, outputs[_k + 1].cpu()), dim=self.tp_dim + 2) - outputs[_k + 1] = 0 - torch.cuda.empty_cache() - gc.collect() - # print(f'SplitConvolutionV1 cat: {_k + 1}/{len(outputs) - 1}.') - if _type == "cuda": - x = x.to("cuda", non_blocking=True) - - del outputs - torch.cuda.empty_cache() - - return x - - -class SplitUpsample1(nn.Module): - """ - Convolution-based upsampling layer. - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - in_channels: number of input channels to the layer. - use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. - """ - - def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: - super().__init__() - if use_convtranspose: - self.conv = SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=2, - kernel_size=3, - padding=1, - conv_only=True, - is_transposed=True, - ) - else: - self.conv = SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - self.use_convtranspose = use_convtranspose - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.use_convtranspose: - return self.conv(x) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - # dtype = x.dtype - # if dtype == torch.bfloat16: - # x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="trilinear") - torch.cuda.empty_cache() - - # If the input is bfloat16, we cast back to bfloat16 - # if dtype == torch.bfloat16: - # x = x.to(dtype) - - x = self.conv(x) - torch.cuda.empty_cache() - - return x - - -class SplitDownsample(nn.Module): - """ - Convolution-based downsampling layer. - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - in_channels: number of input channels. - """ - - def __init__(self, spatial_dims: int, in_channels: int) -> None: - super().__init__() - self.pad = (0, 1) * spatial_dims - - self.conv = SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=2, - kernel_size=3, - padding=0, - conv_only=True, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) - x = self.conv(x) - return x - - -class SplitResBlock(nn.Module): - """ - Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a - residual connection between input and output. - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - in_channels: input channels to the layer. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon for the normalisation. - out_channels: number of output channels. - """ - - def __init__( - self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - - self.norm1 = InplaceGroupNorm3D(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - # self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.conv1 = SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=self.in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - self.norm2 = InplaceGroupNorm3D( - num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True - ) - # self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) - self.conv2 = SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=self.out_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - - if self.in_channels != self.out_channels: - self.nin_shortcut = SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=self.in_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - else: - self.nin_shortcut = nn.Identity() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if True: - h = x - h = self.norm1(h) - torch.cuda.empty_cache() - - # if max(x.size()) > 500: - # h = h.to('cpu', non_blocking=True).float() - # torch.cuda.empty_cache() - - h = F.silu(h) - torch.cuda.empty_cache() - h = self.conv1(h) - torch.cuda.empty_cache() - - # if max(x.size()) > 500: - # h = h.half().to('cuda', non_blocking=True) - # torch.cuda.empty_cache() - - h = self.norm2(h) - torch.cuda.empty_cache() - - # if max(x.size()) > 500: - # h = h.to('cpu', non_blocking=True).float() - # torch.cuda.empty_cache() - - h = F.silu(h) - torch.cuda.empty_cache() - h = self.conv2(h) - torch.cuda.empty_cache() - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - torch.cuda.empty_cache() - - # if max(x.size()) > 500: - # h = h.half().to('cuda', non_blocking=True) - # x = x.half().to('cuda', non_blocking=True) - else: - h1 = self.norm1(x) - if max(h1.size()) > 500: - x = x.to("cpu", non_blocking=True).float() - torch.cuda.empty_cache() - if max(h1.size()) > 500: - h1 = h1.to("cpu", non_blocking=True).float() - torch.cuda.empty_cache() - h2 = F.silu(h1) - if max(h2.size()) > 500: - h2 = h2.half().to("cuda", non_blocking=True) - h3 = self.conv1(h2) - del h2 - torch.cuda.empty_cache() - - h4 = self.norm2(h3) - del h3 - torch.cuda.empty_cache() - if max(h4.size()) > 500: - h4 = h4.to("cpu", non_blocking=True).float() - torch.cuda.empty_cache() - h5 = F.silu(h4) - if max(h5.size()) > 500: - h5 = h5.half().to("cuda", non_blocking=True) - h6 = self.conv2(h5) - del h5 - torch.cuda.empty_cache() - - if max(h6.size()) > 500: - h6 = h6.to("cpu", non_blocking=True).float() - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - torch.cuda.empty_cache() - - out = x + h6 - if max(h6.size()) > 500: - out = out.half().to("cuda", non_blocking=True) - - return x + h - # return out - - -class EncoderTp(nn.Module): - """ - Convolutional cascade that downsamples the image into a spatial latent space. - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - in_channels: number of input channels. - num_channels: sequence of block output channels. - out_channels: number of channels in the bottom layer (latent space) of the autoencoder. - num_res_blocks: number of residual blocks (see ResBlock) per level. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. - norm_eps: epsilon for the normalization. - attention_levels: indicate which level from num_channels contain an attention block. - with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - num_channels: Sequence[int], - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.in_channels = in_channels - self.num_channels = num_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.norm_num_groups = norm_num_groups - self.norm_eps = norm_eps - self.attention_levels = attention_levels - - blocks = [] - # Initial convolution - blocks.append( - SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=num_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - # Residual and downsampling blocks - output_channel = num_channels[0] - for i in range(len(num_channels)): - input_channel = output_channel - output_channel = num_channels[i] - is_final_block = i == len(num_channels) - 1 - - for _ in range(self.num_res_blocks[i]): - blocks.append( - SplitResBlock( - spatial_dims=spatial_dims, - in_channels=input_channel, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=output_channel, - ) - ) - input_channel = output_channel - if attention_levels[i]: - blocks.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=input_channel, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - if not is_final_block: - blocks.append(SplitDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) - - # Non-local attention block - if with_nonlocal_attn is True: - blocks.append( - ResBlock( - spatial_dims=spatial_dims, - in_channels=num_channels[-1], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=num_channels[-1], - ) - ) - - blocks.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=num_channels[-1], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - blocks.append( - ResBlock( - spatial_dims=spatial_dims, - in_channels=num_channels[-1], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=num_channels[-1], - ) - ) - # Normalise and convert to latent size - blocks.append( - InplaceGroupNorm3D(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True) - ) - blocks.append( - SplitConvolutionV1( - spatial_dims=self.spatial_dims, - in_channels=num_channels[-1], - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.blocks = nn.ModuleList(blocks) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for block in self.blocks: - x = block(x) - torch.cuda.empty_cache() - return x - - -class DecoderTp1(nn.Module): - """ - Convolutional cascade upsampling from a spatial latent space into an image space. - - Args: - spatial_dims: number of spatial dimensions (1D, 2D, 3D). - num_channels: sequence of block output channels. - in_channels: number of channels in the bottom layer (latent space) of the autoencoder. - out_channels: number of output channels. - num_res_blocks: number of residual blocks (see ResBlock) per level. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. - norm_eps: epsilon for the normalization. - attention_levels: indicate which level from num_channels contain an attention block. - with_nonlocal_attn: if True use non-local attention block. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: Sequence[int], - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int], - norm_num_groups: int, - norm_eps: float, - attention_levels: Sequence[bool], - with_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - use_convtranspose: bool = False, - tp_dim: int = 1, - ) -> None: - super().__init__() - self.spatial_dims = spatial_dims - self.num_channels = num_channels - self.in_channels = in_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.norm_num_groups = norm_num_groups - self.norm_eps = norm_eps - self.attention_levels = attention_levels - self.tp_dim = tp_dim - - reversed_block_out_channels = list(reversed(num_channels)) - - blocks = [] - # Initial convolution - blocks.append( - SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=reversed_block_out_channels[0], - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - # Non-local attention block - if with_nonlocal_attn is True: - blocks.append( - ResBlock( - spatial_dims=spatial_dims, - in_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=reversed_block_out_channels[0], - ) - ) - blocks.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - blocks.append( - ResBlock( - spatial_dims=spatial_dims, - in_channels=reversed_block_out_channels[0], - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=reversed_block_out_channels[0], - ) - ) - - reversed_attention_levels = list(reversed(attention_levels)) - reversed_num_res_blocks = list(reversed(num_res_blocks)) - block_out_ch = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - block_in_ch = block_out_ch - block_out_ch = reversed_block_out_channels[i] - is_final_block = i == len(num_channels) - 1 - - for _ in range(reversed_num_res_blocks[i]): - blocks.append( - SplitResBlock( - spatial_dims=spatial_dims, - in_channels=block_in_ch, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - out_channels=block_out_ch, - ) - ) - block_in_ch = block_out_ch - - if reversed_attention_levels[i]: - blocks.append( - AttentionBlock( - spatial_dims=spatial_dims, - num_channels=block_in_ch, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - use_flash_attention=use_flash_attention, - ) - ) - - if not is_final_block: - blocks.append( - SplitUpsample1( - spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose - ) - ) - - blocks.append( - InplaceGroupNorm3D(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True) - ) - blocks.append( - SplitConvolutionV1( - spatial_dims=spatial_dims, - in_channels=block_in_ch, - out_channels=out_channels, - strides=1, - kernel_size=3, - padding=1, - conv_only=True, - ) - ) - - self.blocks = nn.ModuleList(blocks) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # if False: - # for block in self.blocks: - # x = block(x) - # else: - for _i in range(len(self.blocks)): - block = self.blocks[_i] - # print(block, type(block), type(type(block))) - - if _i < len(self.blocks) - 0: - # if not isinstance(block, monai.networks.blocks.convolutions.Convolution): - x = block(x) - torch.cuda.empty_cache() - else: - # # print(block, type(block), type(type(block))) - # block = self.blocks[_i] - # # print(f"block {_i + 1}/{len(self.blocks)}") - - num_splits = NUM_SPLITS - # print("num_splits:", num_splits) - - l = x.size(self.tp_dim + 2) - split_size = l // num_splits - - if False: - splits = [x[:, :, i * split_size : (i + 1) * split_size, :, :] for i in range(num_splits)] - else: - # padding = 1 - padding = SPLIT_PADDING - # print("padding:", padding) - - overlaps = [0] + [padding] * (num_splits - 1) - if self.tp_dim == 0: - splits = [ - x[ - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else 0), - :, - :, - ] - for i in range(num_splits) - ] - elif self.tp_dim == 1: - splits = [ - x[ - :, - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else 0), - :, - ] - for i in range(num_splits) - ] - elif self.tp_dim == 2: - splits = [ - x[ - :, - :, - :, - :, - i * split_size - - overlaps[i] : (i + 1) * split_size - + (padding if i != num_splits - 1 else 0), - ] - for i in range(num_splits) - ] - - # for _j in range(len(splits)): - # print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) - - del x - torch.cuda.empty_cache() - - outputs = [block(splits[i]) for i in range(num_splits)] - - del splits - torch.cuda.empty_cache() - - split_size_out = split_size - padding_s = padding - non_tp_dim = self.tp_dim + 1 if self.tp_dim < 2 else 0 - if outputs[0].size(non_tp_dim + 2) // splits[0].size(non_tp_dim + 2) == 2: - split_size_out *= 2 - padding_s *= 2 - # print("split_size_out:", split_size_out) - # print("padding_s:", padding_s) - - if self.tp_dim == 0: - outputs[0] = outputs[0][:, :, :split_size_out, :, :] - for i in range(1, num_splits): - outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] - elif self.tp_dim == 1: - # print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") - outputs[0] = outputs[0][:, :, :, :split_size_out, :] - # # print("outputs", outputs[0].size(3), f"padding_s: {padding_s // 2}, {padding_s // 2 + split_size_out}") - # outputs[0] = outputs[0][:, :, :, padding_s // 2:padding_s // 2 + split_size_out, :] - for i in range(1, num_splits): - # print("outputs", outputs[i].size(3), f"padding_s: {padding_s}, {padding_s + split_size_out}") - outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] - elif self.tp_dim == 2: - outputs[0] = outputs[0][:, :, :, :, :split_size_out] - for i in range(1, num_splits): - outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] - - # for i in range(num_splits): - # print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) - - if max(outputs[0].size()) < 500: - x = torch.cat(list(outputs), dim=self.tp_dim + 2) - else: - import gc - - # x = torch.randn(outputs[0].size(), dtype=outputs[0].dtype, pin_memory=True) - # x = outputs[0] - # x = x.to('cpu', non_blocking=True) - x = outputs[0].clone().to("cpu", non_blocking=True) - outputs[0] = 0 - torch.cuda.empty_cache() - for _k in range(len(outputs) - 1): - x = torch.cat((x, outputs[_k + 1].cpu()), dim=self.tp_dim + 2) - outputs[_k + 1] = 0 - torch.cuda.empty_cache() - gc.collect() - # print(f'cat: {_k + 1}/{len(outputs) - 1}.') - x = x.to("cuda", non_blocking=True) - - del outputs - torch.cuda.empty_cache() - - return x - - -class AutoencoderKlckModifiedTp(AutoencoderKL): - """ - Override encoder to make it align with original ldm codebase and support activation checkpointing. - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - num_res_blocks: Sequence[int], - num_channels: Sequence[int], - attention_levels: Sequence[bool], - latent_channels: int = 3, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - with_encoder_nonlocal_attn: bool = True, - with_decoder_nonlocal_attn: bool = True, - use_flash_attention: bool = False, - use_checkpointing: bool = False, - use_convtranspose: bool = False, - ) -> None: - super().__init__( - spatial_dims, - in_channels, - out_channels, - num_res_blocks, - num_channels, - attention_levels, - latent_channels, - norm_num_groups, - norm_eps, - with_encoder_nonlocal_attn, - with_decoder_nonlocal_attn, - use_flash_attention, - use_checkpointing, - use_convtranspose, - ) - - self.encoder = EncoderTp( - spatial_dims=spatial_dims, - in_channels=in_channels, - num_channels=num_channels, - out_channels=latent_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - attention_levels=attention_levels, - with_nonlocal_attn=with_encoder_nonlocal_attn, - use_flash_attention=use_flash_attention, - ) - - # Override decoder using transposed conv - self.decoder = DecoderTp1( - spatial_dims=spatial_dims, - num_channels=num_channels, - in_channels=latent_channels, - out_channels=out_channels, - num_res_blocks=num_res_blocks, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - attention_levels=attention_levels, - with_nonlocal_attn=with_decoder_nonlocal_attn, - use_flash_attention=use_flash_attention, - use_convtranspose=use_convtranspose, - ) From fd10f8cb5775216c478d239888c7fd730e8b8fe0 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 29 Aug 2024 18:35:13 +0000 Subject: [PATCH 10/19] refomat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 493fb832..e34abd97 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -296,7 +296,7 @@ def ldm_conditional_sample_one_image( end_time = time.time() logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----") - ## post processing: + # post processing: # project output to [0, 1] synthetic_images = (synthetic_images - b_min) / (b_max - b_min) # project output to [-1000, 1000] From ead2754b4f7bfc6c954722cd064c867fdf49715f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 30 Aug 2024 08:53:02 +0000 Subject: [PATCH 11/19] auto update Signed-off-by: Yiheng Wang --- .../maisi_ct_generative/configs/metadata.json | 2 +- models/maisi_ct_generative/scripts/sample.py | 80 ++++--------------- models/maisi_ct_generative/scripts/utils.py | 56 +++---------- 3 files changed, 26 insertions(+), 112 deletions(-) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index adfbf72c..d6cbca00 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -6,7 +6,7 @@ "0.3.6": "first oss version" }, "monai_version": "1.4.0", - "pytorch_version": "2.2.2", + "pytorch_version": "2.4.0", "numpy_version": "1.24.4", "optional_packages_version": { "fire": "0.6.0", diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index e34abd97..18fa6b86 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -71,16 +71,7 @@ def initialize_noise_latents(latent_shape, device): Returns: torch.Tensor: Initialized noise latents. """ - return ( - torch.randn( - [ - 1, - ] - + list(latent_shape) - ) - .half() - .to(device) - ) + return torch.randn([1] + list(latent_shape)).half().to(device) def ldm_conditional_sample_one_mask( @@ -251,9 +242,7 @@ def ldm_conditional_sample_one_image( for t in tqdm(noise_scheduler.timesteps, ncols=110): # Get controlnet output down_block_res_samples, mid_block_res_sample = controlnet( - x=latents, - timesteps=torch.Tensor((t,)).to(device), - controlnet_cond=controlnet_cond_vis, + x=latents, timesteps=torch.Tensor((t,)).to(device), controlnet_cond=controlnet_cond_vis ) latent_model_input = latents noise_pred = diffusion_unet( @@ -350,12 +339,7 @@ def crop_img_body_mask(synthetic_images, combine_label): def check_input( - body_region, - anatomy_list, - label_dict_json, - output_size, - spacing, - controllable_anatomy_size=[("pancreas", 0.5)], + body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size=[("pancreas", 0.5)] ): """ Validate input parameters for image generation. @@ -397,13 +381,7 @@ def check_input( raise ValueError( f"The length of list controllable_anatomy_size has to be less than 10. Yet got length equal to {len(controllable_anatomy_size)}." ) - available_controllable_organ = [ - "liver", - "gallbladder", - "stomach", - "pancreas", - "colon", - ] + available_controllable_organ = ["liver", "gallbladder", "stomach", "pancreas", "colon"] available_controllable_tumor = [ "hepatic tumor", "bone lesion", @@ -443,14 +421,7 @@ def check_input( f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `body_region`: ({body_region}) and `anatomy_list`: ({anatomy_list})." ) # check body_region format - available_body_region = [ - "head", - "chest", - "thorax", - "abdomen", - "pelvis", - "lower", - ] + available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] for region in body_region: if region not in available_body_region: raise ValueError( @@ -658,22 +629,16 @@ def sample_multiple_images(self, num_img): start_time = time.time() if len(self.controllable_anatomy_size) > 0: # generate a synthetic mask - ( - combine_label_or, - top_region_index_tensor, - bottom_region_index_tensor, - spacing_tensor, - ) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion) + (combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( + self.prepare_one_mask_and_meta_info(anatomy_size_condtion) + ) else: # read in mask file mask_file = item["mask_file"] if_aug = item["if_aug"] - ( - combine_label_or, - top_region_index_tensor, - bottom_region_index_tensor, - spacing_tensor, - ) = self.read_mask_information(mask_file) + (combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( + self.read_mask_information(mask_file) + ) if need_resample: combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) # mask augmentation @@ -687,10 +652,7 @@ def sample_multiple_images(self, num_img): try_time = 0 while to_generate: synthetic_images, synthetic_labels = self.sample_one_pair( - combine_label_or, - top_region_index_tensor, - bottom_region_index_tensor, - spacing_tensor, + combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor ) # synthetic image quality check pass_quality_check = self.quality_check( @@ -752,11 +714,7 @@ def select_mask(self, candidate_mask_files, num_img): return selected_mask_files def sample_one_pair( - self, - combine_label_or_aug, - top_region_index_tensor, - bottom_region_index_tensor, - spacing_tensor, + self, combine_label_or_aug, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor ): """ Generate a single pair of synthetic image and mask. @@ -791,10 +749,7 @@ def sample_one_pair( ) return synthetic_images, synthetic_labels - def prepare_anatomy_size_condtion( - self, - controllable_anatomy_size, - ): + def prepare_anatomy_size_condtion(self, controllable_anatomy_size): """ Prepare anatomy size conditions for mask generation. @@ -955,12 +910,7 @@ def read_mask_information(self, mask_file): """ val_data = self.val_transforms(mask_file) - for key in [ - "pseudo_label", - "spacing", - "top_region_index", - "bottom_region_index", - ]: + for key in ["pseudo_label", "spacing", "top_region_index", "bottom_region_index"]: val_data[key] = val_data[key].unsqueeze(0).to(self.device) return ( diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index 7c548d6d..c340944a 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -177,19 +177,7 @@ def erode_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_valu Return: Tensor: eroded mask, same shape as input. """ - return ( - erode( - mask_t.float() - .unsqueeze(0) - .unsqueeze( - 0, - ), - filter_size, - pad_value=pad_value, - ) - .squeeze(0) - .squeeze(0) - ) + return erode(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0) def dilate_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: @@ -206,19 +194,7 @@ def dilate_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_val Return: Tensor: dilated mask, same shape as input. """ - return ( - dilate( - mask_t.float() - .unsqueeze(0) - .unsqueeze( - 0, - ), - filter_size, - pad_value=pad_value, - ) - .squeeze(0) - .squeeze(0) - ) + return dilate(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0) def binarize_labels(x: Tensor, bits: int = 8) -> Tensor: @@ -365,27 +341,16 @@ def prepare_maisi_controlnet_json_dataloader( train_loader = None if use_ddp: - list_train = partition_dataset( - data=list_train, - shuffle=True, - num_partitions=world_size, - even_divisible=True, - )[rank] + list_train = partition_dataset(data=list_train, shuffle=True, num_partitions=world_size, even_divisible=True)[ + rank + ] train_ds = CacheDataset(data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) if use_ddp: - list_valid = partition_dataset( - data=list_valid, - shuffle=True, - num_partitions=world_size, - even_divisible=False, - )[rank] - val_ds = CacheDataset( - data=list_valid, - transform=val_transforms, - cache_rate=cache_rate, - num_workers=8, - ) + list_valid = partition_dataset(data=list_valid, shuffle=True, num_partitions=world_size, even_divisible=False)[ + rank + ] + val_ds = CacheDataset(data=list_valid, transform=val_transforms, cache_rate=cache_rate, num_workers=8) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) return train_loader, val_loader @@ -734,8 +699,7 @@ def KL_loss(z_mu, z_sigma): """ eps = 1e-10 kl_loss = 0.5 * torch.sum( - z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2) + eps) - 1, - dim=list(range(1, len(z_sigma.shape))), + z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2) + eps) - 1, dim=list(range(1, len(z_sigma.shape))) ) return torch.sum(kl_loss) / kl_loss.shape[0] From f6b284749fc95ce636539616bf991577310602de Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 30 Aug 2024 10:08:37 +0000 Subject: [PATCH 12/19] fix type errors Signed-off-by: Yiheng Wang --- ci/unit_tests/test_vista2d.py | 2 +- .../scripts/download_files.py | 2 - .../maisi_ct_generative/scripts/find_masks.py | 5 +- .../scripts/quality_check.py | 1 - models/maisi_ct_generative/scripts/sample.py | 75 +++++++++++++------ models/maisi_ct_generative/scripts/utils.py | 66 +++++----------- models/vista2d/scripts/workflow.py | 8 +- 7 files changed, 80 insertions(+), 79 deletions(-) diff --git a/ci/unit_tests/test_vista2d.py b/ci/unit_tests/test_vista2d.py index 0eb5fb5d..50ac1ab4 100644 --- a/ci/unit_tests/test_vista2d.py +++ b/ci/unit_tests/test_vista2d.py @@ -82,7 +82,7 @@ def test_infer_config(self, override): # check_properties=False, need to add monai service properties later check_workflow(workflow, check_properties=False) - expected_output_file = os.path.join(self.tmp_output_dir, f"image_{self.dataset_size-1}.tif") + expected_output_file = os.path.join(self.tmp_output_dir, f"image_{self.dataset_size - 1}.tif") self.assertTrue(os.path.isfile(expected_output_file)) @parameterized.expand([TEST_CASE_TRAIN]) diff --git a/models/maisi_ct_generative/scripts/download_files.py b/models/maisi_ct_generative/scripts/download_files.py index 4616568d..50a7ef4b 100644 --- a/models/maisi_ct_generative/scripts/download_files.py +++ b/models/maisi_ct_generative/scripts/download_files.py @@ -1,5 +1,3 @@ -import os - import yaml from monai.apps import download_url diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index c919d393..de626552 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -56,7 +56,7 @@ def find_masks( body_region: str | Sequence[str], anatomy_list: int | Sequence[int], spacing: Sequence[float] | float = 1.0, - output_size: Sequence[int] = [512, 512, 512], + output_size: Sequence[int] = (512, 512, 512), check_spacing_and_output_size: bool = False, database_filepath: str = "./configs/database.json", mask_foldername: str = "./datasets/masks/", @@ -72,7 +72,8 @@ def find_masks( anatomy_list: list of input anatomy. The found candidate mask will include these anatomies. spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions. output_size: list of three int, expected candidate mask spatial size. - check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size` and voxel size of `spacing`. + check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size` + and voxel size of `spacing`. database_filepath: path for the json file that stores the information of all the candidate masks. mask_foldername: directory that saves all the candidate masks. Return: diff --git a/models/maisi_ct_generative/scripts/quality_check.py b/models/maisi_ct_generative/scripts/quality_check.py index 22373276..fe34661f 100644 --- a/models/maisi_ct_generative/scripts/quality_check.py +++ b/models/maisi_ct_generative/scripts/quality_check.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import nibabel as nib import numpy as np diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 18fa6b86..9d203b31 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -84,7 +84,7 @@ def ldm_conditional_sample_one_mask( latent_shape, label_dict_remap_json, num_inference_steps=1000, - autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_size=(96, 96, 96), autoencoder_sliding_window_infer_overlap=0.6667, ): """ @@ -150,7 +150,7 @@ def ldm_conditional_sample_one_mask( # mapping raw index to 132 labels synthetic_mask = remap_labels(synthetic_mask, label_dict_remap_json) - ###### post process ##### + # post process data = synthetic_mask.squeeze().cpu().detach().numpy() labels = [23, 24, 26, 27, 128] @@ -181,7 +181,7 @@ def ldm_conditional_sample_one_image( output_size, noise_factor, num_inference_steps=1000, - autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_size=(96, 96, 96), autoencoder_sliding_window_infer_overlap=0.6667, ): """ @@ -228,7 +228,8 @@ def ldm_conditional_sample_one_image( or output_size[2] != combine_label.shape[4] ): logging.info( - "output_size is not a desired value. Need to interpolate the mask to match with output_size. The result image will be very low quality." + "output_size is not a desired value. Need to interpolate the mask to match " + "with output_size. The result image will be very low quality." ) combine_label = torch.nn.functional.interpolate(combine_label, size=output_size, mode="nearest") @@ -338,9 +339,7 @@ def crop_img_body_mask(synthetic_images, combine_label): return synthetic_images -def check_input( - body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size=[("pancreas", 0.5)] -): +def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size): """ Validate input parameters for image generation. @@ -360,7 +359,10 @@ def check_input( raise ValueError(f"The first two components of output_size need to be equal, yet got {output_size}.") if (output_size[0] not in [256, 384, 512]) or (output_size[2] not in [128, 256, 384, 512, 640, 768]): raise ValueError( - f"The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." + ( + "The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " + f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." + ) ) if spacing[0] != spacing[1]: @@ -371,15 +373,22 @@ def check_input( ) if output_size[0] * spacing[0] < 256: - FOV = [output_size[axis] * spacing[axis] for axis in range(3)] + fov = [output_size[axis] * spacing[axis] for axis in range(3)] raise ValueError( - f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). The FOV will be {FOV}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least 384mm for other body regions like abdomen. There is no such restriction for z-axis." + ( + f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). " + f"The FOV will be {fov}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least " + "384mm for other body regions like abdomen. There is no such restriction for z-axis." + ) ) # check controllable_anatomy_size format if len(controllable_anatomy_size) > 10: raise ValueError( - f"The length of list controllable_anatomy_size has to be less than 10. Yet got length equal to {len(controllable_anatomy_size)}." + ( + f"The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " + f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." + ) ) available_controllable_organ = ["liver", "gallbladder", "stomach", "pancreas", "colon"] available_controllable_tumor = [ @@ -395,7 +404,10 @@ def check_input( for controllable_anatomy_size_pair in controllable_anatomy_size: if controllable_anatomy_size_pair[0] not in available_controllable_anatomy: raise ValueError( - f"The controllable_anatomy have to be chosen from {available_controllable_anatomy}, yet got {controllable_anatomy_size_pair[0]}." + ( + f"The controllable_anatomy have to be chosen from {available_controllable_anatomy}, " + f"yet got {controllable_anatomy_size_pair[0]}." + ) ) if controllable_anatomy_size_pair[0] in available_controllable_tumor: controllable_tumor += [controllable_anatomy_size_pair[0]] @@ -405,7 +417,10 @@ def check_input( continue if controllable_anatomy_size_pair[1] < 0 or controllable_anatomy_size_pair[1] > 1.0: raise ValueError( - f"The controllable size scale have to be between 0 and 1,0, or equal to -1, yet got {controllable_anatomy_size_pair[1]}." + ( + "The controllable size scale have to be between 0 and 1,0, or equal to -1, " + f"yet got {controllable_anatomy_size_pair[1]}." + ) ) if len(controllable_tumor + controllable_organ) != len(list(set(controllable_tumor + controllable_organ))): raise ValueError(f"Please do not repeat controllable_anatomy. Got {controllable_tumor + controllable_organ}.") @@ -414,11 +429,17 @@ def check_input( if len(controllable_anatomy_size) > 0: logging.info( - f"`controllable_anatomy_size` is not empty.\nWe will ignore `body_region` and `anatomy_list` and synthesize based on `controllable_anatomy_size`: ({controllable_anatomy_size})." + ( + "`controllable_anatomy_size` is not empty.\nWe will ignore `body_region` and `anatomy_list` " + f"and synthesize based on `controllable_anatomy_size`: ({controllable_anatomy_size})." + ) ) else: logging.info( - f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `body_region`: ({body_region}) and `anatomy_list`: ({anatomy_list})." + ( + "`controllable_anatomy_size` is empty.\nWe will synthesize based on `body_region`: " + f"({body_region}) and `anatomy_list`: ({anatomy_list})." + ) ) # check body_region format available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] @@ -476,11 +497,11 @@ def __init__( image_output_ext=".nii.gz", label_output_ext=".nii.gz", real_img_median_statistics="./configs/image_median_statistics.json", - spacing=[1, 1, 1], + spacing=(1, 1, 1), num_inference_steps=None, mask_generation_num_inference_steps=None, random_seed=None, - autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_size=(96, 96, 96), autoencoder_sliding_window_infer_overlap=0.6667, ) -> None: """ @@ -536,7 +557,10 @@ def __init__( ) if not (0 <= autoencoder_sliding_window_infer_overlap <= 1): raise ValueError( - f"Value of autoencoder_sliding_window_infer_overlap must be between 0 and 1.\n Got {autoencoder_sliding_window_infer_overlap}" + ( + f"Value of autoencoder_sliding_window_infer_overlap must be between 0 " + f"and 1.\n Got {autoencoder_sliding_window_infer_overlap}" + ) ) self.autoencoder_sliding_window_infer_size = autoencoder_sliding_window_infer_size self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap @@ -622,7 +646,10 @@ def sample_multiple_images(self, num_img): logging.info(f"Images will be generated based on {selected_mask_files}.") if len(selected_mask_files) != num_img: raise ValueError( - f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img ({num_img}). This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." + ( + f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img ({num_img}). " + f"This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." + ) ) for item in selected_mask_files: logging.info("---- Start preparing masks... ----") @@ -894,7 +921,10 @@ def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=Tr for anatomy_label in self.anatomy_list: if anatomy_label not in contained_labels: raise ValueError( - f"Resampled mask does not contain required class labels {anatomy_label}. Please tune spacing and output size." + ( + f"Resampled mask does not contain required class labels {anatomy_label}. " + "Please tune spacing and output size." + ) ) return labels @@ -995,7 +1025,10 @@ def quality_check(self, image_data, label_data): for label, result in outlier_results.items(): if result.get("is_outlier", False): logging.info( - f"Generated image quality check for label '{label}' failed: median value {result['median_value']} is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})." + ( + f"Generated image quality check for label '{label}' failed: median value {result['median_value']} " + f"is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})." + ) ) return False return True diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index c340944a..b34b9113 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and import copy import json -import logging import math import os import zipfile @@ -21,19 +20,12 @@ import skimage import torch import torch.distributed as dist -import torch.nn.functional as F from monai.bundle import ConfigParser from monai.config import DtypeLike, NdarrayOrTensor from monai.data import CacheDataset, DataLoader, partition_dataset from monai.transforms import Compose, EnsureTyped, Lambdad, LoadImaged, Orientationd from monai.transforms.utils_morphological_ops import dilate, erode -from monai.utils import ( - TransformBackends, - convert_data_type, - convert_to_dst_type, - ensure_tuple_rep, - get_equivalent_dtype, -) +from monai.utils import TransformBackends, convert_data_type, convert_to_dst_type, get_equivalent_dtype from scipy import stats from torch import Tensor @@ -49,7 +41,18 @@ def unzip_dataset(dataset_dir): return -def add_data_dir2path(list_files, data_dir, fold=None): +def add_data_dir2path(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]: + """ + Read a list of data dictionary. + + Args: + list_files (list): input data to load and transform to generate dataset for model. + data_dir (str): directory of files. + fold (int, optional): fold index for cross validation. Defaults to None. + + Returns: + tuple[list, list]: A tuple of two arrays (training, validation). + """ new_list_files = copy.deepcopy(list_files) if fold is not None: new_list_files_train = [] @@ -253,40 +256,6 @@ def define_instance(args: Namespace, instance_def_key: str) -> Any: return parser.get_parsed_content(instance_def_key, instantiate=True) -def add_data_dir2path(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]: - """ - Read a list of data dictionary. - - Args: - list_files (list): input data to load and transform to generate dataset for model. - data_dir (str): directory of files. - fold (int, optional): fold index for cross validation. Defaults to None. - - Returns: - tuple[list, list]: A tuple of two arrays (training, validation). - """ - new_list_files = copy.deepcopy(list_files) - if fold is not None: - new_list_files_train = [] - new_list_files_val = [] - for d in new_list_files: - d["image"] = os.path.join(data_dir, d["image"]) - - if "label" in d: - d["label"] = os.path.join(data_dir, d["label"]) - - if fold is not None: - if d["fold"] == fold: - new_list_files_val.append(copy.deepcopy(d)) - else: - new_list_files_train.append(copy.deepcopy(d)) - - if fold is not None: - return new_list_files_train, new_list_files_val - else: - return new_list_files, [] - - def prepare_maisi_controlnet_json_dataloader( json_data_list: list | str, data_base_dir: list | str, @@ -683,12 +652,13 @@ def __call__(self, img: NdarrayOrTensor): return out -def KL_loss(z_mu, z_sigma): +def kl_loss(z_mu, z_sigma): """ Compute the Kullback-Leibler (KL) divergence loss for a variational autoencoder (VAE). The KL divergence measures how one probability distribution diverges from a second, expected probability distribution. - In the context of VAEs, this loss term ensures that the learned latent space distribution is close to a standard normal distribution. + In the context of VAEs, + this loss term ensures that the learned latent space distribution is close to a standard normal distribution. Args: z_mu (torch.Tensor): Mean of the latent variable distribution, shape [N,C,H,W,D] or [N,C,H,W]. @@ -698,10 +668,10 @@ def KL_loss(z_mu, z_sigma): torch.Tensor: The computed KL divergence loss, averaged over the batch. """ eps = 1e-10 - kl_loss = 0.5 * torch.sum( + loss = 0.5 * torch.sum( z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2) + eps) - 1, dim=list(range(1, len(z_sigma.shape))) ) - return torch.sum(kl_loss) / kl_loss.shape[0] + return torch.sum(loss) / loss.shape[0] def dynamic_infer(inferer, model, images): diff --git a/models/vista2d/scripts/workflow.py b/models/vista2d/scripts/workflow.py index e987a2f6..d850ea08 100644 --- a/models/vista2d/scripts/workflow.py +++ b/models/vista2d/scripts/workflow.py @@ -746,9 +746,9 @@ def train(self): logger.info( f"Estimated remaining training time for the current model fold {config('fold')} is " - f"{time_remaining_estimate/3600:.2f} hr, " - f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, " - f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n" + f"{time_remaining_estimate / 3600:.2f} hr, " + f"running time {(time.time() - pre_loop_time) / 3600:.2f} hr, " + f"est total time {(time.time() - pre_loop_time + time_remaining_estimate) / 3600:.2f} hr \n" ) # end of main epoch loop @@ -792,7 +792,7 @@ def train(self): logger.info( f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs}." - f"Training time {(time.time() - pre_loop_time)/3600:.2f} hr." + f"Training time {(time.time() - pre_loop_time) / 3600:.2f} hr." ) return best_metric From 21c0e3bac6a07c9cf7c06d2354742b14488ad4ac Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 30 Aug 2024 10:12:01 +0000 Subject: [PATCH 13/19] remove extra f Signed-off-by: Yiheng Wang --- models/maisi_ct_generative/scripts/sample.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 9d203b31..721ab92c 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -386,7 +386,7 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing if len(controllable_anatomy_size) > 10: raise ValueError( ( - f"The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " + "The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." ) ) @@ -558,7 +558,7 @@ def __init__( if not (0 <= autoencoder_sliding_window_infer_overlap <= 1): raise ValueError( ( - f"Value of autoencoder_sliding_window_infer_overlap must be between 0 " + "Value of autoencoder_sliding_window_infer_overlap must be between 0 " f"and 1.\n Got {autoencoder_sliding_window_infer_overlap}" ) ) @@ -648,7 +648,7 @@ def sample_multiple_images(self, num_img): raise ValueError( ( f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img ({num_img}). " - f"This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." + "This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." ) ) for item in selected_mask_files: From 61685a1c1a18814e30a2f8c3b0e04b0820fe2bb6 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 30 Aug 2024 10:17:04 +0000 Subject: [PATCH 14/19] revert other bundle changes Signed-off-by: Yiheng Wang --- ci/unit_tests/test_vista2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/unit_tests/test_vista2d.py b/ci/unit_tests/test_vista2d.py index 50ac1ab4..0eb5fb5d 100644 --- a/ci/unit_tests/test_vista2d.py +++ b/ci/unit_tests/test_vista2d.py @@ -82,7 +82,7 @@ def test_infer_config(self, override): # check_properties=False, need to add monai service properties later check_workflow(workflow, check_properties=False) - expected_output_file = os.path.join(self.tmp_output_dir, f"image_{self.dataset_size - 1}.tif") + expected_output_file = os.path.join(self.tmp_output_dir, f"image_{self.dataset_size-1}.tif") self.assertTrue(os.path.isfile(expected_output_file)) @parameterized.expand([TEST_CASE_TRAIN]) From 5c42ddaa65a590e82550055dec0ed14699b292fe Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 30 Aug 2024 10:19:38 +0000 Subject: [PATCH 15/19] revert large file dataset change Signed-off-by: Yiheng Wang --- models/maisi_ct_generative/large_files.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml index bac9076a..6ab6db86 100644 --- a/models/maisi_ct_generative/large_files.yml +++ b/models/maisi_ct_generative/large_files.yml @@ -14,7 +14,7 @@ large_files: - path: "configs/all_anatomy_size_condtions.json" url: "https://drive.google.com/file/d/1AJyt1DSoUd2x2AOQOgM7IxeSyo4MXNX0/view?usp=sharing" - path: "datasets/all_masks_flexible_size_and_spacing_3000.zip" - url: "https://drive.google.com/file/d/1AJyt1DSoUd2x2AOQOgM7IxeSyo4MXNX0/view?usp=sharing" + url: "https://drive.google.com/file/d/16MKsDKkHvDyF2lEir4dzlxwex_GHStUf/view?usp=sharing" - path: "datasets/IntegrationTest-AbdomenCT.nii.gz" url: "https://drive.google.com/file/d/1OTgt_dyBgvP52krKRXWXD3u0L5Zbj5JR/view?usp=share_link" - path: "datasets/C4KC-KiTS_subset.zip" From c4fff2142a2f57a4eaa346fe2600f3abf00d0fac Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 30 Aug 2024 11:05:33 +0000 Subject: [PATCH 16/19] revert vista2d changes Signed-off-by: Yiheng Wang --- models/vista2d/scripts/workflow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/vista2d/scripts/workflow.py b/models/vista2d/scripts/workflow.py index d850ea08..e987a2f6 100644 --- a/models/vista2d/scripts/workflow.py +++ b/models/vista2d/scripts/workflow.py @@ -746,9 +746,9 @@ def train(self): logger.info( f"Estimated remaining training time for the current model fold {config('fold')} is " - f"{time_remaining_estimate / 3600:.2f} hr, " - f"running time {(time.time() - pre_loop_time) / 3600:.2f} hr, " - f"est total time {(time.time() - pre_loop_time + time_remaining_estimate) / 3600:.2f} hr \n" + f"{time_remaining_estimate/3600:.2f} hr, " + f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, " + f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n" ) # end of main epoch loop @@ -792,7 +792,7 @@ def train(self): logger.info( f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs}." - f"Training time {(time.time() - pre_loop_time) / 3600:.2f} hr." + f"Training time {(time.time() - pre_loop_time)/3600:.2f} hr." ) return best_metric From f68e12a8a243940e17838949003208375c234a48 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 2 Sep 2024 12:07:25 +0800 Subject: [PATCH 17/19] add arg Signed-off-by: Yiheng Wang --- models/maisi_ct_generative/configs/inference.json | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 7cea2c85..2ad6aa08 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -19,6 +19,7 @@ "all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'", "label_dict_json": "$@bundle_root + '/configs/label_dict.json'", "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'", + "real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'", "num_output_samples": 1, "body_region": [ "abdomen" @@ -285,6 +286,7 @@ "controllable_anatomy_size": "@controllable_anatomy_size", "image_output_ext": "@image_output_ext", "label_output_ext": "@label_output_ext", + "real_img_median_statistics": "@real_img_median_statistics_file", "device": "@device", "latent_shape": "@latent_shape", "mask_generation_latent_shape": "@mask_generation_latent_shape", From 3a34a59a40e0bc01d67e2f706e291eb4c9667c39 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 2 Sep 2024 04:36:45 +0000 Subject: [PATCH 18/19] update error message Signed-off-by: Yiheng Wang --- ci/unit_tests/test_maisi_ct_generative.py | 4 ++-- models/maisi_ct_generative/scripts/sample.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/unit_tests/test_maisi_ct_generative.py b/ci/unit_tests/test_maisi_ct_generative.py index 3a3028c7..e2b9db54 100644 --- a/ci/unit_tests/test_maisi_ct_generative.py +++ b/ci/unit_tests/test_maisi_ct_generative.py @@ -93,7 +93,7 @@ "body_region": ["head"], "anatomy_list": ["colon cancer primaries"], }, - "Cannot find body region with given organ list.", + "Cannot find body region with given anatomy list.", ] TEST_CASE_INFER_ERROR_2 = [ @@ -156,7 +156,7 @@ "body_region": ["chest"], "anatomy_list": ["colon", "spleen", "trachea", "left humerus", "sacrum", "heart"], }, - "Cannot find body region with given organ list.", + "Cannot find body region with given anatomy list.", ] TEST_CASE_TRAIN = [ diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 721ab92c..30127b80 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -1009,7 +1009,7 @@ def find_closest_masks(self, num_img): final_candidates.append(c) if len(final_candidates) == 0: - raise ValueError("Cannot find body region with given organ list.") + raise ValueError("Cannot find body region with given anatomy list.") return final_candidates def quality_check(self, image_data, label_data): From b486c32be50f7e0e63864e9ede094b67a1264cbb Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 4 Sep 2024 11:27:58 +0800 Subject: [PATCH 19/19] remove kl loss Signed-off-by: Yiheng Wang --- models/maisi_ct_generative/scripts/utils.py | 22 --------------------- 1 file changed, 22 deletions(-) diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index b34b9113..e786fdaa 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -652,28 +652,6 @@ def __call__(self, img: NdarrayOrTensor): return out -def kl_loss(z_mu, z_sigma): - """ - Compute the Kullback-Leibler (KL) divergence loss for a variational autoencoder (VAE). - - The KL divergence measures how one probability distribution diverges from a second, expected probability distribution. - In the context of VAEs, - this loss term ensures that the learned latent space distribution is close to a standard normal distribution. - - Args: - z_mu (torch.Tensor): Mean of the latent variable distribution, shape [N,C,H,W,D] or [N,C,H,W]. - z_sigma (torch.Tensor): Standard deviation of the latent variable distribution, same shape as 'z_mu'. - - Returns: - torch.Tensor: The computed KL divergence loss, averaged over the batch. - """ - eps = 1e-10 - loss = 0.5 * torch.sum( - z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2) + eps) - 1, dim=list(range(1, len(z_sigma.shape))) - ) - return torch.sum(loss) / loss.shape[0] - - def dynamic_infer(inferer, model, images): """ Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer.