forked from ver0z/Deformable-DETR-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
164 lines (143 loc) · 6.5 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# To load the correct checkpoints please change the link
import torch
from models.backbone import Backbone, Joiner
from models.deformable_detr import DeformableDETR, PostProcess
from models.position_encoding import PositionEmbeddingSine
from models.segmentation import DETRsegm, PostProcessPanoptic
from models.deformable_transformer import DeformableTransformer
dependencies = ["torch", "torchvision"]
def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False):
hidden_dim = 256
backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation)
pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
backbone_with_pos_enc = Joiner(backbone, pos_enc)
backbone_with_pos_enc.num_channels = backbone.num_channels
transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True)
detr = DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100)
if mask:
return DETRsegm(detr)
return detr
def detr_resnet50(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR R50 with 6 encoder and 6 decoder layers.
Achieves 42/62.4 AP/AP50 on COCO val5k.
"""
model = _make_detr("resnet50", dilation=False, num_classes=num_classes)
if pretrained:
# The original link for deformable resnet50 is https://drive.google.com/file/d/1WEjQ9_FgfI5sw5OZZ4ix-OKk-IJ_-SDU/view?usp=sharing
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model
def detr_resnet50_dc5(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR-DC5 R50 with 6 encoder and 6 decoder layers.
The last block of ResNet-50 has dilation to increase
output resolution.
Achieves 43.3/63.1 AP/AP50 on COCO val5k.
"""
model = _make_detr("resnet50", dilation=True, num_classes=num_classes)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-f0fb7ef5.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model
def detr_resnet101(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR-DC5 R101 with 6 encoder and 6 decoder layers.
Achieves 43.5/63.8 AP/AP50 on COCO val5k.
"""
model = _make_detr("resnet101", dilation=False, num_classes=num_classes)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r101-2c7b67e5.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model
def detr_resnet101_dc5(pretrained=False, num_classes=91, return_postprocessor=False):
"""
DETR-DC5 R101 with 6 encoder and 6 decoder layers.
The last block of ResNet-101 has dilation to increase
output resolution.
Achieves 44.9/64.7 AP/AP50 on COCO val5k.
"""
model = _make_detr("resnet101", dilation=True, num_classes=num_classes)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r101-dc5-a2e86def.pth", map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcess()
return model
def detr_resnet50_panoptic(
pretrained=False, num_classes=250, threshold=0.85, return_postprocessor=False
):
"""
DETR R50 with 6 encoder and 6 decoder layers.
Achieves 43.4 PQ on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction
"""
model = _make_detr("resnet50", dilation=False, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)}
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-panoptic-00ce5173.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
return model
def detr_resnet50_dc5_panoptic(
pretrained=False, num_classes=91, threshold=0.85, return_postprocessor=False
):
"""
DETR-DC5 R50 with 6 encoder and 6 decoder layers.
The last block of ResNet-50 has dilation to increase
output resolution.
Achieves 44.6 on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction
"""
model = _make_detr("resnet50", dilation=True, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)}
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r50-dc5-panoptic-da08f1b1.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
return model
def detr_resnet101_panoptic(
pretrained=False, num_classes=91, threshold=0.85, return_postprocessor=False
):
"""
DETR-DC5 R101 with 6 encoder and 6 decoder layers.
Achieves 45.1 PQ on COCO val5k.
threshold is the minimum confidence required for keeping segments in the prediction
"""
model = _make_detr("resnet101", dilation=False, num_classes=num_classes, mask=True)
is_thing_map = {i: i <= 90 for i in range(250)}
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/detr/detr-r101-panoptic-40021d53.pth",
map_location="cpu",
check_hash=True,
)
model.load_state_dict(checkpoint["model"])
if return_postprocessor:
return model, PostProcessPanoptic(is_thing_map, threshold=threshold)
return model