diff --git a/README.md b/README.md index 36ae3aa37eb..ec31fcb7753 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ from torchgeo.models import ResNet18_Weights weights = ResNet18_Weights.SENTINEL2_ALL_MOCO model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10) -model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False) +model.load_state_dict(weights.get_state_dict(progress=True), strict=False) ``` These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html). diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index 26d97fcbc6c..e15dd1ebc16 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -228,7 +228,7 @@ "source": [ "in_chans = weights.meta[\"in_chans\"]\n", "model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n", - "model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" + "model.load_state_dict(weights.get_state_dict(progress=True), strict=False)" ] }, { diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index 52d7a9be25d..06da0a359eb 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -41,7 +41,7 @@ def test_get_input_layer_name_and_module() -> None: def test_load_state_dict(checkpoint: str, model: Module) -> None: _, state_dict = extract_backbone(checkpoint) - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None: @@ -58,7 +58,7 @@ def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) f" model {expected_in_channels}. Overriding with new input channels" ) with pytest.warns(UserWarning, match=warning): - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None: @@ -74,7 +74,7 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None f" {expected_num_classes}. Overriding with new num classes" ) with pytest.warns(UserWarning, match=warning): - model = load_state_dict(model, state_dict) + load_state_dict(model, state_dict) def test_reinit_initial_conv_layer() -> None: diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 68bdb6c9c43..d6c0b62765e 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -343,7 +343,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - backbone = utils.load_state_dict(backbone, state_dict) + utils.load_state_dict(backbone, state_dict) self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224)) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 9ac312051c7..5d8d10c9dce 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -137,7 +137,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model = utils.load_state_dict(self.model, state_dict) + utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head if self.hparams["freeze_backbone"]: diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index d2621a8da74..4dbc1e453c9 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -261,7 +261,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.backbone = utils.load_state_dict(self.backbone, state_dict) + utils.load_state_dict(self.backbone, state_dict) # Create projection (and prediction) head batch_norm = version == 3 diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index c58f033b5bc..9cc2ea56441 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -128,7 +128,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.model = utils.load_state_dict(self.model, state_dict) + utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head if self.hparams["freeze_backbone"]: diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index a889be1c96f..27719eda224 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -172,7 +172,7 @@ def configure_models(self) -> None: _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) - self.backbone = utils.load_state_dict(self.backbone, state_dict) + utils.load_state_dict(self.backbone, state_dict) # Create projection head input_dim = self.backbone.num_features diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index e1b6678ed73..b5cd8f1e923 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -71,7 +71,9 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]: return key, module -def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module: +def load_state_dict( + model: Module, state_dict: "OrderedDict[str, Tensor]" +) -> tuple[list[str], list[str]]: """Load pretrained resnet weights to a model. Args: @@ -79,7 +81,7 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo state_dict: dict containing tensor parameters Returns: - the model with pretrained weights + The missing and unexpected keys Warns: If input channels in model != pretrained model input channels @@ -115,8 +117,10 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo state_dict[output_module_key + ".bias"], ) - model.load_state_dict(state_dict, strict=False) - return model + missing_keys: list[str] + unexpected_keys: list[str] + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + return missing_keys, unexpected_keys def reinit_initial_conv_layer(