Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
zakajd committed Jun 4, 2020
1 parent 8222993 commit 0568102
Show file tree
Hide file tree
Showing 20 changed files with 473 additions and 411 deletions.
8 changes: 3 additions & 5 deletions pytorch_tools/detection_models/efficientdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,14 @@ def predict(self, x):
"""
class_outputs, box_outputs = self.forward(x)
anchors = box_utils.generate_anchors_boxes(x.shape[-2:])[0]
return box_utils.decode(
class_outputs, box_outputs, anchors, #img_shape=x.shape[-2:]
)
return box_utils.decode(class_outputs, box_outputs, anchors)

def _initialize_weights(self):
# init everything except encoder
no_encoder_m = [m for n, m in self.named_modules() if not "encoder" in n]
initialize_iterator(no_encoder_m)
# need to init last bias so that after sigmoid it's 0.01
cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59
# need to init last bias so that after sigmoid it's 0.01
cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59
nn.init.constant_(self.cls_head_convs[-1][1].bias, cls_bias_init)


Expand Down
13 changes: 6 additions & 7 deletions pytorch_tools/detection_models/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class RetinaNet(nn.Module):

def __init__(
self,
pretrained="coco", # not used here for proper signature
pretrained="coco", # not used here for proper signature
encoder_name="resnet50",
encoder_weights="imagenet",
pyramid_channels=256,
Expand Down Expand Up @@ -90,7 +90,7 @@ def make_final_convs():
self.box_convs = make_final_convs()
self.box_head_conv = conv3x3(pyramid_channels, 4 * anchors_per_location, bias=True)
self.num_classes = num_classes
self. _initialize_weights()
self._initialize_weights()

# Name from mmdetectin for convenience
def extract_features(self, x):
Expand Down Expand Up @@ -126,18 +126,17 @@ def predict(self, x):
"""Run forward on given images and decode raw prediction into bboxes"""
class_outputs, box_outputs = self.forward(x)
anchors = box_utils.generate_anchors_boxes(x.shape[-2:])[0]
return box_utils.decode(
class_outputs, box_outputs, anchors, img_shape=x.shape[-2:]
)
return box_utils.decode(class_outputs, box_outputs, anchors)

def _initialize_weights(self):
# init everything except encoder
no_encoder_m = [m for n, m in self.named_modules() if not "encoder" in n]
initialize_iterator(no_encoder_m)
# need to init last bias so that after sigmoid it's 0.01
cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59
# need to init last bias so that after sigmoid it's 0.01
cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59
nn.init.constant_(self.cls_head_conv.bias, cls_bias_init)


# Don't really know input size for the models. 512 is just a guess
PRETRAIN_SETTINGS = {**DEFAULT_IMAGENET_SETTINGS, "input_size": (512, 512), "crop_pct": 1, "num_classes": 80}

Expand Down
2 changes: 1 addition & 1 deletion pytorch_tools/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@
from .bit_resnet import bit_m_101x1
from .bit_resnet import bit_m_101x3
from .bit_resnet import bit_m_152x2
from .bit_resnet import bit_m_152x4
from .bit_resnet import bit_m_152x4
Loading

0 comments on commit 0568102

Please sign in to comment.