Skip to content

Commit

Permalink
load_state_dict does not return the model (#1503)
Browse files Browse the repository at this point in the history
* Update pretrained_weights.ipynb

Fixed an error in the state dict loading of the turorial and added a comment on the num_classes parameter when creating timm models.

* Update docs/tutorials/pretrained_weights.ipynb

* Update utils.py

* Import Tuple from typing
* Change return of `load_state_dict` from `model` to `Tuple[List[str], List[str]]`, matching the return of the standard PyTorch builtin function.

* Update pretrained_weights.ipynb

Remove example of loading pretrained model without prediction head (`num_classes=0`).

* Update README.md

Adapt new `load_state_dict` function.

* Mimic return type of builtin load_state_dict

* Modern type hints

* Blacken

* Try being explicit

---------

Co-authored-by: Caleb Robinson <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
3 people authored and isaaccorley committed Mar 2, 2024
1 parent 89d4dae commit e63d358
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ from torchgeo.models import ResNet18_Weights

weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
```

These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html).
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/pretrained_weights.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
"source": [
"in_chans = weights.meta[\"in_chans\"]\n",
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n",
"model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
"model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_get_input_layer_name_and_module() -> None:

def test_load_state_dict(checkpoint: str, model: Module) -> None:
_, state_dict = extract_backbone(checkpoint)
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)


def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None:
Expand All @@ -58,7 +58,7 @@ def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module)
f" model {expected_in_channels}. Overriding with new input channels"
)
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)


def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None:
Expand All @@ -74,7 +74,7 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None
f" {expected_num_classes}. Overriding with new num classes"
)
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)


def test_reinit_initial_conv_layer() -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
backbone = utils.load_state_dict(backbone, state_dict)
utils.load_state_dict(backbone, state_dict)

self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224))

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
utils.load_state_dict(self.model, state_dict)

# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.backbone = utils.load_state_dict(self.backbone, state_dict)
utils.load_state_dict(self.backbone, state_dict)

# Create projection (and prediction) head
batch_norm = version == 3
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
utils.load_state_dict(self.model, state_dict)

# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def configure_models(self) -> None:
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.backbone = utils.load_state_dict(self.backbone, state_dict)
utils.load_state_dict(self.backbone, state_dict)

# Create projection head
input_dim = self.backbone.num_features
Expand Down
12 changes: 8 additions & 4 deletions torchgeo/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]:
return key, module


def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module:
def load_state_dict(
model: Module, state_dict: "OrderedDict[str, Tensor]"
) -> tuple[list[str], list[str]]:
"""Load pretrained resnet weights to a model.
Args:
model: model to load the pretrained weights to
state_dict: dict containing tensor parameters
Returns:
the model with pretrained weights
The missing and unexpected keys
Warns:
If input channels in model != pretrained model input channels
Expand Down Expand Up @@ -115,8 +117,10 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
state_dict[output_module_key + ".bias"],
)

model.load_state_dict(state_dict, strict=False)
return model
missing_keys: list[str]
unexpected_keys: list[str]
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
return missing_keys, unexpected_keys


def reinit_initial_conv_layer(
Expand Down

0 comments on commit e63d358

Please sign in to comment.