Skip to content

Commit

Permalink
head reset will use the same class of head
Browse files Browse the repository at this point in the history
  • Loading branch information
fffffgggg54 committed Dec 26, 2024
1 parent f6bc034 commit abbdf05
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion timm/layers/ml_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def reset(self, num_classes, global_pool=None):
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
# TODO fix this it is incorrect, need to impl a reset for mldecoder itself i think
self.head = MLDecoder(in_features=in_features, num_classes=num_classes)
self.head = type(self.head)(in_features=in_features, num_classes=num_classes)


def forward(self, x, pre_logits: bool = False):
Expand Down

0 comments on commit abbdf05

Please sign in to comment.