-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support different number of input channels to YOLOX backbone #1239
base: main
Are you sure you want to change the base?
Conversation
Allow for arbitrary number of input channels to the YOLOX neural network. Previously, number of input channels were hardcoded to 3 (i.e. RGB only).
Seems like there's still something hardcoded somewhere in the model architecture. When I try to pass in a 4-channel input, there is a RuntimeError about incorrect number of channels import torch
model = torch.hub.load(
repo_or_dir="weiji14/YOLOX:44c2f3149946758c620df077b845e25bee010edf",
model="yolox_nano",
backbone_in_channels=4,
num_classes=1,
device="cuda",
pretrained=False,
)
model(torch.randn(2, 4, 480, 480)) Full error message: ---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [7], in <module>
----> 1 model(torch.randn(2, 4, 480, 480))
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/yolox.py:30, in YOLOX.forward(self, x, targets)
28 def forward(self, x, targets=None):
29 # fpn output content features of [dark3, dark4, dark5]
---> 30 fpn_outs = self.backbone(x)
32 if self.training:
33 assert targets is not None
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/yolo_pafpn.py:96, in YOLOPAFPN.forward(self, input)
87 """
88 Args:
89 inputs: input images.
(...)
92 Tuple[Tensor]: FPN feature.
93 """
95 # backbone
---> 96 out_features = self.backbone(input)
97 features = [out_features[f] for f in self.in_features]
98 [x2, x1, x0] = features
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/darknet.py:170, in CSPDarknet.forward(self, x)
168 def forward(self, x):
169 outputs = {}
--> 170 x = self.stem(x)
171 outputs["stem"] = x
172 x = self.dark2(x)
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/network_blocks.py:210, in Focus.forward(self, x)
200 patch_bot_right = x[..., 1::2, 1::2]
201 x = torch.cat(
202 (
203 patch_top_left,
(...)
208 dim=1,
209 )
--> 210 return self.conv(x)
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/torch/hub/weiji14_YOLOX_44c2f3149946758c620df077b845e25bee010edf/yolox/models/network_blocks.py:51, in BaseConv.forward(self, x)
50 def forward(self, x):
---> 51 return self.act(self.bn(self.conv(x)))
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/conv.py:446, in Conv2d.forward(self, input)
445 def forward(self, input: Tensor) -> Tensor:
--> 446 return self._conv_forward(input, self.weight, self.bias)
File /srv/conda/envs/notebook/lib/python3.8/site-packages/torch/nn/modules/conv.py:442, in Conv2d._conv_forward(self, input, weight, bias)
438 if self.padding_mode != 'zeros':
439 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
440 weight, bias, self.stride,
441 _pair(0), self.dilation, self.groups)
--> 442 return F.conv2d(input, weight, bias, self.stride,
443 self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [16, 12, 3, 3], expected input[2, 16, 240, 240] to have 12 channels, but got 16 channels instead It looks similar to the problem mentioned at https://stackoverflow.com/questions/65719005/runtimeerror-given-groups-1-weight-of-size-16-1-3-3-expected-input16-3, and might have something to do with this line: YOLOX/yolox/models/network_blocks.py Line 193 in 7ba9fd2
Not sure why the layer is doing |
Ok, found the solution. Needed to modify import torch
model = torch.hub.load(
repo_or_dir="weiji14/YOLOX:d264836e86eefea58cd37865a134d63ebcf80fbf",
model="yolox_nano",
backbone_in_channels=4,
num_classes=1,
device="cpu",
pretrained=False,
)
model.eval()
y_hat = model(x=torch.randn(2, 4, 480, 480))
print(y_hat.shape)
# torch.Size([2, 4725, 6]) |
yolox/models/build.py
Outdated
exp.num_classes = num_classes | ||
yolox_model = exp.get_model() | ||
if pretrained and num_classes == 80: | ||
if pretrained and backbone_in_channels == 3 and num_classes == 80: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the architecture fully supports it, but in the past, when using N>3 for FRCNN or MRCNN w/ resnet backbone, I had better luck adapting weights to the extra channels. Certainly beats training from scratch.
Rather than ignoring the weights in the case of N_Channels != 3, is it possible to randomize the extra weights or duplicate weights from a different channel?
At the very least, might be nice to log a warning that the weights are being ignored, despite the "pretrained" input being true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the architecture fully supports it, but in the past, when using N>3 for FRCNN or MRCNN w/ resnet backbone, I had better luck adapting weights to the extra channels. Certainly beats training from scratch.
Rather than ignoring the weights in the case of N_Channels != 3, is it possible to randomize the extra weights or duplicate weights from a different channel?
Hmm, I'm not sure how to randomize weights for extra channels, do you some example code to do that? Maybe this can be done in a follow up Pull Request so as not to overcomplicate things.
At the very least, might be nice to log a warning that the weights are being ignored, despite the "pretrained" input being true.
Good idea. Or maybe it should just be an error? Edit: decided to just let it raise an error, done in commit 4e42e61
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm. For resnet50 channel additions I've mainly used tensorflow where I did something like:
tensor = params["conv0/W"]
assert tensor.shape == (7, 7, 3, 64)
# Create a replacement with 4 channels, using the existing first 3 and a copy of the 1st
replacement = np.zeros((7, 7, 4, 64), tensor.dtype)
replacement[:, :, :3, :] = tensor
replacement[:, :, 3, :] = tensor[:, :, 0, :]
params[target_layer_name] = replacement
I'm not sure how well that translates to architecture used here.
For pytorch, I believe you can simply modify the state dict before loading. You could do this to avoid loading any tensors with mismtached sizes. That is, attempt to use all weights which CAN be used. For example, a model trained on a different number of classes could still be used to populate weights of the backbone, omitting just the weights from the model head. . Here is an example from Huggingface: https://github.com/huggingface/transformers/blob/v4.18.0/src/transformers/modeling_utils.py#L1989
In the case of N channels != 3, you might need to manipulate the weights. I've had success manipulating weights directly like so:
# Load from a PyTorch checkpoint
state_dict = torch.load(archive_file, map_location="cpu")
# Manually manipulate the weights relevant to 4 channel model
wip = state_dict["param_name"]
// manipulate
wip = ...
# update
state_dict["param_name"] = wip
Allow for arbitrary number of input channels to the YOLOX neural network. Previously, number of input channels were hardcoded to 3 (i.e. RGB only).