-
Notifications
You must be signed in to change notification settings - Fork 5
/
net.py
43 lines (30 loc) · 1.3 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from cvpods.layers import ShapeSpec
from cvpods.modeling.backbone import Backbone
from cvpods.modeling.backbone.fpn import build_resnet_fpn_backbone
from cvpods.modeling.proposal_generator import RPN
from cvpods.modeling.roi_heads.box_head import FastRCNNConvFCHead
from dataset import * # noqa
from modeling import DoubleHeadRPN, RedetectROIHeads
from rcnn import RCNN
def build_backbone(cfg, input_shape=None):
if input_shape is None:
input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))
backbone = build_resnet_fpn_backbone(cfg, input_shape)
assert isinstance(backbone, Backbone)
return backbone
def build_proposal_generator_train(cfg, input_shape):
return RPN(cfg, input_shape)
def build_proposal_generator_inference(cfg, input_shape):
return DoubleHeadRPN(cfg, input_shape)
def build_roi_heads(cfg, input_shape):
return RedetectROIHeads(cfg, input_shape)
def build_box_head(cfg, input_shape):
return FastRCNNConvFCHead(cfg, input_shape)
def build_model(cfg, training=True):
cfg.build_backbone = build_backbone
cfg.build_proposal_generator = build_proposal_generator_train \
if training else build_proposal_generator_inference
cfg.build_roi_heads = build_roi_heads
cfg.build_box_head = build_box_head
model = RCNN(cfg)
return model