Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different number of input channels to YOLOX backbone #1239

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

weiji14
Copy link

@weiji14 weiji14 commented Apr 13, 2022

Allow for arbitrary number of input channels to the YOLOX neural network. Previously, number of input channels were hardcoded to 3 (i.e. RGB only).

Allow for arbitrary number of input channels to the YOLOX neural network. Previously, number of input channels were hardcoded to 3 (i.e. RGB only).
@CLAassistant
Copy link

CLAassistant commented Apr 13, 2022

CLA assistant check
All committers have signed the CLA.

@weiji14
Copy link
Author

weiji14 commented Apr 14, 2022

Seems like there's still something hardcoded somewhere in the model architecture. When I try to pass in a 4-channel input, there is a RuntimeError about incorrect number of channels

import torch

model = torch.hub.load(
    repo_or_dir="weiji14/YOLOX:44c2f3149946758c620df077b845e25bee010edf",
    model="yolox_nano",
    backbone_in_channels=4,
    num_classes=1,
    device="cuda",
    pretrained=False,
)
model(torch.randn(2, 4, 480, 480))

Full error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [7], in <module>
----> 1 model(torch.randn(2, 4, 480, 480))

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/yolox.py:30, in YOLOX.forward(self, x, targets)
     28 def forward(self, x, targets=None):
     29     # fpn output content features of [dark3, dark4, dark5]
---> 30     fpn_outs = self.backbone(x)
     32     if self.training:
     33         assert targets is not None

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/yolo_pafpn.py:96, in YOLOPAFPN.forward(self, input)
     87 """
     88 Args:
     89     inputs: input images.
   (...)
     92     Tuple[Tensor]: FPN feature.
     93 """
     95 #  backbone
---> 96 out_features = self.backbone(input)
     97 features = [out_features[f] for f in self.in_features]
     98 [x2, x1, x0] = features

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/darknet.py:170, in CSPDarknet.forward(self, x)
    168 def forward(self, x):
    169     outputs = {}
--> 170     x = self.stem(x)
    171     outputs["stem"] = x
    172     x = self.dark2(x)

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/network_blocks.py:210, in Focus.forward(self, x)
    200 patch_bot_right = x[..., 1::2, 1::2]
    201 x = torch.cat(
    202     (
    203         patch_top_left,
   (...)
    208     dim=1,
    209 )
--> 210 return self.conv(x)

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/network_blocks.py:51, in BaseConv.forward(self, x)
     50 def forward(self, x):
---> 51     return self.act(self.bn(self.conv(x)))

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/conv.py:446, in Conv2d.forward(self, input)
    445 def forward(self, input: Tensor) -> Tensor:
--> 446     return self._conv_forward(input, self.weight, self.bias)

File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/conv.py:442, in Conv2d._conv_forward(self, input, weight, bias)
    438 if self.padding_mode != 'zeros':
    439     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    440                     weight, bias, self.stride,
    441                     _pair(0), self.dilation, self.groups)
--> 442 return F.conv2d(input, weight, bias, self.stride,
    443                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [16, 12, 3, 3], expected input[2, 16, 240, 240] to have 12 channels, but got 16 channels instead

It looks similar to the problem mentioned at https://stackoverflow.com/questions/65719005/runtimeerror-given-groups-1-weight-of-size-16-1-3-3-expected-input16-3, and might have something to do with this line:

self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)

Not sure why the layer is doing in_channels * 4, will need to do a bit more debugging.

@weiji14
Copy link
Author

weiji14 commented Apr 14, 2022

Ok, found the solution. Needed to modify exps/default/yolox_nano.py too, done in d264836. Now 4-channel inputs work properly!

import torch

model = torch.hub.load(
    repo_or_dir="weiji14/YOLOX:d264836e86eefea58cd37865a134d63ebcf80fbf",
    model="yolox_nano",
    backbone_in_channels=4,
    num_classes=1,
    device="cpu",
    pretrained=False,
)
model.eval()
y_hat = model(x=torch.randn(2, 4, 480, 480))
print(y_hat.shape)
# torch.Size([2, 4725, 6])

@weiji14 weiji14 marked this pull request as ready for review April 14, 2022 03:18
exp.num_classes = num_classes
yolox_model = exp.get_model()
if pretrained and num_classes == 80:
if pretrained and backbone_in_channels == 3 and num_classes == 80:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the architecture fully supports it, but in the past, when using N>3 for FRCNN or MRCNN w/ resnet backbone, I had better luck adapting weights to the extra channels. Certainly beats training from scratch.

Rather than ignoring the weights in the case of N_Channels != 3, is it possible to randomize the extra weights or duplicate weights from a different channel?

At the very least, might be nice to log a warning that the weights are being ignored, despite the "pretrained" input being true.

Copy link
Author

@weiji14 weiji14 Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the architecture fully supports it, but in the past, when using N>3 for FRCNN or MRCNN w/ resnet backbone, I had better luck adapting weights to the extra channels. Certainly beats training from scratch.

Rather than ignoring the weights in the case of N_Channels != 3, is it possible to randomize the extra weights or duplicate weights from a different channel?

Hmm, I'm not sure how to randomize weights for extra channels, do you some example code to do that? Maybe this can be done in a follow up Pull Request so as not to overcomplicate things.

At the very least, might be nice to log a warning that the weights are being ignored, despite the "pretrained" input being true.

Good idea. Or maybe it should just be an error? Edit: decided to just let it raise an error, done in commit 4e42e61

Copy link

@dcyoung dcyoung Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. For resnet50 channel additions I've mainly used tensorflow where I did something like:

tensor = params["conv0/W"]
assert tensor.shape == (7, 7, 3, 64)
# Create a replacement with 4 channels, using the existing first 3 and a copy of the 1st
replacement = np.zeros((7, 7, 4, 64), tensor.dtype)
replacement[:, :, :3, :] = tensor
replacement[:, :, 3, :] = tensor[:, :, 0, :]
params[target_layer_name] = replacement

I'm not sure how well that translates to architecture used here.

For pytorch, I believe you can simply modify the state dict before loading. You could do this to avoid loading any tensors with mismtached sizes. That is, attempt to use all weights which CAN be used. For example, a model trained on a different number of classes could still be used to populate weights of the backbone, omitting just the weights from the model head. . Here is an example from Huggingface: https://github.com/huggingface/transformers/blob/v4.18.0/src/transformers/modeling_utils.py#L1989

In the case of N channels != 3, you might need to manipulate the weights. I've had success manipulating weights directly like so:

# Load from a PyTorch checkpoint
state_dict = torch.load(archive_file, map_location="cpu")

# Manually manipulate the weights relevant to 4 channel model
wip = state_dict["param_name"]
// manipulate
wip = ...
# update
state_dict["param_name"] = wip

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants