Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Probabilistic U-Net. #46

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

naga-karthik
Copy link
Member

The primary contribution of this PR is the implementation of Probabilistic U-Net (in 3D) for the MSSeg challenge. Architectural hints and suggestions have been taken from the paper's appendix. Here is a summary of the major changes:

  1. The datasets.py file was modified to load subvolumes of size 128x128x128 directly and importantly, this version does away with the consensus GT (for the time being) and instead loads one of the 4 experts' segmentation at random during training.
  2. The files unet.py and probabilistic_unet.py introduce the PU-Net architecture. As mentioned in the paper, 3 networks are involved during training - the prior net, posterior net and the standard U-Net.
  3. utils.py just adds some basic utility functions for initializing the PU-Net model correctly.
  4. For validation, 5 segmentation masks are produced for each input. These are concatenated and their mean is taken to obtain the predicted "consensus" mask. Dice is calculated b/w this and the actual GT.

Review is primarily required for points 1 and 2 (punet.py in particular), just to ensure that the code is logical. Suggestions and areas for improvement are welcomed!

Things To-Do:

  1. Currently, the reconstruction loss uses PyTorch's BCEWithLogitsLoss. Try using ivadomed's readily available DiceLoss to see how the training fares.
  2. Maybe look at STAPLE for aggregating predicted segmentation masks before calculating validation Dice?

Note: The branch is inappropriately named as run_baselines, which it was originally set out to be. But then it evolved into PU-Net completely.

@@ -67,53 +70,38 @@
"length_3D": [128, 128, 128],
"stride_3D": [64, 64, 64],
"attention": true,
"n_filters": 8
"n_filters": 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually be very interested to know how much difference this makes; we should definitely do an experiment on this!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, definitely! From my experience, increasing the number of feature maps/filters helps in improving performance! Hence, I try to fit as many as possible into the memory.

