-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of Probabilistic U-Net. #46
base: main
Are you sure you want to change the base?
Conversation
@@ -67,53 +70,38 @@ | |||
"length_3D": [128, 128, 128], | |||
"stride_3D": [64, 64, 64], | |||
"attention": true, | |||
"n_filters": 8 | |||
"n_filters": 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually be very interested to know how much difference this makes; we should definitely do an experiment on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, definitely! From my experience, increasing the number of feature maps/filters helps in improving performance! Hence, I try to fit as many as possible into the memory.
"early_stopping_epsilon": 0.001 | ||
}, | ||
"scheduler": { | ||
"initial_lr": 5e-05, | ||
"initial_lr": 2e-04, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These might be old experiments and highly dependent on the specific hyperparameters (e.g. batch size) but in my experience with the challenge lowering the learning rate was very helpful. I have seen cases where lowering the learning rate from 5e-05
to 3e-05
made a huge difference!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! I really didn't know that lowering the learning rate could make that much of a difference. I am struggling to get PU-Net working for this dataset, maybe this might help to some extent! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks beautiful 😍, and thanks for introducing us to PUNet! I know we are focusing on more critical stuff now, but I truly think PUNet addresses a problem / scenario that is not addressed by other models. Looking forward to a PUNet with amazing results, and hope my comments can help along that way!
I just want to make a note of the comment I made in your modeling/datasets.py
: this is the only bug that I could spot, and so sorry for introducing it in the first place 🙏!
assert len(ses01_subvolumes) == len(ses02_subvolumes) == len(gts_subvolumes[0]) | ||
|
||
for i in range(len(ses01_subvolumes)): | ||
subvolumes_ = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Please ignore this review if it doesn't apply!)
I remember doing a similar thing (i.e. reading all GTs from experts) and it completely blew up the (CPU) memory, limiting my experiments severely! This was in the context of multi-GPU training though, so it might not apply here. If this is ever a problem, we could also move the GT-reading part to __get_item()
so that it's called for each subvolume at a time. It would slow things down considerably, but just wanted to mention this as an alternative 🙂.
in_dim = self.in_channels if i == 0 else out_dim | ||
out_dim = num_feat_maps[i] | ||
if i != 0: | ||
layers.append(nn.AvgPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I most likely don't know what I am talking about, but I am used to seeing MaxPoolkd
here instead of the current average-pooling layer. The smoothing effect average-pooling introduces might introduce errors in edges for the segmentation component. Please feel free to ignore if this is what the paper suggests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my experience with GANs, what I've seen is that MaxPooling abruptly eliminates certain elements which can adversely affect the learning, hence AveragePooling is used. However, I am carrying this over from GANs I don't know if it's the same for segmentation models. Also, I thought that since U-Net has an upsampling path, it might help to NOT abruptly lose voxels through the MaxPool op. Idk if it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you don't directly lose voxels per se, but I do see what you're saying! Also ivadomed.models.Modified3DUNet
seems to be not using any pooling layers, so that might also be something to experiment with if applicable. Of course, sometimes we don't have any choice though.
layers.append(nn.AvgPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=True)) | ||
|
||
layers.append(nn.Conv3d(in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=int(padding))) | ||
layers.append(nn.ReLU(inplace=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at ModifiedUNet3D
and other segmentation models, there seems to be a preference of LeakyReLU
over ReLU
. The way LeakyReLU
has been introduced in the field poses a hard-to-dispute hypothesis that it retains properties of ReLU
whereas also preventing "dead" neurons. However, I understand that in practice this might not be the case, and having these discussions over DL components are hardly ever productive. Quite frankly, I just like bringing these up because it is fun 😄! (Also, in some rare cases these small changes could help!)
encoding = self.encoder(inp) | ||
self.show_enc = encoding | ||
# for getting the mean of the resulting volume --> (batch-size x Channels x Depth x Height x Width) | ||
encoding = torch.mean(encoding, dim=2, keepdim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two comments:
- Can you describe what lines
96-98
achieve? It seems like this can be done in a single line as well! - Also, to someone like me who has never heard of
AxisAlignedConvGaussian
, this part seems like a huge bottleneck. My experience is that taking the mean of features (or aggregating them in any other way) inside a neural network always ends up being a bottleneck. Instead, modifying hyperparameters (e.g. kernel-size, num. filters, etc.) so that the encoder outputs features in the desired shape works better. Again, keep in mind that I have not read the paper and not very knowledgeable about PUNet though 🙂 (that is to say, feel free to ignore).
# Squeeze all the singleton dimensions | ||
mu_log_sigma = torch.squeeze(mu_log_sigma) # shape: (B x (2*latent_dim) ) | ||
|
||
mu = mu_log_sigma[:, :self.latent_dim] # take the first "latent_dim" samples as mu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is very interesting! They way I have been implementing VAEs / Convolutional VAEs was to always have two separate layers for extracting the mu
and the logvar
in parallel. (Also quick note: should it be called log_var
instead of log_sigma
which is log_standard_deviation
instead?)
Now I can't answer the following question: Why can't we achieve the same thing with a single layer as you do here? (You probably can!) But, I would still make sure that this is indeed how the authors implement / suggest to implement this part!
# self.reconstruction_loss = torch.sum(reconstruction_loss) | ||
# self.mean_reconstruction_loss = torch.mean(reconstruction_loss) | ||
|
||
# # TODO: use DiceLoss as the criterion instead? --> Uncomment below lines |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I would actually vote DiceLoss
, and the only reason is DiceLoss
is a metric that we know for sure is going to be used in the test phase of the challenge! I would try to compare the validation loss (SOFT and HARD Dice scores) you are getting when you train with BCEWithLogitsLoss()
vs. DiceLoss()
. Or better yet, we c an know compare ANIMA metrics (e.g. F1 score) as well 😉!
for _ in range(num_predictions): | ||
mask_pred = model.sample(testing=True) | ||
# TODO: this line below gets hard predictions. Use just sigmoid and see how Dice performs. | ||
mask_pred = torch.sigmoid(mask_pred) # getting a soft pred. ; shape: (B x 1 x P x P x P) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we have discussed over this so many times, but something else to try is removing sigmoid and using normalized ReLU (i.e. ReLU but in 0-1 range) instead. This is what gave the best results for me!
# NOTE: This will also help with discarding empty inputs! | ||
if ses01_patches.std() < 1e-5 or ses02_patches.std() < 1e-5: | ||
if ses01_subvolumes.std() < 1e-5 or ses02_subvolumes.std() < 1e-5: | ||
if self.train: | ||
return self.__getitem__(random.randint(0, self.__len__() - 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed this very late, but actually this line causes information leakage. The reason is because during training it can get a validation index, and during validation it could get a training index. Really sorry about this! In the new version of datasets.py
which you can check from #44 (um/transunet_setup
branch), I repeat what I do for validation as seen in the else
statement for training as well.
The primary contribution of this PR is the implementation of Probabilistic U-Net (in 3D) for the MSSeg challenge. Architectural hints and suggestions have been taken from the paper's appendix. Here is a summary of the major changes:
datasets.py
file was modified to load subvolumes of size128x128x128
directly and importantly, this version does away with the consensus GT (for the time being) and instead loads one of the 4 experts' segmentation at random during training.unet.py
andprobabilistic_unet.py
introduce the PU-Net architecture. As mentioned in the paper, 3 networks are involved during training - the prior net, posterior net and the standard U-Net.utils.py
just adds some basic utility functions for initializing the PU-Net model correctly.Review is primarily required for points 1 and 2 (punet.py in particular), just to ensure that the code is logical. Suggestions and areas for improvement are welcomed!
Things To-Do:
BCEWithLogitsLoss
. Try usingivadomed
's readily available DiceLoss to see how the training fares.Note: The branch is inappropriately named as
run_baselines
, which it was originally set out to be. But then it evolved into PU-Net completely.