Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support model in_chans not equal to pre-trained weights in_chans #2324

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_satlas_sentinel2_transforms,
_satlas_transforms,
)
from .utils import load_pretrained

# https://github.com/zhu-xlab/DeCUR/blob/f190e9a3895ef645c005c8c2fce287ffa5a937e3/src/transfer_classification_BE/linear_BE_resnet.py#L286
# Normalization by channel-wise band statistics
Expand Down Expand Up @@ -710,14 +711,18 @@ def resnet50(
Returns:
A ResNet-50 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']

model: ResNet = timm.create_model('resnet50', *args, **kwargs)

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
pretrained_cfg = {}
pretrained_cfg['first_conv'] = 'conv1'
in_chans = kwargs.get('in_chans', 3)
missing_keys, unexpected_keys = load_pretrained(
model,
weights=weights,
pretrained_cfg=pretrained_cfg,
in_chans=in_chans,
strict=False,
)
assert set(missing_keys) <= {'fc.weight', 'fc.bias'}
assert not unexpected_keys
Expand Down
60 changes: 60 additions & 0 deletions torchgeo/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation. All rights reserved.

Check failure on line 1 in torchgeo/models/utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D100)

torchgeo/models/utils.py:1:1: D100 Missing docstring in public module
# Licensed under the MIT License.

import math

import torch
from torch import nn
from torchvision.models._api import Weights


def load_pretrained(

Check failure on line 11 in torchgeo/models/utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

torchgeo/models/utils.py:11:5: D103 Missing docstring in public function
model: nn.Module,
weights: Weights,
pretrained_cfg: dict,
in_chans: int = 3,
strict: bool = True,
) -> tuple:
state_dict = weights.get_state_dict(progress=True)

input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None:
if isinstance(input_convs, str):
input_convs = (input_convs,)
for input_conv_name in input_convs:
weight_name = input_conv_name + '.weight'
weight_in_chans = state_dict[weight_name].shape[1]
if in_chans != weight_in_chans:
try:
state_dict[weight_name] = adapt_input_conv(
in_chans, state_dict[weight_name]
)
print(
f'Converted input conv {input_conv_name} pretrained weights from {weight_in_chans} to {in_chans} channel(s)'
)
except NotImplementedError:
del state_dict[weight_name]
strict = False
print(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.'
)

missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)

return missing_keys, unexpected_keys


def adapt_input_conv(in_chans: int, conv_weight: torch.Tensor) -> torch.Tensor:

Check failure on line 47 in torchgeo/models/utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

torchgeo/models/utils.py:47:5: D103 Missing docstring in public function
conv_type = conv_weight.dtype
conv_weight = (
conv_weight.float()
) # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape

Check failure on line 52 in torchgeo/models/utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E741)

torchgeo/models/utils.py:52:5: E741 Ambiguous variable name: `O`

Check failure on line 52 in torchgeo/models/utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E741)

torchgeo/models/utils.py:52:8: E741 Ambiguous variable name: `I`
if in_chans == 1:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
else:
repeat = math.ceil(in_chans / I)
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= I / float(in_chans)
conv_weight = conv_weight.to(conv_type)
return conv_weight
Loading