"early_stopping_epsilon": 0.001
},
"scheduler": {
"initial_lr": 5e-05,
"initial_lr": 2e-04,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These might be old experiments and highly dependent on the specific hyperparameters (e.g. batch size) but in my experience with the challenge lowering the learning rate was very helpful. I have seen cases where lowering the learning rate from 5e-05 to 3e-05 made a huge difference!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second this!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! I really didn't know that lowering the learning rate could make that much of a difference. I am struggling to get PU-Net working for this dataset, maybe this might help to some extent! :)

Copy link
Contributor

@uzaymacar uzaymacar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks beautiful 😍, and thanks for introducing us to PUNet! I know we are focusing on more critical stuff now, but I truly think PUNet addresses a problem / scenario that is not addressed by other models. Looking forward to a PUNet with amazing results, and hope my comments can help along that way!

I just want to make a note of the comment I made in your modeling/datasets.py: this is the only bug that I could spot, and so sorry for introducing it in the first place 🙏!

assert len(ses01_subvolumes) == len(ses02_subvolumes) == len(gts_subvolumes[0])

for i in range(len(ses01_subvolumes)):
subvolumes_ = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Please ignore this review if it doesn't apply!)

I remember doing a similar thing (i.e. reading all GTs from experts) and it completely blew up the (CPU) memory, limiting my experiments severely! This was in the context of multi-GPU training though, so it might not apply here. If this is ever a problem, we could also move the GT-reading part to __get_item() so that it's called for each subvolume at a time. It would slow things down considerably, but just wanted to mention this as an alternative 🙂.

in_dim = self.in_channels if i == 0 else out_dim
out_dim = num_feat_maps[i]
if i != 0:
layers.append(nn.AvgPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I most likely don't know what I am talking about, but I am used to seeing MaxPoolkd here instead of the current average-pooling layer. The smoothing effect average-pooling introduces might introduce errors in edges for the segmentation component. Please feel free to ignore if this is what the paper suggests!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience with GANs, what I've seen is that MaxPooling abruptly eliminates certain elements which can adversely affect the learning, hence AveragePooling is used. However, I am carrying this over from GANs I don't know if it's the same for segmentation models. Also, I thought that since U-Net has an upsampling path, it might help to NOT abruptly lose voxels through the MaxPool op. Idk if it makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you don't directly lose voxels per se, but I do see what you're saying! Also ivadomed.models.Modified3DUNet seems to be not using any pooling layers, so that might also be something to experiment with if applicable. Of course, sometimes we don't have any choice though.

layers.append(nn.AvgPool3d(kernel_size=3, stride=2, padding=0, ceil_mode=True))

layers.append(nn.Conv3d(in_channels=in_dim, out_channels=out_dim, kernel_size=3, padding=int(padding)))
layers.append(nn.ReLU(inplace=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at ModifiedUNet3D and other segmentation models, there seems to be a preference of LeakyReLU over ReLU. The way LeakyReLU has been introduced in the field poses a hard-to-dispute hypothesis that it retains properties of ReLU whereas also preventing "dead" neurons. However, I understand that in practice this might not be the case, and having these discussions over DL components are hardly ever productive. Quite frankly, I just like bringing these up because it is fun 😄! (Also, in some rare cases these small changes could help!)

encoding = self.encoder(inp)
self.show_enc = encoding
# for getting the mean of the resulting volume --> (batch-size x Channels x Depth x Height x Width)
encoding = torch.mean(encoding, dim=2, keepdim=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two comments:

  • Can you describe what lines 96-98 achieve? It seems like this can be done in a single line as well!
  • Also, to someone like me who has never heard of AxisAlignedConvGaussian, this part seems like a huge bottleneck. My experience is that taking the mean of features (or aggregating them in any other way) inside a neural network always ends up being a bottleneck. Instead, modifying hyperparameters (e.g. kernel-size, num. filters, etc.) so that the encoder outputs features in the desired shape works better. Again, keep in mind that I have not read the paper and not very knowledgeable about PUNet though 🙂 (that is to say, feel free to ignore).

# Squeeze all the singleton dimensions
mu_log_sigma = torch.squeeze(mu_log_sigma) # shape: (B x (2*latent_dim) )

mu = mu_log_sigma[:, :self.latent_dim] # take the first "latent_dim" samples as mu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is very interesting! They way I have been implementing VAEs / Convolutional VAEs was to always have two separate layers for extracting the mu and the logvar in parallel. (Also quick note: should it be called log_var instead of log_sigma which is log_standard_deviation instead?)

Now I can't answer the following question: Why can't we achieve the same thing with a single layer as you do here? (You probably can!) But, I would still make sure that this is indeed how the authors implement / suggest to implement this part!

# self.reconstruction_loss = torch.sum(reconstruction_loss)
# self.mean_reconstruction_loss = torch.mean(reconstruction_loss)

# # TODO: use DiceLoss as the criterion instead? --> Uncomment below lines
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would actually vote DiceLoss, and the only reason is DiceLoss is a metric that we know for sure is going to be used in the test phase of the challenge! I would try to compare the validation loss (SOFT and HARD Dice scores) you are getting when you train with BCEWithLogitsLoss() vs. DiceLoss(). Or better yet, we c an know compare ANIMA metrics (e.g. F1 score) as well 😉!

for _ in range(num_predictions):
mask_pred = model.sample(testing=True)
# TODO: this line below gets hard predictions. Use just sigmoid and see how Dice performs.
mask_pred = torch.sigmoid(mask_pred) # getting a soft pred. ; shape: (B x 1 x P x P x P)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we have discussed over this so many times, but something else to try is removing sigmoid and using normalized ReLU (i.e. ReLU but in 0-1 range) instead. This is what gave the best results for me!

# NOTE: This will also help with discarding empty inputs!
if ses01_patches.std() < 1e-5 or ses02_patches.std() < 1e-5:
if ses01_subvolumes.std() < 1e-5 or ses02_subvolumes.std() < 1e-5:
if self.train:
return self.__getitem__(random.randint(0, self.__len__() - 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this very late, but actually this line causes information leakage. The reason is because during training it can get a validation index, and during validation it could get a training index. Really sorry about this! In the new version of datasets.py which you can check from #44 (um/transunet_setup branch), I repeat what I do for validation as seen in the else statement for training as well.

@naga-karthik naga-karthik added the bug Something isn't working label Jul 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants