Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding YOLOFastestV1 full model except loss #472

Draft
wants to merge 12 commits into
base: dev
Choose a base branch
from
76 changes: 76 additions & 0 deletions config/model/yolo/yolo-fastest-v1-detection.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
model:
task: detection
name: yolo_fastest
checkpoint:
use_pretrained: False
load_head: False
path: ~
fx_model_path: ~
optimizer_path: ~
freeze_backbone: False
architecture:
full: ~ # auto
backbone:
name: darknet
params:
# stage_stem_block_type: Bottleneck
stage_stem_block_type: darknetblock
depthwise: True
norm_type: batch_norm
act_type: relu6
num_feat_layers: 2
stem_stride: 2
stem_out_channels: 8
stage_params:
- # stage_1
out_channels: 4
num_blocks: 1 # not the stage stem block
darknet_expansions: [2, 2]
- # stage_2
out_channels: 8
num_blocks: 2
darknet_expansions: [3, 4]
- # stage_3
out_channels: 8
num_blocks: 2
darknet_expansions: [4, 6]
- # stage_4
out_channels: 16
num_blocks: 4
darknet_expansions: [3, 6]
- # stage_5
out_channels: 24
num_blocks: 4
darknet_expansions: [4, 5.67] # expands to 136 channels
- # stage_6
out_channels: 48
num_blocks: 5
darknet_expansions: [2.84, 4.67] # expands to 224 channels
neck:
name: yolov3fpn
params:
out_channels: [48, 24]
kernel_size: 5
double_channel: False
share_fpn_block: False
depthwise: True
norm_type: batch_norm
act_type: leaky_relu
head:
name: yolo_fastest_head
params:
anchors:
- [12.64,19.39, 37.88,51.48, 55.71,138.31] # P4/16
- [126.91,78.23, 131.57,214.55, 279.92,258.87] # P5/32]
# strides: [16, 32]
norm_type: batch_norm
topk_candidates: 1000
# postprocessor - decode
score_thresh: 0.0
# postprocessor - nms
nms_thresh: 0.45
class_agnostic: False
# Temporary loss to test the full YOLOFastest model to work right
losses:
- criterion: retinanet_loss
weight: ~
2 changes: 1 addition & 1 deletion src/netspresso_trainer/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# from .core import *
from .core.resnet import resnet
from .experimental.darknet import cspdarknet
from .experimental.darknet import cspdarknet, darknet
from .experimental.efficientformer import efficientformer
from .experimental.mixnet import mixnet
from .experimental.mixtransformer import mixtransformer
Expand Down
171 changes: 169 additions & 2 deletions src/netspresso_trainer/models/backbones/experimental/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
Based on the Darknet implementation of Megvii.
https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/darknet.py
"""
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Type

from omegaconf import DictConfig
import torch
from torch import nn

from ...op.custom import ConvLayer, CSPLayer, Focus, SPPBottleneck, SeparableConvLayer
from ...op.custom import ConvLayer, CSPLayer, Focus, SPPBottleneck, SeparableConvLayer, DarknetBlock
from ...utils import BackboneOutput
from ..registry import USE_INTERMEDIATE_FEATURES_TASK_LIST

__all__ = ['cspdarknet']
SUPPORTING_TASK = ['classification', 'segmentation', 'detection', 'pose_estimation']
DARKNET_SUPPORTED_BLOCKS = ["darknetblock"]
BLOCK_FROM_LITERAL: Dict[str, Type[nn.Module]] = {"darknetblock": DarknetBlock}


class CSPDarknet(nn.Module):
Expand Down Expand Up @@ -181,3 +183,168 @@ def task_support(self, task):

def cspdarknet(task, conf_model_backbone) -> CSPDarknet:
return CSPDarknet(task, conf_model_backbone.params, conf_model_backbone.stage_params)


class Darknet(nn.Module):
"""
Consists of a stem layer and multiple stage layers.
Stage layers are named as stage_{i} starting from stage_1
"""

num_stages: int

def __init__(
self,
task: str,
params: Optional[DictConfig] = None,
stage_params: Optional[List] = None,
) -> None:
self.task = task.lower()
assert (
self.task in SUPPORTING_TASK
), f"Darknet is not supported on {self.task} task now."
assert stage_params, "please provide stage params of Darknet"
assert len(stage_params) >= 2
assert (
params.stage_stem_block_type.lower() in DARKNET_SUPPORTED_BLOCKS
), "Block type not supported"
self.use_intermediate_features = (
self.task in USE_INTERMEDIATE_FEATURES_TASK_LIST
)

self.num_stages = len(stage_params)

super().__init__()

# TODO: Check if inplace activation should be used
act_type = params.act_type
norm_type = params.norm_type
stage_stem_block_type = params.stage_stem_block_type
stem_stride = params.stem_stride
stem_out_channels = params.stem_out_channels
depthwise = params.depthwise

StageStemBlock = BLOCK_FROM_LITERAL[stage_stem_block_type.lower()]
predefined_out_features = dict()

# build the stem layer
self.stem = ConvLayer(
in_channels=3,
out_channels=stem_out_channels,
kernel_size=3,
stride=stem_stride,
act_type=act_type,
norm_type=norm_type,
)

prev_out_channels = stem_out_channels

# build rest of the layers
# TODO: make it compatiable with Yolov3
for i, stage_param in enumerate(stage_params):

layers = []
hidden_expansions = stage_param.darknet_expansions
out_channels = stage_param.out_channels

if len(hidden_expansions) == 2:
# stage_stem_expansion is defined as hidden_ch // output_ch
stage_stem_expansion = hidden_expansions[0]
block_expansion = hidden_expansions[1]

# TODO: Implement
else:
raise NotImplementedError

stage_stem_block = StageStemBlock(
in_channels=prev_out_channels,
out_channels=out_channels,
shortcut=False,
expansion=stage_stem_expansion,
depthwise=depthwise,
act_type=act_type,
norm_type=norm_type,
no_out_act=False,
depthwise_stride=2,
)

layers.append(stage_stem_block)
prev_out_channels = out_channels

for _ in range(stage_param.num_blocks):

in_ch = prev_out_channels
out_ch = in_ch
darknet_block = DarknetBlock(
in_channels=in_ch,
out_channels=out_ch,
shortcut=True,
expansion=block_expansion,
depthwise=depthwise,
norm_type=norm_type,
act_type=act_type,
no_out_act=True,
)

layers.append(darknet_block)
setattr(self, f"stage_{i+1}", nn.Sequential(*layers))
predefined_out_features[f"stage_{i+1}"] = stage_param.out_channels

# feature layers
self.out_features = []
first_feat_layer = self.num_stages - params.num_feat_layers + 1
for i in range(params.num_feat_layers):
layer_str = f"stage_{first_feat_layer + i}"
self.out_features.append(layer_str)

self._feature_dim = predefined_out_features[f"stage_{self.num_stages-1}"]

self._intermediate_features_dim = [
predefined_out_features[out_feature] for out_feature in self.out_features
]

# Initialize
def init_bn(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03

self.apply(init_bn)
return

def forward(self, x):
outputs_dict = {}
x = self.stem(x)
outputs_dict["stem"] = x

for i in range(1, self.num_stages + 1):
x = getattr(self, f"stage_{i}")(x)
outputs_dict[f"stage_{i}"] = x

if self.use_intermediate_features:
all_hidden_states = [
outputs_dict[out_name] for out_name in self.out_features
]
return BackboneOutput(intermediate_features=all_hidden_states)

# TODO: Check if classification head is needed
x = self.avgpool(x)
x = torch.flatten(x, 1)

return BackboneOutput(last_feature=x)

@property
def feature_dim(self):
return self._feature_dim

@property
def intermediate_features_dim(self):
return self._intermediate_features_dim

def task_support(self, task):
return task.lower() in SUPPORTING_TASK


def darknet(task, conf_model_backbone) -> Darknet:
return Darknet(task, conf_model_backbone.params, conf_model_backbone.stage_params)
3 changes: 2 additions & 1 deletion src/netspresso_trainer/models/heads/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# ----------------------------------------------------------------------------

from .experimental.anchor_free_decoupled_head import anchor_free_decoupled_head
from .experimental.anchor_decoupled_head import anchor_decoupled_head
from .experimental.anchor_decoupled_head import anchor_decoupled_head
from .experimental.yolo_fastest_head import yolo_fastest_head
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import List
from omegaconf import DictConfig
import torch
import torch.nn as nn

from ....op.custom import ConvLayer
from ....utils import AnchorBasedDetectionModelOutput
from .detection import AnchorGenerator


class YoloFastestHead(nn.Module):

num_layers: int

def __init__(
self,
num_classes: int,
intermediate_features_dim: List[int],
params: DictConfig,
):
super().__init__()

anchors = params.anchors
num_layers = len(anchors)
self.anchors = anchors
tmp_cell_anchors = []
for a in self.anchors:
a = torch.tensor(a).view(-1, 2)
wa = a[:, 0:1]
ha = a[:, 1:]
base_anchors = torch.cat([-wa, -ha, wa, ha], dim=-1)/2
tmp_cell_anchors.append(base_anchors)
self.anchor_generator = AnchorGenerator(sizes=((128),)) # TODO: dynamic image_size, and anchor_size as a parameters
self.anchor_generator.cell_anchors = tmp_cell_anchors
num_anchors = self.anchor_generator.num_anchors_per_location()[0]
self.num_anchors = num_anchors
self.num_layers = num_layers
self.num_classes = num_classes
out_channels = num_anchors * (4 + num_classes) # TODO: Add confidence score dim
norm_type = params.norm_type
use_act = False
kernel_size = 1

for i in range(num_layers):

in_channels = intermediate_features_dim[i]

conv_norm = ConvLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
norm_type=norm_type,
use_act=use_act,
)
conv = ConvLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
use_norm=False,
use_act=use_act,
)

layer = nn.Sequential(conv_norm, conv)

setattr(self, f"layer_{i+1}", layer)

def init_bn(M):
for m in M.modules():

if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03

self.apply(init_bn)

def forward(self, inputs: List[torch.Tensor]):
x1, x2 = inputs
out1 = self.layer_1(x1)
out2 = self.layer_2(x2)
output = [out1, out2]
all_cls_logits = []
all_bbox_regression = []
anchors = torch.cat(self.anchor_generator(output), dim=0)
for idx, o in enumerate(output):
N, _, H, W = o.shape
o = o.view(N, self.num_anchors, -1, H, W).permute(0, 3, 4, 1, 2)
bbox_regression = o[..., :4]
cls_logits = o[..., 4:]
bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, K)
all_bbox_regression.append(bbox_regression)
all_cls_logits.append(cls_logits)
return AnchorBasedDetectionModelOutput({"anchors": anchors,
"bbox_regression": all_bbox_regression,
"cls_logits": all_cls_logits,
})


def yolo_fastest_head(
num_classes, intermediate_features_dim, conf_model_head, **kwargs
) -> YoloFastestHead:
return YoloFastestHead(
num_classes=num_classes,
intermediate_features_dim=intermediate_features_dim,
params=conf_model_head.params,
)
Loading
Loading