Skip to content

Commit

Permalink
Bugfix for cls_layer
Browse files Browse the repository at this point in the history
In `any_softmax`, all operations are in-place, so pass into the `logits.clone()` to prevent outside logits changed.
  • Loading branch information
L1aoXingyu committed May 31, 2021
1 parent c3ac4f5 commit 6300bd7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

### v1.3

#### Improvements
#### New Features
- Vision Transformer backbone, see config in `configs/Market1501/bagtricks_vit.yml`
- Self-Distillation with EMA update
- Gradient Clip

#### Improvements
- Faster dataloader with pre-fetch thread and cuda stream
- Optimize DDP training speed by removing `find_unused_parameters` in DDP


### v1.2 (06/04/2021)

Expand Down
2 changes: 1 addition & 1 deletion fastreid/modeling/heads/clas_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def forward(self, features, targets=None):
# Evaluation
if not self.training: return logits.mul_(self.cls_layer.s)

cls_outputs = self.cls_layer(logits, targets)
cls_outputs = self.cls_layer(logits.clone(), targets)

return {
"cls_outputs": cls_outputs,
Expand Down
18 changes: 11 additions & 7 deletions fastreid/modeling/heads/embedding_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
@contact: [email protected]
"""

import math

import torch
import torch.nn.functional as F
from torch import nn

from fastreid.config import configurable
from fastreid.layers import *
from fastreid.layers import pooling, any_softmax
from fastreid.utils.weight_init import weights_init_kaiming
from fastreid.layers.weight_init import weights_init_kaiming
from .build import REID_HEADS_REGISTRY


Expand Down Expand Up @@ -78,14 +76,19 @@ def __init__(
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))

self.bottleneck = nn.Sequential(*neck)
self.bottleneck.apply(weights_init_kaiming)

# Linear layer
# Classification head
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
"but got {}".format(any_softmax.__all__, cls_type)
self.weight = nn.Parameter(torch.normal(0, 0.01, (num_classes, feat_dim)))
self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim))
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)

self.reset_parameters()

def reset_parameters(self) -> None:
self.bottleneck.apply(weights_init_kaiming)
nn.init.normal_(self.weight, std=0.01)

@classmethod
def from_config(cls, cfg):
# fmt: off
Expand Down Expand Up @@ -132,7 +135,8 @@ def forward(self, features, targets=None):
else:
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))

cls_outputs = self.cls_layer(logits, targets)
# Pass logits.clone() into cls_layer, because there is in-place operations
cls_outputs = self.cls_layer(logits.clone(), targets)

# fmt: off
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]
Expand Down

0 comments on commit 6300bd7

Please sign in to comment.