Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The tutorial accompanying the pull request #7308 for MONAI core, which adds SURE-loss and Conjugate Gradient #1631

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
47ade41
add utils and model files
cxlcl Dec 19, 2023
7fbea11
Merge branch 'Project-MONAI:main' into main
cxlcl Feb 2, 2024
c5f6952
remove smrd from reconstruction; we will move it to generative
cxlcl Feb 3, 2024
28675fd
remove smrd from reconstruction; we will move it to generative
cxlcl Feb 3, 2024
626d14a
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 3, 2024
06ed6cb
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 3, 2024
63ed95d
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 3, 2024
d39b9fa
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 3, 2024
4a49e52
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 3, 2024
f43bdf7
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 4, 2024
04a6ad0
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 4, 2024
aa20c9d
added smrd tutorial that relies on SURE-loss and Conjugate Gradient
cxlcl Feb 4, 2024
7ace6c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2024
477e041
Update generative/smrd/models/ema.py
cxlcl Feb 24, 2024
1d0fd50
Update generative/README.md
cxlcl Feb 24, 2024
7c44130
Update generative/smrd/README.md
cxlcl Feb 24, 2024
2c83658
Update README.md
cxlcl Feb 24, 2024
a9179fc
update notebook; update functions; add torchscript for model config
cxlcl Feb 25, 2024
4420513
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2024
8b20def
add smrd_optimizer.py
cxlcl Feb 28, 2024
d5fc798
Merge branch 'cg_sure' of github.com:cxlcl/monai-tutorials into cg_sure
cxlcl Feb 28, 2024
dcc3c25
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2024
9f3cd49
add smrd_optimizer.py
cxlcl Feb 28, 2024
0ae721f
Merge branch 'cg_sure' of github.com:cxlcl/monai-tutorials into cg_sure
cxlcl Feb 28, 2024
1f59820
Merge branch 'main' into cg_sure
ericspod Apr 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,7 @@ This tutorial shows the use cases of training and validating a 3D Latent Diffusi

##### [2D latent diffusion model](./generative/2d_ldm)
This tutorial shows the use cases of training and validating a 2D Latent Diffusion Model.

##### [SURE-Based Robust MRI Reconstruction with Diffusion Models (SMRD)](./generative/smrd)
Example shows the use case of inference with a pre-trained score function while taking into account available measurements,
using the SURE-Based Robust MRI Reconstruction with Diffusion Models (SMRD) method.
cxlcl marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 4 additions & 0 deletions generative/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ Example shows the use cases of training and validating a 3D Latent Diffusion Mod

## [Brats 2D latent diffusion model](./2d_ldm/README.md)
Example shows the use cases of training and validating a 2D Latent Diffusion Model on axial slices from Brats 2016&2017 data.

## [SURE-Based Robust MRI Reconstruction with Diffusion Models (SMRD)](./smrd/README.md)
Example shows the use case of inference with a pre-trained score function while taking into account available measurements,
using the SURE-Based Robust MRI Reconstruction with Diffusion Models (SMRD) method.
cxlcl marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions generative/smrd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SURE-Based Robust MRI Reconstruction with Diffusion Models (SMRD). MICCAI 2023 (https://link.springer.com/chapter/10.1007/978-3-031-43898-1_20)

![SMRD](figures/SMRD.png)
It show cases how the conjugate gradient method can be used to enforece meausrement consistency in diffusion model based MRI reconstruction; it also shows
how the SURE-based method can be used to perform early stopping, so less iteratios and artifacts are introduced during the generation of the reconstructured image.
cxlcl marked this conversation as resolved.
Show resolved Hide resolved
865 changes: 865 additions & 0 deletions generative/smrd/SMRD.ipynb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put the code for SMRDOptimizer into its own file since it's so large in this notebook.

Large diffs are not rendered by default.

133 changes: 133 additions & 0 deletions generative/smrd/configs/demo/SMRD-brain_T2-noise005-R8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
user: csgm-mri-langevin
model_type: ncsnv2
seed: 42
device: cuda
batch_size: 1
repeat: 1

# The pre-trained NCSNV2 checkpoint
gen_ckpt: checkpoints/mri-unet-smrd.pth

## weights of different losses
mse: 5.

## start from different noise level of langevin
start_iter: 1155

# can be decreased for super-resolution
image_size:
- 384
- 384

## files
input_dir: ./datasets/brain_T2/
maps_dir: ./datasets/brain_T2_maps/
anatomy: brain

early_stop: stop

## Acceleration
R: 8
pattern: equispaced
exp_names: 0
orientation: vertical

## SMRD hyperparameters
num_cg_iter: 5
window_size: 160
lambda_lr: 0.2
init_lambda_update: 1154
last_lambda_update: 1655

## Lambda
lambda_init: 2.0
lambda_end: 2.0
lambda_func: learnable

exp_name: admm-learn-sure_brain_noise_005_R8
learning_loss: SURE

## Input noise
noise_std: 0.005


# logging
save_latent: false
save_images: true
save_dataloader_every: 1000000
save_iter: 100

debug: false
world_size: 1
multiprocessing: false
port: 12345

langevin_config:
training:
batch_size: 4
n_epochs: 500000
n_iters: 320001
snapshot_freq: 10000
snapshot_sampling: true
anneal_power: 2
log_all_sigmas: false

sampling:
batch_size: 4
data_init: false
step_lr: 5e-5
n_steps_each: 4
ckpt_id: 5000
final_only: true
fid: false
denoise: true
num_samples4fid: 10000
inpainting: false
interpolation: false
n_interpolations: 8

fast_fid:
batch_size: 1000
num_samples: 1000
step_lr: 0.0000009
n_steps_each: 3
begin_ckpt: 100000
end_ckpt: 80000
verbose: false
ensemble: false

test:
begin_ckpt: 5000
end_ckpt: 80000
batch_size: 100

data:
dataset: "mri-mvue"
image_size: 384
channels: 2
logit_transform: false
uniform_dequantization: false
gaussian_dequantization: false
random_flip: false
rescaled: false
num_workers: 8

model:
sigma_begin: 232
num_classes: 2311
ema: true
ema_rate: 0.999
spec_norm: false
sigma_dist: geometric
sigma_end: 0.0066
normalization: InstanceNorm++
nonlinearity: elu
ngf: 128

optim:
weight_decay: 0.000
optimizer: "Adam"
lr: 0.0001
beta1: 0.9
amsgrad: false
eps: 0.001
Binary file added generative/smrd/figures/SMRD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions generative/smrd/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
58 changes: 58 additions & 0 deletions generative/smrd/models/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn


class EMAHelper(object):
cxlcl marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, mu=0.999):
self.mu = mu
self.shadow = {}

def register(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()

def update(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name].data = (1.0 - self.mu) * param.data + self.mu * self.shadow[name].data

def ema(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
param.data.copy_(self.shadow[name].data)

def ema_copy(self, module):
if isinstance(module, nn.DataParallel):
inner_module = module.module
module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not do inner_module.clone()? This line will work if the type of the inner module is something constructable like this, but the Pytorch pattern is to use the clone method then the load_state_dict() method call isn't needed.

module_copy.load_state_dict(inner_module.state_dict())
module_copy = nn.DataParallel(module_copy)
else:
module_copy = type(module)(module.config).to(module.config.device)
module_copy.load_state_dict(module.state_dict())
# module_copy = copy.deepcopy(module)
self.ema(module_copy)
return module_copy

def state_dict(self):
return self.shadow

def load_state_dict(self, state_dict):
self.shadow = state_dict
Loading
Loading