Skip to content

Commit

Permalink
Support regular expression in the mapping arg of copy_model_state (#…
Browse files Browse the repository at this point in the history
…6917)

Part of #6552.

### Description
After PR #6835, we have added `copy_model_args` in the `load` API which
can help us update the state_dict flexibly.

https://github.com/KumoLiu/MONAI/blob/93a149a611b66153cf804b31a7b36a939e2e593a/monai/bundle/scripts.py#L397

Given this [issue](#6552),
we need to be able to filter the model's weights flexibly.
In `copy_model_state`, we already have a "mapping" arg, the filter will
be more flexible if we can support regular expression in the mapping.
This PR mainly added the support for regular expression for "mapping"
arg.

In the
[example](#6552 (comment))
in this [issue](#6552),
after this PR, we can do something like:
```
exclude_vars = "encoder.mask_token|encoder.norm.weight|encoder.norm.bias|out.conv.conv.weight|out.conv.conv.bias"
mapping={"encoder.layers(.*).0.0.": "swinViT.layers(.*).0."}
dst_dict, updated_keys, unchanged_keys = copy_model_state(
       model, ssl_weights, exclude_vars=exclude_vars, mapping=mapping
)
```

Additionally, based on the comments of Eric
[here](#6552 (comment)),
I totally agree, we could add a handler to make the pipeline easier to
implement, but perhaps this task is no need to set as a "BundleTodo" for
MONAIv1.3 but as an enhancement for MONAI near future.
What do you think? @ericspod @wyli @Nic-Ma 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Sep 12, 2023
1 parent 66f42c1 commit 392c5c1
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 2 deletions.
49 changes: 49 additions & 0 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,52 @@ def forward(self, x, normalize=True):
x4 = self.layers4[0](x3.contiguous())
x4_out = self.proj_out(x4, normalize)
return [x0_out, x1_out, x2_out, x3_out, x4_out]


def filter_swinunetr(key, value):
"""
A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
This function is typically used with `monai.networks.copy_model_state`
[1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
<https://arxiv.org/abs/2307.16896>"
Args:
key: the key in the source state dict used for the update.
value: the value in the source state dict used for the update.
Examples::
import torch
from monai.apps import download_url
from monai.networks.utils import copy_model_state
from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr
model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
resource = (
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
)
ssl_weights_path = "./ssl_pretrained_weights.pth"
download_url(resource, ssl_weights_path)
ssl_weights = torch.load(ssl_weights_path)["model"]
dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
"""
if key in [
"encoder.mask_token",
"encoder.norm.weight",
"encoder.norm.bias",
"out.conv.conv.weight",
"out.conv.conv.bias",
]:
return None

if key[:8] == "encoder.":
if key[8:19] == "patch_embed":
new_key = "swinViT." + key[8:]
else:
new_key = "swinViT." + key[8:18] + key[20:]

return new_key, value
else:
return None
11 changes: 10 additions & 1 deletion monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def copy_model_state(
mapping=None,
exclude_vars=None,
inplace=True,
filter_func=None,
):
"""
Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten
Expand All @@ -490,7 +491,7 @@ def copy_model_state(
Args:
dst: a pytorch module or state dict to be updated.
src: a pytorch module or state dist used to get the values used for the update.
src: a pytorch module or state dict used to get the values used for the update.
dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]`
will be assigned to the value of `src[src_key]`.
mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]`
Expand All @@ -499,6 +500,8 @@ def copy_model_state(
so that their values are not overwritten by `src`.
inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`.
This option is only available when `dst` is a `torch.nn.Module`.
filter_func: a filter function used to filter the weights to be loaded.
See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py".
Examples:
.. code-block:: python
Expand Down Expand Up @@ -536,6 +539,12 @@ def copy_model_state(
warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.")
dst_dict[dst_key] = src_dict[s]
updated_keys.append(dst_key)
if filter_func is not None:
for key, value in src_dict.items():
new_pair = filter_func(key, value)
if new_pair is not None and new_pair[0] not in to_skip:
dst_dict[new_pair[0]] = new_pair[1]
updated_keys.append(new_pair[0])

updated_keys = sorted(set(updated_keys))
unchanged_keys = sorted(set(all_keys).difference(updated_keys))
Expand Down
33 changes: 32 additions & 1 deletion tests/test_swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@

from __future__ import annotations

import os
import tempfile
import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.apps import download_url
from monai.networks import eval_mode
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr
from monai.networks.utils import copy_model_state
from monai.utils import optional_import
from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, testing_data_config

einops, has_einops = optional_import("einops")

Expand Down Expand Up @@ -51,6 +56,14 @@
case_idx += 1
TEST_CASE_SWIN_UNETR.append(test_case)

TEST_CASE_FILTER = [
[
{"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True},
"swinViT.layers1.0.blocks.0.norm1.weight",
torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]),
]
]


class TestSWINUNETR(unittest.TestCase):
@parameterized.expand(TEST_CASE_SWIN_UNETR)
Expand Down Expand Up @@ -93,6 +106,24 @@ def test_patch_merging(self):
t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))
self.assertEqual(t.shape, torch.Size([1, 11, 10, 10, 20]))

@parameterized.expand(TEST_CASE_FILTER)
@skip_if_quick
def test_filter_swinunetr(self, input_param, key, value):
with skip_if_downloading_fails():
with tempfile.TemporaryDirectory() as tempdir:
file_name = "ssl_pretrained_weights.pth"
data_spec = testing_data_config("models", f"{file_name.split('.', 1)[0]}")
weight_path = os.path.join(tempdir, file_name)
download_url(
data_spec["url"], weight_path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"]
)

ssl_weight = torch.load(weight_path)["model"]
net = SwinUNETR(**input_param)
dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr)
assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False)
self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2)


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions tests/testing_data/data_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth",
"hash_type": "sha256",
"hash_val": "a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e"
},
"ssl_pretrained_weights": {
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth",
"hash_type": "sha256",
"hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8"
}
},
"configs": {
Expand Down

0 comments on commit 392c5c1

Please sign in to comment.