-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load Segmentation Trainer Weights #1379
base: main
Are you sure you want to change the base?
Conversation
Need to think about this one. I wonder if there isn't an easier way to tell Lightning to save the encoder in a separate file. Then it's trainer-dependent instead of having to if-statement every possible naming scheme in one function. |
I'm not sure I understand. You are trying to load a model trained using which trainer into the segmentation trainer? Can you give the error message? |
If you need to load a checkpoint in the same trainer you can just use pytorch Lightning's loading method.
|
I believe the checkpoint comes from MoCo and/or SimCLR and looks very different than the checkpoint expected by SemanticSegmentationTask. |
I don't think so. MoCo and SimCLR tasks named the backbone "backbone" not "encoder. |
Ah I see what's happening. Right now we don't support loading an entire pretrained segmentation checkpoint using the weights option. We only support loading the encoder part of the weights, not the decoder. So the solution would be to just load the checkpoint directly using pytorch lighting's checkpoint option. Edit: so in this case, the PR changes are the correct solution. |
Is this still needed or is this superseded by #1403? |
I believe this is separate. This has to do with loading an entire UNet checkpoint using the weights argument. Right now the weights argument only loads the weights into the encoder (not the decoder). |
I guess one can load a full checkpoint involving encoder and decoder after model initialization via lightning, but currently, the "weights" string is still just capable of loading encoder weights, so maybe that needs to be made more explicit in doc string? |
This PR changes the loading of a checkpoint in the segmentation model. Given that I have trained a model with a torchgeo trainer, I might want to do with the
weights
argument:Not sure which of all should be supported by default, or whether there should be suggestions on how to do each of these (or other things).