diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 6cd9719611..5c08a3ca8a 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1055,3 +1055,52 @@ def forward(self, x, normalize=True): x4 = self.layers4[0](x3.contiguous()) x4_out = self.proj_out(x4, normalize) return [x0_out, x1_out, x2_out, x3_out, x4_out] + + +def filter_swinunetr(key, value): + """ + A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model. + This function is typically used with `monai.networks.copy_model_state` + [1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training + " + + Args: + key: the key in the source state dict used for the update. + value: the value in the source state dict used for the update. + + Examples:: + + import torch + from monai.apps import download_url + from monai.networks.utils import copy_model_state + from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr + + model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48) + resource = ( + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth" + ) + ssl_weights_path = "./ssl_pretrained_weights.pth" + download_url(resource, ssl_weights_path) + ssl_weights = torch.load(ssl_weights_path)["model"] + + dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr) + + """ + if key in [ + "encoder.mask_token", + "encoder.norm.weight", + "encoder.norm.bias", + "out.conv.conv.weight", + "out.conv.conv.bias", + ]: + return None + + if key[:8] == "encoder.": + if key[8:19] == "patch_embed": + new_key = "swinViT." + key[8:] + else: + new_key = "swinViT." + key[8:18] + key[20:] + + return new_key, value + else: + return None diff --git a/monai/networks/utils.py b/monai/networks/utils.py index e4cdfc6f9b..12533183b1 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -478,6 +478,7 @@ def copy_model_state( mapping=None, exclude_vars=None, inplace=True, + filter_func=None, ): """ Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten @@ -490,7 +491,7 @@ def copy_model_state( Args: dst: a pytorch module or state dict to be updated. - src: a pytorch module or state dist used to get the values used for the update. + src: a pytorch module or state dict used to get the values used for the update. dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]` will be assigned to the value of `src[src_key]`. mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]` @@ -499,6 +500,8 @@ def copy_model_state( so that their values are not overwritten by `src`. inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`. This option is only available when `dst` is a `torch.nn.Module`. + filter_func: a filter function used to filter the weights to be loaded. + See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py". Examples: .. code-block:: python @@ -536,6 +539,12 @@ def copy_model_state( warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.") dst_dict[dst_key] = src_dict[s] updated_keys.append(dst_key) + if filter_func is not None: + for key, value in src_dict.items(): + new_pair = filter_func(key, value) + if new_pair is not None and new_pair[0] not in to_skip: + dst_dict[new_pair[0]] = new_pair[1] + updated_keys.append(new_pair[0]) updated_keys = sorted(set(updated_keys)) unchanged_keys = sorted(set(all_keys).difference(updated_keys)) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 636fcc9e31..9197308aa5 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -11,15 +11,20 @@ from __future__ import annotations +import os +import tempfile import unittest from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode -from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR +from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr +from monai.networks.utils import copy_model_state from monai.utils import optional_import +from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, testing_data_config einops, has_einops = optional_import("einops") @@ -51,6 +56,14 @@ case_idx += 1 TEST_CASE_SWIN_UNETR.append(test_case) +TEST_CASE_FILTER = [ + [ + {"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True}, + "swinViT.layers1.0.blocks.0.norm1.weight", + torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]), + ] +] + class TestSWINUNETR(unittest.TestCase): @parameterized.expand(TEST_CASE_SWIN_UNETR) @@ -93,6 +106,24 @@ def test_patch_merging(self): t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim))) self.assertEqual(t.shape, torch.Size([1, 11, 10, 10, 20])) + @parameterized.expand(TEST_CASE_FILTER) + @skip_if_quick + def test_filter_swinunetr(self, input_param, key, value): + with skip_if_downloading_fails(): + with tempfile.TemporaryDirectory() as tempdir: + file_name = "ssl_pretrained_weights.pth" + data_spec = testing_data_config("models", f"{file_name.split('.', 1)[0]}") + weight_path = os.path.join(tempdir, file_name) + download_url( + data_spec["url"], weight_path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"] + ) + + ssl_weight = torch.load(weight_path)["model"] + net = SwinUNETR(**input_param) + dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr) + assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False) + self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 4bdac6abba..c0666119d9 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -128,6 +128,11 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth", "hash_type": "sha256", "hash_val": "a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e" + }, + "ssl_pretrained_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", + "hash_type": "sha256", + "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" } }, "configs": {