-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MedNext implementation #8004
base: dev
Are you sure you want to change the base?
Conversation
Signed-off-by: Suraj Pai <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work! I think it's a great addition to MONAI!
return self.conv_out(x) | ||
|
||
|
||
class LayerNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this LayerNorm implementation? Or can we just use the torch.nn.LayerNorm? I don't see the channels_last used anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a copy from the original codebase and can probably be replaced with torch.nn.LayerNorm as 'channels_first' is always assumed.
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) | ||
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) | ||
|
||
def forward(self, x, dummy_tensor=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know what the dummy_tensor is used for? I don't see it being used anywhere. This applies to the other forward functions as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also come from a copy from the original codebase and should be removed as it's never supposed to be used.
self, | ||
spatial_dims: int = 3, | ||
init_filters: int = 32, | ||
in_channels: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about setting this to None per default and using LazyConv if the in_channels are set to None?
blocks_down: list = [2, 2, 2, 2], | ||
blocks_bottleneck: int = 2, | ||
blocks_up: list = [2, 2, 2, 2], | ||
norm_type: str = "group", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this a StrEnum? Makes it easy to see the options. Ofc also applies to the Block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@johnzielke Can you give me an exemple of its usage in one of network implementation. Because it's not straighforward for me how to use it.
monai/networks/nets/mednext.py
Outdated
enc_exp_r: int = 2, | ||
dec_exp_r: int = 2, | ||
bottlenec_exp_r: int = 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO encoder_expansion_ratio (the same for decoder and bottleneck) is better here. It's a public API, makes it easier to see what it refers to without having to look at the docs and the few letters more don't matter that much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've kept this convention similar to the original implementation which uses a joint exp_r
argument. Happy to rename it. Agree with the comment
monai/networks/nets/mednext.py
Outdated
bottlenec_exp_r: int = 2, | ||
kernel_size: int = 7, | ||
deep_supervision: bool = False, | ||
do_res: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use_residual_connections?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think use_res
would probably suffice? Similar to args in other APIs like SegResNet such as use_conv_final
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, my only concern is that res can be sth. like resolution or result.
monai/networks/nets/mednext.py
Outdated
kernel_size: int = 7, | ||
deep_supervision: bool = False, | ||
do_res: bool = False, | ||
do_res_up_down: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use_residual_connections_up_down_blocks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds very verbose to me and may not be necessary. I think the contraction res is easily understood. If playing with these params, users are encouraged to read the code nonetheless to figure out where it goes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about use_residual_up_down ?
if dim == "2d": | ||
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) | ||
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) | ||
else: | ||
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) | ||
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is easier to understand, but the current one is fine as well.
if dim == "2d": | |
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) | |
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True) | |
else: | |
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) | |
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True) | |
if dim == "2d": | |
grn_parameter_shape = (1,1) | |
elif dim == "3d": | |
grn_parameter_shape = (1,1,1) | |
else: | |
raise ValueError() | |
grn_parameter_shape = (1, exp_r * in_channels,) + grn_parameter_shape | |
self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) | |
self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True) | |
Can't it be usefull to adapt the factory functions of create_mednext_v1 script inside your |
…edNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder Signed-off-by: Robin CREMESE <[email protected]>
Thanks for the comments @johnzielke and thank you for the updates @rcremese. I'll update my branch to dev and look through the changes in your PR asap. |
…edNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder Signed-off-by: Robin CREMESE <[email protected]>
…edNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder Signed-off-by: Robin CREMESE <[email protected]>
Signed-off-by: Suraj Pai <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Suraj Pai <[email protected]>
@KumoLiu This implementation should be ready now. Please let me know if you have any comments |
Signed-off-by: Suraj Pai <[email protected]>
@KumoLiu Added. Do you think there would be interest to add this as a candidate for Auto3DSeg? I refer to this paper for its performance benchmarking: https://arxiv.org/abs/2404.09556 |
Thank you for bringing this up, it's an interesting suggestion. I believe it could be worthwhile to consider this as a potential candidate for Auto3DSeg. However, before moving forward, I would appreciate hearing others' thoughts and insights on whether this aligns with the current goals and roadmap for Auto3DSeg. cc @mingxin-zheng @dongyang0122 @Nic-Ma @myron |
Great work @surajpaib ! any plans on when we're gonna be able to get this into main? |
Fixes #7786
Description
Added MedNext architectures implementation for MONAI.
Since a lot of the code is heavily sourced from the original MedNext repo, https://github.com/MIC-DKFZ/MedNeXt, I wanted to check if there is an attribution policy with regarded to borrowed source code. I've added a derivative notice bellow the monai copyright comment. Let me know if this needs to be changed.
The blocks have been taken almost as is but the network implementation has been changed largely to allow flexible blocks and follow MONAI segresnet styling.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.