diff --git a/generation/maisi/configs/config_infer.json b/generation/maisi/configs/config_infer.json index 9586081af..fc08a7bda 100644 --- a/generation/maisi/configs/config_infer.json +++ b/generation/maisi/configs/config_infer.json @@ -18,5 +18,10 @@ 2.0 ], "autoencoder_sliding_window_infer_size": [48,48,48], - "autoencoder_sliding_window_infer_overlap": 0.25 + "autoencoder_sliding_window_infer_overlap": 0.25, + "controlnet": "$@controlnet_def", + "diffusion_unet": "$@diffusion_unet_def", + "autoencoder": "$@autoencoder_def", + "mask_generation_autoencoder": "$@mask_generation_autoencoder_def", + "mask_generation_diffusion": "$@mask_generation_diffusion_def" } diff --git a/generation/maisi/configs/config_trt.json b/generation/maisi/configs/config_trt.json new file mode 100644 index 000000000..fc52486e3 --- /dev/null +++ b/generation/maisi/configs/config_trt.json @@ -0,0 +1,22 @@ +{ + "+imports": [ + "$from monai.networks import trt_compile" + ], + "c_trt_args": { + "export_args": { + "dynamo": "$False", + "report": "$True" + }, + "output_lists": [ + [ + -1 + ], + [ + ] + ] + }, + "device": "cuda", + "controlnet": "$trt_compile(@controlnet_def.to(@device), @trained_controlnet_path, @c_trt_args)", + "diffusion_unet": "$trt_compile(@diffusion_unet_def.to(@device), @trained_diffusion_path)", + "mask_generation_diffusion": "$trt_compile(@mask_generation_diffusion_def.to(@device), @trained_mask_generation_diffusion_path)" +} diff --git a/generation/maisi/scripts/inference.py b/generation/maisi/scripts/inference.py index 8220f200c..968d5bf49 100644 --- a/generation/maisi/scripts/inference.py +++ b/generation/maisi/scripts/inference.py @@ -48,6 +48,12 @@ def main(): default="./configs/config_infer.json", help="config json file that stores inference hyper-parameters", ) + parser.add_argument( + "-x", + "--extra-config-file", + default=None, + help="config json file that stores inference extra parameters", + ) parser.add_argument( "-s", "--random-seed", @@ -140,6 +146,16 @@ def main(): setattr(args, k, v) print(f"{k}: {v}") + # + # ## Read in optional extra configuration setting - typically acceleration options (TRT) + # + # + if args.extra_config_file is not None: + extra_config_dict = json.load(open(args.extra_config_file, "r")) + for k, v in extra_config_dict.items(): + setattr(args, k, v) + print(f"{k}: {v}") + check_input( args.body_region, args.anatomy_list, @@ -158,25 +174,25 @@ def main(): device = torch.device("cuda") - autoencoder = define_instance(args, "autoencoder_def").to(device) + autoencoder = define_instance(args, "autoencoder").to(device) checkpoint_autoencoder = torch.load(args.trained_autoencoder_path) autoencoder.load_state_dict(checkpoint_autoencoder) - diffusion_unet = define_instance(args, "diffusion_unet_def").to(device) + diffusion_unet = define_instance(args, "diffusion_unet").to(device) checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path) diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True) scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device) - controlnet = define_instance(args, "controlnet_def").to(device) + controlnet = define_instance(args, "controlnet").to(device) checkpoint_controlnet = torch.load(args.trained_controlnet_path) monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict()) controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True) - mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder_def").to(device) + mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder").to(device) checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path) mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder) - mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion_def").to(device) + mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion").to(device) checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path) mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"]) mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"]