Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About how multimodal upstream tasks migrate to unimodal downstream tasks #18

Open
zzzjjj98 opened this issue Jul 22, 2024 · 9 comments
Open

Comments

@zzzjjj98
Copy link

Hello! I am very inspired by your work. Referring to your work, I have some doubts while pre-training MRI data.
I want to use brain tumor MRI containing four modalities for pre-training, what to do if the downstream task is single modality or modality missing?
I have a few ideas: 1. Pre-training treats the four-modality MRI of the same patient as four independent inputs . 2. Pre-training uses four-channel inputs, and when the modality of the downstream task is missing, it is replaced with all-black inputs. 3. The downstream task is loaded with only a portion of the parameters of the pre-training, and the rest of the layers that involve the channels are re-trained.
I wonder which strategy is more reasonable? Or do you have a better way to deal with it, looking forward to your reply.

@Luffy03
Copy link
Owner

Luffy03 commented Jul 22, 2024

Hi, many thanks for your attention to our work! It encourages me a lot!
The third way is the best, where only the first layer with "in_channels=4" is discarded. The codes can be written as follows:

def load(model, model_dict):
    if "state_dict" in model_dict.keys():
        state_dict = model_dict["state_dict"]
    elif "network_weights" in model_dict.keys():
        state_dict = model_dict["network_weights"]
    elif "net" in model_dict.keys():
        state_dict = model_dict["net"]
    else:
        state_dict = model_dict

    if "module." in list(state_dict.keys())[0]:
        print("Tag 'module.' found in state dict - fixing!")
        for key in list(state_dict.keys()):
            state_dict[key.replace("module.", "")] = state_dict.pop(key)

    if "backbone." in list(state_dict.keys())[0]:
        print("Tag 'backbone.' found in state dict - fixing!")
    for key in list(state_dict.keys()):
        state_dict[key.replace("backbone.", "")] = state_dict.pop(key)

    if "swin_vit" in list(state_dict.keys())[0]:
        print("Tag 'swin_vit' found in state dict - fixing!")
        for key in list(state_dict.keys()):
            state_dict[key.replace("swin_vit", "swinViT")] = state_dict.pop(key)

    current_model_dict = model.state_dict()
    new_state_dict = {
        k: state_dict[k] if (k in state_dict.keys()) and (state_dict[k].size() == current_model_dict[k].size()) else current_model_dict[k]
        for k in current_model_dict.keys()}

    model.load_state_dict(new_state_dict, strict=True)
    print("Using VoCo pretrained backbone weights !!!!!!!")

    return model

I will release the implementations with MRI in the next version recently. Stay tuned!

@zzzjjj98
Copy link
Author

Hi, many thanks for your attention to our work! It encourages me a lot! The third way is the best, where only the first layer with "in_channels=4" is discarded. The codes can be written as follows:

def load(model, model_dict):
    if "state_dict" in model_dict.keys():
        state_dict = model_dict["state_dict"]
    elif "network_weights" in model_dict.keys():
        state_dict = model_dict["network_weights"]
    elif "net" in model_dict.keys():
        state_dict = model_dict["net"]
    else:
        state_dict = model_dict

    if "module." in list(state_dict.keys())[0]:
        print("Tag 'module.' found in state dict - fixing!")
        for key in list(state_dict.keys()):
            state_dict[key.replace("module.", "")] = state_dict.pop(key)

    if "backbone." in list(state_dict.keys())[0]:
        print("Tag 'backbone.' found in state dict - fixing!")
    for key in list(state_dict.keys()):
        state_dict[key.replace("backbone.", "")] = state_dict.pop(key)

    if "swin_vit" in list(state_dict.keys())[0]:
        print("Tag 'swin_vit' found in state dict - fixing!")
        for key in list(state_dict.keys()):
            state_dict[key.replace("swin_vit", "swinViT")] = state_dict.pop(key)

    current_model_dict = model.state_dict()
    new_state_dict = {
        k: state_dict[k] if (k in state_dict.keys()) and (state_dict[k].size() == current_model_dict[k].size()) else current_model_dict[k]
        for k in current_model_dict.keys()}

    model.load_state_dict(new_state_dict, strict=True)
    print("Using VoCo pretrained backbone weights !!!!!!!")

    return model

I will release the implementations with MRI in the next version recently. Stay tuned!

Thank you for your reply! Looking forward to your work on MRI!

@zzzjjj98
Copy link
Author

Hi! I was thinking about the pre-training data format recently. There is a question I would like to discuss with you.
Among the publicly available datasets for MRI, only BRATS provides data for four modalities (t1c, t1n, t2f, t2w). Most of the datasets are either single modality or missing some of the modalities. If pre-training a model for MRI, which of the following strategies is better?

  1. Pre-training stage uses four-channel inputs, and the downstream task does not load the first layer if it is single-channel. (Like you said before)
  2. Convert the four-modal data to four single modalities and train a single-channel pre-trained model.
  3. Pre-training uses both unimodal and multimodal data, with missing modalities (channels) set to 0.
    Looking forward to your reply!Thank you

@Luffy03
Copy link
Owner

Luffy03 commented Sep 12, 2024

Good question. The third way is not feasible. And personally, I prefer the second way, since it is more flexible and extendable, easier to implement.

@zzzjjj98
Copy link
Author

Good question. The third way is not feasible. And personally, I prefer the second way, since it is more flexible and extendable, easier to implement.

Thank you so much for your continued patience in responding! What you have said makes a lot of sense.
Regarding the second strategy, if the four modalities are split into single modality, will there be a problem of duplicated training data? (Although the modalities are not the same gray values)

@Luffy03
Copy link
Owner

Luffy03 commented Sep 12, 2024

Yes, the information will be duplicated (to some extent) and result in redundancy. But you can consider it as a kind of data augmentation.
We have not yet evaluated the first way for mri, maybe it worth to try.

@zzzjjj98
Copy link
Author

Yes, the information will be duplicated (to some extent) and result in redundancy. But you can consider it as a kind of data augmentation. We have not yet evaluated the first way for mri, maybe it worth to try.

Thank you for your reply~(●'◡'●)

@Luffy03
Copy link
Owner

Luffy03 commented Sep 12, 2024

You are welcome! If you have any further questions or advances about it, feel free to contact me!

@Luffy03
Copy link
Owner

Luffy03 commented Oct 14, 2024

Dear researchers, our work is now available at Large-Scale-Medical, if you are still interested in this topic. Thank you very much for your attention to our work, it does encourage me a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants