diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index e1b6678ed73..7a4cb00cb63 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -46,6 +46,14 @@ def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]: state_dict = OrderedDict( {k.replace("model.backbone.model.", ""): v for k, v in state_dict.items()} ) + elif checkpoint["model"] in ["deeplabv3+", "unet"]: + state_dict = OrderedDict( + { + k.replace("encoder.", ""): v + for k, v in state_dict.items() + if "encoder" in k + } + ) else: raise ValueError( "Unknown checkpoint task. Only backbone or model extraction is supported"