From 0616c5ebe69fe533072587865d1efe897b49af53 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 14:47:49 +0900 Subject: [PATCH 01/12] Add yolo-fastest-v1 config --- .../model/yolo/yolo-fastest-v1-detection.yaml | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 config/model/yolo/yolo-fastest-v1-detection.yaml diff --git a/config/model/yolo/yolo-fastest-v1-detection.yaml b/config/model/yolo/yolo-fastest-v1-detection.yaml new file mode 100644 index 000000000..9e306fb41 --- /dev/null +++ b/config/model/yolo/yolo-fastest-v1-detection.yaml @@ -0,0 +1,76 @@ +model: + task: detection + name: yolo_fastest + checkpoint: + use_pretrained: False + load_head: False + path: ~ + fx_model_path: ~ + optimizer_path: ~ + freeze_backbone: False + architecture: + full: ~ # auto + backbone: + name: darknet + params: + # stage_stem_block_type: Bottleneck + stage_stem_block_type: darknetblock + depthwise: True + norm_type: batch_norm + act_type: relu6 + num_feat_layers: 2 + stem_stride: 2 + stem_out_channels: 8 + stage_params: + - # stage_1 + out_channels: 4 + num_blocks: 1 # not the stage stem block + darknet_expansions: [2, 2] + - # stage_2 + out_channels: 8 + num_blocks: 2 + darknet_expansions: [3, 4] + - # stage_3 + out_channels: 8 + num_blocks: 2 + darknet_expansions: [4, 6] + - # stage_4 + out_channels: 16 + num_blocks: 4 + darknet_expansions: [3, 6] + - # stage_5 + out_channels: 24 + num_blocks: 4 + darknet_expansions: [4, 5.67] # expands to 136 channels + - # stage_6 + out_channels: 48 + num_blocks: 5 + darknet_expansions: [2.84, 4.67] # expands to 224 channels + neck: + name: yolov3fpn + params: + out_channels: [48, 24] + kernel_size: 5 + double_channel: False + share_fpn_block: False + depthwise: True + norm_type: batch_norm + act_type: leaky_relu + head: + name: yolo_fastest_head + params: + anchors: + - [12.64,19.39, 37.88,51.48, 55.71,138.31] # P4/16 + - [126.91,78.23, 131.57,214.55, 279.92,258.87] # P5/32] + # strides: [16, 32] + norm_type: batch_norm + topk_candidates: 1000 + # postprocessor - decode + score_thresh: 0.0 + # postprocessor - nms + nms_thresh: 0.45 + class_agnostic: False + # Temporary loss to test the full YOLOFastest model to work right + losses: + - criterion: retinanet_loss + weight: ~ From 883268a0bfca33a7cf323e49c31b548d6465a375 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 14:51:52 +0900 Subject: [PATCH 02/12] [refactor] SeparableConvLayer --- src/netspresso_trainer/models/op/custom.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/netspresso_trainer/models/op/custom.py b/src/netspresso_trainer/models/op/custom.py index 093cab68b..5c91f2017 100644 --- a/src/netspresso_trainer/models/op/custom.py +++ b/src/netspresso_trainer/models/op/custom.py @@ -152,6 +152,7 @@ def __init__( norm_type: Optional[str] = None, use_act: bool = True, act_type: Optional[str] = None, + no_out_act: bool = False, ) -> None: super().__init__() self.depthwise = ConvLayer(in_channels=in_channels, out_channels=in_channels, @@ -159,11 +160,15 @@ def __init__( padding=padding, groups=in_channels, bias=bias, padding_mode=padding_mode, use_norm=use_norm, norm_type=norm_type, use_act=use_act, act_type=act_type,) self.pointwise = ConvLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, - use_norm=use_norm, norm_type=norm_type, use_act=use_act, act_type=act_type,) + use_norm=use_norm, norm_type=norm_type, use_act=False, act_type=act_type,) + self.final_act = ( + nn.Identity() if no_out_act else ACTIVATION_REGISTRY[act_type]() + ) def forward(self, x: Union[Tensor, Proxy]) -> Union[Tensor, Proxy]: x = self.depthwise(x) x = self.pointwise(x) + x = self.final_act(x) return x From 2c314cad76a5914759027bff98041219784839c0 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 14:53:17 +0900 Subject: [PATCH 03/12] [refactor] make DarknetBlock compatible with YOLOFastest --- src/netspresso_trainer/models/op/custom.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/netspresso_trainer/models/op/custom.py b/src/netspresso_trainer/models/op/custom.py index 5c91f2017..1c1b7d095 100644 --- a/src/netspresso_trainer/models/op/custom.py +++ b/src/netspresso_trainer/models/op/custom.py @@ -651,7 +651,6 @@ def forward(self, x): return x -# Newly defined because of slight difference with Bottleneck of custom.py class DarknetBlock(nn.Module): # Standard bottleneck def __init__( @@ -662,17 +661,20 @@ def __init__( expansion=0.5, depthwise=False, act_type="silu", + norm_type: Optional[str] = None, + no_out_act=False, + depthwise_stride: Optional[int] = None, ): super().__init__() hidden_channels = int(out_channels * expansion) self.conv1 = ConvLayer(in_channels=in_channels, out_channels=hidden_channels, - kernel_size=1, stride=1, act_type=act_type) + kernel_size=1, stride=1, act_type=act_type, norm_type=norm_type) if depthwise: self.conv2 = SeparableConvLayer(in_channels=hidden_channels, out_channels=out_channels, - kernel_size=3, stride=1, act_type=act_type) + kernel_size=3, stride=depthwise_stride if depthwise_stride else 1, act_type=act_type, norm_type=norm_type, no_out_act=no_out_act) else: self.conv2 = ConvLayer(in_channels=hidden_channels, out_channels=out_channels, - kernel_size=3, stride=1, act_type=act_type) + kernel_size=3, stride=1, act_type=act_type, norm_type=norm_type) self.use_add = shortcut and in_channels == out_channels def forward(self, x): From 65f4674905b60a696c0b8ebe67fbcf1f1eda51bb Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 14:55:01 +0900 Subject: [PATCH 04/12] [feat] Implement DarkNet --- .../models/backbones/experimental/darknet.py | 171 +++++++++++++++++- 1 file changed, 169 insertions(+), 2 deletions(-) diff --git a/src/netspresso_trainer/models/backbones/experimental/darknet.py b/src/netspresso_trainer/models/backbones/experimental/darknet.py index 645e193b2..b7cc1202a 100644 --- a/src/netspresso_trainer/models/backbones/experimental/darknet.py +++ b/src/netspresso_trainer/models/backbones/experimental/darknet.py @@ -18,18 +18,20 @@ Based on the Darknet implementation of Megvii. https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/darknet.py """ -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Type from omegaconf import DictConfig import torch from torch import nn -from ...op.custom import ConvLayer, CSPLayer, Focus, SPPBottleneck, SeparableConvLayer +from ...op.custom import ConvLayer, CSPLayer, Focus, SPPBottleneck, SeparableConvLayer, DarknetBlock from ...utils import BackboneOutput from ..registry import USE_INTERMEDIATE_FEATURES_TASK_LIST __all__ = ['cspdarknet'] SUPPORTING_TASK = ['classification', 'segmentation', 'detection', 'pose_estimation'] +DARKNET_SUPPORTED_BLOCKS = ["darknetblock"] +BLOCK_FROM_LITERAL: Dict[str, Type[nn.Module]] = {"darknetblock": DarknetBlock} class CSPDarknet(nn.Module): @@ -181,3 +183,168 @@ def task_support(self, task): def cspdarknet(task, conf_model_backbone) -> CSPDarknet: return CSPDarknet(task, conf_model_backbone.params, conf_model_backbone.stage_params) + + +class Darknet(nn.Module): + """ + Consists of a stem layer and multiple stage layers. + Stage layers are named as stage_{i} starting from stage_1 + """ + + num_stages: int + + def __init__( + self, + task: str, + params: Optional[DictConfig] = None, + stage_params: Optional[List] = None, + ) -> None: + self.task = task.lower() + assert ( + self.task in SUPPORTING_TASK + ), f"Darknet is not supported on {self.task} task now." + assert stage_params, "please provide stage params of Darknet" + assert len(stage_params) >= 2 + assert ( + params.stage_stem_block_type.lower() in DARKNET_SUPPORTED_BLOCKS + ), "Block type not supported" + self.use_intermediate_features = ( + self.task in USE_INTERMEDIATE_FEATURES_TASK_LIST + ) + + self.num_stages = len(stage_params) + + super().__init__() + + # TODO: Check if inplace activation should be used + act_type = params.act_type + norm_type = params.norm_type + stage_stem_block_type = params.stage_stem_block_type + stem_stride = params.stem_stride + stem_out_channels = params.stem_out_channels + depthwise = params.depthwise + + StageStemBlock = BLOCK_FROM_LITERAL[stage_stem_block_type.lower()] + predefined_out_features = dict() + + # build the stem layer + self.stem = ConvLayer( + in_channels=3, + out_channels=stem_out_channels, + kernel_size=3, + stride=stem_stride, + act_type=act_type, + norm_type=norm_type, + ) + + prev_out_channels = stem_out_channels + + # build rest of the layers + # TODO: make it compatiable with Yolov3 + for i, stage_param in enumerate(stage_params): + + layers = [] + hidden_expansions = stage_param.darknet_expansions + out_channels = stage_param.out_channels + + if len(hidden_expansions) == 2: + # stage_stem_expansion is defined as hidden_ch // output_ch + stage_stem_expansion = hidden_expansions[0] + block_expansion = hidden_expansions[1] + + # TODO: Implement + else: + raise NotImplementedError + + stage_stem_block = StageStemBlock( + in_channels=prev_out_channels, + out_channels=out_channels, + shortcut=False, + expansion=stage_stem_expansion, + depthwise=depthwise, + act_type=act_type, + norm_type=norm_type, + no_out_act=False, + depthwise_stride=2, + ) + + layers.append(stage_stem_block) + prev_out_channels = out_channels + + for _ in range(stage_param.num_blocks): + + in_ch = prev_out_channels + out_ch = in_ch + darknet_block = DarknetBlock( + in_channels=in_ch, + out_channels=out_ch, + shortcut=True, + expansion=block_expansion, + depthwise=depthwise, + norm_type=norm_type, + act_type=act_type, + no_out_act=True, + ) + + layers.append(darknet_block) + setattr(self, f"stage_{i+1}", nn.Sequential(*layers)) + predefined_out_features[f"stage_{i+1}"] = stage_param.out_channels + + # feature layers + self.out_features = [] + first_feat_layer = self.num_stages - params.num_feat_layers + 1 + for i in range(params.num_feat_layers): + layer_str = f"stage_{first_feat_layer + i}" + self.out_features.append(layer_str) + + self._feature_dim = predefined_out_features[f"stage_{self.num_stages-1}"] + + self._intermediate_features_dim = [ + predefined_out_features[out_feature] for out_feature in self.out_features + ] + + # Initialize + def init_bn(M): + for m in M.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + self.apply(init_bn) + return + + def forward(self, x): + outputs_dict = {} + x = self.stem(x) + outputs_dict["stem"] = x + + for i in range(1, self.num_stages + 1): + x = getattr(self, f"stage_{i}")(x) + outputs_dict[f"stage_{i}"] = x + + if self.use_intermediate_features: + all_hidden_states = [ + outputs_dict[out_name] for out_name in self.out_features + ] + return BackboneOutput(intermediate_features=all_hidden_states) + + # TODO: Check if classification head is needed + x = self.avgpool(x) + x = torch.flatten(x, 1) + + return BackboneOutput(last_feature=x) + + @property + def feature_dim(self): + return self._feature_dim + + @property + def intermediate_features_dim(self): + return self._intermediate_features_dim + + def task_support(self, task): + return task.lower() in SUPPORTING_TASK + + +def darknet(task, conf_model_backbone) -> Darknet: + return Darknet(task, conf_model_backbone.params, conf_model_backbone.stage_params) \ No newline at end of file From 17b8d7717a8bb1d192bb990fcf171bf19c144e73 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 14:57:11 +0900 Subject: [PATCH 05/12] init darknet --- src/netspresso_trainer/models/backbones/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/netspresso_trainer/models/backbones/__init__.py b/src/netspresso_trainer/models/backbones/__init__.py index 03507f8f6..035dbb4ce 100644 --- a/src/netspresso_trainer/models/backbones/__init__.py +++ b/src/netspresso_trainer/models/backbones/__init__.py @@ -16,7 +16,7 @@ # from .core import * from .core.resnet import resnet -from .experimental.darknet import cspdarknet +from .experimental.darknet import cspdarknet, darknet from .experimental.efficientformer import efficientformer from .experimental.mixnet import mixnet from .experimental.mixtransformer import mixtransformer From ce25e6074010b2fe67d8f1f685d991431b7b76f3 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 14:57:20 +0900 Subject: [PATCH 06/12] init yolov3fpn --- src/netspresso_trainer/models/necks/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/netspresso_trainer/models/necks/__init__.py b/src/netspresso_trainer/models/necks/__init__.py index 23a863307..f5558423a 100644 --- a/src/netspresso_trainer/models/necks/__init__.py +++ b/src/netspresso_trainer/models/necks/__init__.py @@ -16,3 +16,4 @@ from .experimental.fpn import fpn from .experimental.yolopafpn import yolopafpn +from .experimental.yolov3fpn import yolov3fpn From 9ab837b146720411baee59493c150578ba7ca8e5 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 14:57:33 +0900 Subject: [PATCH 07/12] [feat] implement yolov3fpn --- .../models/necks/experimental/yolov3fpn.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 src/netspresso_trainer/models/necks/experimental/yolov3fpn.py diff --git a/src/netspresso_trainer/models/necks/experimental/yolov3fpn.py b/src/netspresso_trainer/models/necks/experimental/yolov3fpn.py new file mode 100644 index 000000000..360e3e9d6 --- /dev/null +++ b/src/netspresso_trainer/models/necks/experimental/yolov3fpn.py @@ -0,0 +1,148 @@ +""" +This code is modified version of mmdetection. +https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/necks/yolo_neck.py +""" +from typing import List + +from omegaconf import DictConfig +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...op.custom import ConvLayer +from ...utils import BackboneOutput + + +class YOLOv3FPNBlock(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + double_channel: bool, + kernel_size: int, + norm_type: str, + act_type: str, + depthwise: bool, + share_fpn_block: bool, + ) -> None: + super().__init__() + self.share_fpn_block = share_fpn_block + double_out_channels = out_channels * 2 if double_channel else out_channels + groups = double_out_channels if depthwise else 1 + + # shortcut + if self.share_fpn_block: + self.conv1 = ConvLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, + norm_type=norm_type, act_type=act_type) + self.conv2 = ConvLayer(in_channels=out_channels, out_channels=double_out_channels, kernel_size=kernel_size, + groups=groups, norm_type=norm_type, act_type=act_type) + self.conv3 = ConvLayer(in_channels=double_out_channels, out_channels=out_channels, kernel_size=1, + norm_type=norm_type, act_type=act_type) + self.conv4 = ConvLayer(in_channels=out_channels, out_channels=double_out_channels, kernel_size=kernel_size, + groups=groups, norm_type=norm_type, act_type=act_type) + self.conv5 = ConvLayer(in_channels=double_out_channels, out_channels=out_channels, kernel_size=1, + norm_type=norm_type, act_type=act_type) + + def forward(self, x): + if self.share_fpn_block: + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + out = self.conv5(x) + return out + + +class YOLOv3FPN(nn.Module): + + def __init__( + self, + intermediate_features_dim: List[int], + params: DictConfig, + ): + super().__init__() + self.input_channels = intermediate_features_dim + self.out_channels = params.out_channels + self.double_channel = params.double_channel + self.kernel_size = params.kernel_size + self.norm_type = params.norm_type + self.act_type = params.act_type + self.share_fpn_block = params.share_fpn_block + self.depthwise = params.depthwise + + self._intermediate_features_dim = self.out_channels + + self.input_channels = self.input_channels[::-1] + self.out_channels = self.out_channels[::-1] + + self.conv_blocks = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + + self.build_conv_blocks() + self.build_fpn_blocks() + + def build_conv_blocks(self): + if self.share_fpn_block: + for i in range(1, len(self.input_channels)): + in_c, out_c = self.input_channels[i], self.out_channels[i] + inter_c = self.out_channels[i - 1] + self.conv_blocks.append(self.build_1x1_conv(in_channels=inter_c, out_channels=out_c)) + else: + in_c, out_c = self.input_channels[0], self.out_channels[0] + self.conv_blocks.append(self.build_1x1_conv(in_channels=in_c, out_channels=out_c)) + + def build_fpn_blocks(self): + for i in range(len(self.input_channels)): + in_c, out_c = self.input_channels[i], self.out_channels[i] + in_c = in_c if i == 0 else in_c + out_c + self.fpn_blocks.append(self.build_fpn_block(in_channels=in_c, out_channels=out_c)) + + def build_1x1_conv(self, in_channels, out_channels): + return ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + norm_type=self.norm_type, + act_type=self.act_type) + + def build_fpn_block(self, in_channels, out_channels): + return YOLOv3FPNBlock(in_channels=in_channels, + out_channels=out_channels, + double_channel=self.double_channel, + kernel_size=self.kernel_size, + norm_type=self.norm_type, + act_type=self.act_type, + depthwise=self.depthwise, + share_fpn_block=self.share_fpn_block) + + def forward(self, inputs): + outputs = [] + + feat = inputs[-1] if self.share_fpn_block else self.conv_blocks[0](inputs[-1]) + tmp = self.fpn_blocks[0](feat) + outputs.append(tmp) + + if self.share_fpn_block: feat = tmp + + for i, x in enumerate(reversed(inputs[:-1])): + if self.share_fpn_block: + feat = self.conv_blocks[i](feat) + + # Cat with low-lvl feats + feat = F.interpolate(feat, scale_factor=2) + feat = torch.cat((feat, x), 1) + + tmp = self.fpn_blocks[i+1](feat) + outputs.append(tmp) + + if self.share_fpn_block: feat = tmp + + return BackboneOutput(intermediate_features=outputs[::-1]) + + @property + def intermediate_features_dim(self): + return self._intermediate_features_dim + + +def yolov3fpn(intermediate_features_dim, conf_model_neck, **kwargs): + return YOLOv3FPN(intermediate_features_dim=intermediate_features_dim, params=conf_model_neck.params) \ No newline at end of file From 12be4be794b7ff853b0813f3520d82102d9f08e3 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 15:00:17 +0900 Subject: [PATCH 08/12] init yolo_fastest_head --- src/netspresso_trainer/models/heads/detection/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/netspresso_trainer/models/heads/detection/__init__.py b/src/netspresso_trainer/models/heads/detection/__init__.py index 80e5b4f84..6918ac09b 100644 --- a/src/netspresso_trainer/models/heads/detection/__init__.py +++ b/src/netspresso_trainer/models/heads/detection/__init__.py @@ -15,4 +15,5 @@ # ---------------------------------------------------------------------------- from .experimental.anchor_free_decoupled_head import anchor_free_decoupled_head -from .experimental.anchor_decoupled_head import anchor_decoupled_head \ No newline at end of file +from .experimental.anchor_decoupled_head import anchor_decoupled_head +from .experimental.yolo_fastest_head import yolo_fastest_head \ No newline at end of file From c889b955b4426eff4889ebde17b1d7495546fe33 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 15:00:38 +0900 Subject: [PATCH 09/12] [feat] implement yolo_fastest_head --- .../experimental/yolo_fastest_head.py | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py diff --git a/src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py b/src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py new file mode 100644 index 000000000..b2d8124a2 --- /dev/null +++ b/src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py @@ -0,0 +1,106 @@ +from typing import List +from omegaconf import DictConfig +import torch +import torch.nn as nn + +from ....op.custom import ConvLayer +from ....utils import AnchorBasedDetectionModelOutput +from .detection import AnchorGenerator + + +class YoloFastestHead(nn.Module): + + num_layers: int + + def __init__( + self, + num_classes: int, + intermediate_features_dim: List[int], + params: DictConfig, + ): + super().__init__() + + anchors = params.anchors + num_layers = len(anchors) + self.anchors = anchors + tmp_cell_anchors = [] + for a in self.anchors: + a = torch.tensor(a).view(-1, 2) + wa = a[:, 0:1] + ha = a[:, 1:] + base_anchors = torch.cat([-wa, -ha, wa, ha], dim=-1)/2 + tmp_cell_anchors.append(base_anchors) + self.anchor_generator = AnchorGenerator(sizes=((128),)) # TODO: dynamic image_size, and anchor_size as a parameters + self.anchor_generator.cell_anchors = tmp_cell_anchors + num_anchors = self.anchor_generator.num_anchors_per_location()[0] + self.num_anchors = num_anchors + self.num_layers = num_layers + self.num_classes = num_classes + out_channels = num_anchors * (4 + num_classes) # TODO: Add confidence score dim + norm_type = params.norm_type + use_act = False + kernel_size = 1 + + for i in range(num_layers): + + in_channels = intermediate_features_dim[i] + + conv_norm = ConvLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + norm_type=norm_type, + use_act=use_act, + ) + conv = ConvLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + use_norm=False, + use_act=use_act, + ) + + layer = nn.Sequential(conv_norm, conv) + + setattr(self, f"layer_{i+1}", layer) + + def init_bn(M): + for m in M.modules(): + + if isinstance(m, nn.BatchNorm2d): + m.eps = 1e-3 + m.momentum = 0.03 + + self.apply(init_bn) + + def forward(self, inputs: List[torch.Tensor]): + x1, x2 = inputs + out1 = self.layer_1(x1) + out2 = self.layer_2(x2) + output = [out1, out2] + all_cls_logits = [] + all_bbox_regression = [] + anchors = torch.cat(self.anchor_generator(output), dim=0) + for idx, o in enumerate(output): + N, _, H, W = o.shape + o = o.view(N, self.num_anchors, -1, H, W).permute(0, 3, 4, 1, 2) + bbox_regression = o[..., :4] + cls_logits = o[..., 4:] + bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) + cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, K) + all_bbox_regression.append(bbox_regression) + all_cls_logits.append(cls_logits) + return AnchorBasedDetectionModelOutput({"anchors": anchors, + "bbox_regression": all_bbox_regression, + "cls_logits": all_cls_logits, + }) + + +def yolo_fastest_head( + num_classes, intermediate_features_dim, conf_model_head, **kwargs +) -> YoloFastestHead: + return YoloFastestHead( + num_classes=num_classes, + intermediate_features_dim=intermediate_features_dim, + params=conf_model_head.params, + ) \ No newline at end of file From 50a164c9af95725dc20ae5c1245a1c5146ddf1c7 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 15:01:22 +0900 Subject: [PATCH 10/12] Register Darknet, yolov3fpn, yolo_fastest_head --- src/netspresso_trainer/models/registry.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/netspresso_trainer/models/registry.py b/src/netspresso_trainer/models/registry.py index b946c5d89..90c3249ba 100644 --- a/src/netspresso_trainer/models/registry.py +++ b/src/netspresso_trainer/models/registry.py @@ -19,14 +19,24 @@ import torch.nn as nn -from .backbones import cspdarknet, efficientformer, mixnet, mixtransformer, mobilenetv3, mobilevit, resnet, vit +from .backbones import ( + cspdarknet, + darknet, + efficientformer, + mixnet, + mixtransformer, + mobilenetv3, + mobilevit, + resnet, + vit, +) from .base import ClassificationModel, DetectionModel, PoseEstimationModel, SegmentationModel, TaskModel from .full import pidnet from .heads.classification import fc -from .heads.detection import anchor_decoupled_head, anchor_free_decoupled_head +from .heads.detection import anchor_decoupled_head, anchor_free_decoupled_head, yolo_fastest_head from .heads.pose_estimation import rtmcc from .heads.segmentation import all_mlp_decoder -from .necks import fpn, yolopafpn +from .necks import fpn, yolopafpn, yolov3fpn MODEL_BACKBONE_DICT: Dict[str, Callable[..., nn.Module]] = { 'resnet': resnet, @@ -36,12 +46,14 @@ 'vit': vit, 'efficientformer': efficientformer, 'cspdarknet': cspdarknet, + 'darknet': darknet, 'mixnet': mixnet, } MODEL_NECK_DICT: Dict[str, Callable[..., nn.Module]] = { 'fpn': fpn, 'yolopafpn': yolopafpn, + 'yolov3fpn': yolov3fpn, } MODEL_HEAD_DICT: Dict[str, Callable[..., nn.Module]] = { @@ -54,6 +66,7 @@ 'detection': { 'anchor_free_decoupled_head': anchor_free_decoupled_head, 'anchor_decoupled_head': anchor_decoupled_head, + 'yolo_fastest_head': yolo_fastest_head, }, 'pose_estimation': { 'rtmcc': rtmcc, From 44638e401b13fd3c2f6bd8746552aeba4e576725 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 15:02:50 +0900 Subject: [PATCH 11/12] Register PostProcessor --- src/netspresso_trainer/postprocessors/detection.py | 2 +- src/netspresso_trainer/postprocessors/registry.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/netspresso_trainer/postprocessors/detection.py b/src/netspresso_trainer/postprocessors/detection.py index 1ba1ca818..9c41bdb8f 100644 --- a/src/netspresso_trainer/postprocessors/detection.py +++ b/src/netspresso_trainer/postprocessors/detection.py @@ -169,7 +169,7 @@ def __init__(self, conf_model): if head_name == 'anchor_free_decoupled_head': self.decode_outputs = partial(anchor_free_decoupled_head_decode, score_thresh=params.score_thresh) self.postprocess = partial(nms, nms_thresh=params.nms_thresh, class_agnostic=params.class_agnostic) - elif head_name == 'anchor_decoupled_head': + elif head_name == 'anchor_decoupled_head' or head_name == 'yolo_fastest_head': self.decode_outputs = partial(anchor_decoupled_head_decode, topk_candidates=params.topk_candidates, score_thresh=params.score_thresh) self.postprocess = partial(nms, nms_thresh=params.nms_thresh, class_agnostic=params.class_agnostic) else: diff --git a/src/netspresso_trainer/postprocessors/registry.py b/src/netspresso_trainer/postprocessors/registry.py index e5d288bb0..2ef16ff5d 100644 --- a/src/netspresso_trainer/postprocessors/registry.py +++ b/src/netspresso_trainer/postprocessors/registry.py @@ -27,5 +27,6 @@ 'anchor_free_decoupled_head': DetectionPostprocessor, 'pidnet': SegmentationPostprocessor, 'anchor_decoupled_head': DetectionPostprocessor, + 'yolo_fastest_head': DetectionPostprocessor, 'rtmcc': PoseEstimationPostprocessor, } From 7bfa252ff56567cfca37778f548384af295085a4 Mon Sep 17 00:00:00 2001 From: hglee98 Date: Mon, 24 Jun 2024 06:14:41 +0000 Subject: [PATCH 12/12] Register ReLU6 activation --- src/netspresso_trainer/models/op/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/netspresso_trainer/models/op/registry.py b/src/netspresso_trainer/models/op/registry.py index 19150b6b6..77e63efc3 100644 --- a/src/netspresso_trainer/models/op/registry.py +++ b/src/netspresso_trainer/models/op/registry.py @@ -32,4 +32,5 @@ 'silu': nn.SiLU, 'swish': nn.SiLU, 'hard_swish': nn.Hardswish, + 'relu6': nn.ReLU6, }