From 8bef257de8bf458ba30921dbd7af00163635bd3c Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Thu, 26 Sep 2024 15:12:14 -0600 Subject: [PATCH 1/2] created load_pretrained function and used for resnet50 --- torchgeo/models/resnet.py | 16 +++++++---- torchgeo/models/utils.py | 60 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 5 deletions(-) create mode 100644 torchgeo/models/utils.py diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index b61e816ea2..39aa36ac83 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -17,6 +17,7 @@ _satlas_sentinel2_transforms, _satlas_transforms, ) +from .utils import load_pretrained # https://github.com/zhu-xlab/DeCUR/blob/f190e9a3895ef645c005c8c2fce287ffa5a937e3/src/transfer_classification_BE/linear_BE_resnet.py#L286 # Normalization by channel-wise band statistics @@ -710,14 +711,19 @@ def resnet50( Returns: A ResNet-50 model. """ - if weights: - kwargs['in_chans'] = weights.meta['in_chans'] - model: ResNet = timm.create_model('resnet50', *args, **kwargs) if weights: - missing_keys, unexpected_keys = model.load_state_dict( - weights.get_state_dict(progress=True), strict=False + pretrained_cfg = {} + pretrained_cfg['url'] = weights.url + pretrained_cfg['first_conv'] = 'conv1' + in_chans = kwargs.get('in_chans', 3) + missing_keys, unexpected_keys = load_pretrained( + model, + weights=weights, + pretrained_cfg=pretrained_cfg, + in_chans=in_chans, + strict=False, ) assert set(missing_keys) <= {'fc.weight', 'fc.bias'} assert not unexpected_keys diff --git a/torchgeo/models/utils.py b/torchgeo/models/utils.py new file mode 100644 index 0000000000..16815d839c --- /dev/null +++ b/torchgeo/models/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import math + +import torch +from torch import nn +from torchvision.models._api import Weights + + +def load_pretrained( + model: nn.Module, + weights: Weights, + pretrained_cfg: dict, + in_chans: int = 3, + strict: bool = True, +) -> tuple: + state_dict = weights.get_state_dict(progress=True) + + input_convs = pretrained_cfg.get('first_conv', None) + if input_convs is not None: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + weight_in_chans = state_dict[weight_name].shape[1] + if in_chans != weight_in_chans: + try: + state_dict[weight_name] = adapt_input_conv( + in_chans, state_dict[weight_name] + ) + print( + f'Converted input conv {input_conv_name} pretrained weights from {weight_in_chans} to {in_chans} channel(s)' + ) + except NotImplementedError: + del state_dict[weight_name] + strict = False + print( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.' + ) + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) + + return missing_keys, unexpected_keys + + +def adapt_input_conv(in_chans: int, conv_weight: torch.Tensor) -> torch.Tensor: + conv_type = conv_weight.dtype + conv_weight = ( + conv_weight.float() + ) # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + else: + repeat = math.ceil(in_chans / I) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= I / float(in_chans) + conv_weight = conv_weight.to(conv_type) + return conv_weight From a5fa82b464da07e542ae4e6780921f9af0120ece Mon Sep 17 00:00:00 2001 From: Keenan Eves Date: Fri, 27 Sep 2024 16:37:55 +0000 Subject: [PATCH 2/2] removed unnecessary line --- torchgeo/models/resnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 39aa36ac83..d4c2a320f0 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -715,7 +715,6 @@ def resnet50( if weights: pretrained_cfg = {} - pretrained_cfg['url'] = weights.url pretrained_cfg['first_conv'] = 'conv1' in_chans = kwargs.get('in_chans', 3) missing_keys, unexpected_keys = load_pretrained(