Skip to content

Commit

Permalink
Preparations for release
Browse files Browse the repository at this point in the history
  • Loading branch information
NamanMakkar committed Nov 14, 2024
1 parent 4b6bc96 commit a5e7851
Show file tree
Hide file tree
Showing 6 changed files with 622 additions and 438 deletions.
4 changes: 2 additions & 2 deletions vajra/core/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vajra.dataset.dataset import VajraDetDataset
from vajra.dataset.build import build_dataloader
from vajra.dataset.utils import check_det_dataset, check_cls_dataset, check_class_names, default_class_names
from vajra.nn.modules import VajraMerudandaBhag1, VajraMerudandaBhag4, AttentionBottleneckV2
from vajra.nn.modules import VajraMerudandaBhag1, VajraMerudandaBhag4, VajraMerudandaBhag7, AttentionBottleneckV2
from vajra.nn.head import Detection
from vajra.nn.vajra import DetectionModel, SegmentationModel, VajraWorld
from vajra.utils import (
Expand Down Expand Up @@ -168,7 +168,7 @@ def __call__(self, model=None):
module.dynamic = self.args.dynamic
module.export = True
module.format = self.args.format
elif isinstance(module, (VajraMerudandaBhag4, AttentionBottleneckV2)) and not is_tf_format:
elif isinstance(module, (VajraMerudandaBhag4, AttentionBottleneckV2, VajraMerudandaBhag7)) and not is_tf_format:
module.forward = module.forward_split

y = None
Expand Down
40 changes: 28 additions & 12 deletions vajra/nn/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, num_classes=80, in_channels=[]) -> None:
for ch in in_channels
)
self.branch_cls = nn.ModuleList(
nn.Sequential(nn.Sequential(DepthwiseConvBNAct(ch, ch, 1, 3), ConvBNAct(ch, c3, 1, 1)),#ConvBNAct(ch, c3, 1, 3),
nn.Sequential(nn.Sequential(DepthwiseConvBNAct(ch, ch, 1, 3), ConvBNAct(ch, c3, 1, 1)),
nn.Sequential(DepthwiseConvBNAct(c3, c3, 1, 3), ConvBNAct(c3, c3, 1, 1)),
nn.Conv2d(c3, self.num_classes, 1))
for ch in in_channels
Expand Down Expand Up @@ -158,11 +158,11 @@ def __init__(self, num_classes=80, num_masks=32, num_protos=256, in_channels=[])
self.num_masks = num_masks
self.num_protos = num_protos
self.proto = ProtoMaskModule(in_channels[0], self.num_protos, self.num_masks)
self.detection = Detection.forward
c4 = max(in_channels[0] // 4, self.num_masks)
self.branch_seg = nn.ModuleList(
nn.Sequential(
ConvBNAct(ch, c4, kernel_size=3, stride=1),
ConvBNAct(c4, c4, kernel_size=3, stride=1),
nn.Conv2d(c4, self.num_masks, 1)
)
for ch in in_channels
Expand All @@ -174,7 +174,7 @@ def forward(self, x):

mask_coefficients = torch.cat([self.branch_seg[i](x[i]).view(batch_size, self.num_masks, -1)
for i in range(self.num_det_layers)], 2)
x = self.detection(self, x)
x = Detection.forward(self, x)
if self.training:
return x, mask_coefficients, proto_masks
return (torch.cat([x, mask_coefficients], 1), proto_masks) if self.export else (torch.cat([x[0], mask_coefficients], 1), (x[1], mask_coefficients, proto_masks))
Expand All @@ -192,6 +192,7 @@ def __init__(self, num_classes=80, keypoint_shape=(17, 3), in_channels=[]) -> No
self.branch_pose_detect = nn.ModuleList(
nn.Sequential(
ConvBNAct(ch, c4, kernel_size=3, stride=1),
ConvBNAct(c4, c4, kernel_size=3, stride=1),
nn.Conv2d(c4, self.num_keypoints, 1)
)
for ch in in_channels
Expand All @@ -200,8 +201,18 @@ def __init__(self, num_classes=80, keypoint_shape=(17, 3), in_channels=[]) -> No
def decode_keypoints(self, batch_size, keypoints):
ndim = self.keypoint_shape[1]
if self.export:
y = keypoints.view(batch_size, *self.keypoint_shape, -1)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
if self.format in {
"tflite",
"edgetpu",
}:
y = keypoints.view(batch_size, *self.keypoint_shape, -1)
grid_h, grid_w = self.shape[2], self.shape[3]
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
norm = self.strides / (self.stride[0] * grid_size)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
else:
y = keypoints.view(batch_size, *self.keypoint_shape, -1)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
if ndim == 3:
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
return a.view(batch_size, self.num_keypoints, -1)
Expand All @@ -217,7 +228,7 @@ def forward(self, x):
batch_size = x[0].shape[0] # batch_size
keypoint = torch.cat([self.branch_pose_detect[i](x[i]).view(batch_size, self.num_keypoints, -1)
for i in range(self.num_det_layers)], -1) # (batch_size, 17*3, h*w)
x = self.detection(self, x)
x = Detection.forward(self, x)
if self.training:
return x, keypoint
pred_keypoints = self.decode_keypoints(batch_size, keypoint)
Expand All @@ -233,7 +244,7 @@ def __init__(self, num_classes=80, embed_dim=512, with_bn=False, in_channels=[])
self.with_bn = with_bn
self.in_channels = in_channels
c3 = max(in_channels[0], min(self.num_classes, 100))
self.branch3 = nn.ModuleList(nn.Sequential(ConvBNAct(x, c3, 3), nn.Conv2d(c3, embed_dim, 1)) for x in in_channels)
self.branch3 = nn.ModuleList(nn.Sequential(ConvBNAct(ch, c3, stride=1, kernel_size=3), ConvBNAct(c3, c3, kernel_size=3, stride=1), nn.Conv2d(c3, embed_dim, 1)) for ch in in_channels)
self.branch4 = nn.ModuleList(BNContrastiveHead(embed_dim) if with_bn else ContrastiveHead() for _ in in_channels)

def forward(self, x, text):
Expand All @@ -260,13 +271,18 @@ def forward(self, x, text):
grid_w = shape[3]
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * grid_size)
dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1)
dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0) * norm[:, :2])
else:
dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0)) * self.strides

y = torch.cat((dist_box, cls.sigmoid()), 1)

return y if self.export else (y, x)

def bias_init(self):
detection_module = self
for branch_a, branch_b, stride in zip(detection_module.branch_det, detection_module.branch_cls, detection_module.stride):
branch_a[-1].bias.data[:] = 1.0

def get_module_info(self):
return f"WorldDetection", f"[{self.num_classes}, {self.embed_dim}, {self.with_bn}, {self.in_channels}]"
Expand All @@ -276,11 +292,11 @@ class OBBDetection(Detection):
def __init__(self, num_classes=80, num_extra_params=1, in_channels=[]) -> None:
super().__init__(num_classes, in_channels)
self.num_extra = num_extra_params
self.detect = Detection.forward
c4 = max(in_channels[0] // 4, self.num_extra)
self.oriented_branch = nn.ModuleList(
nn.Sequential(
ConvBNAct(ch, c4, 3),
ConvBNAct(ch, c4, kernel_size=3, stride=1),
ConvBNAct(c4, c4, kernel_size=3, stride=1),
nn.Conv2d(c4, self.num_extra, 1)
)
for ch in in_channels
Expand All @@ -293,7 +309,7 @@ def forward(self, x):

if not self.training:
self.angle = angle
x = self.detect(self, x)
x = Detection.forward(self, x)
if self.training:
return x, angle

Expand Down
Loading

0 comments on commit a5e7851

Please sign in to comment.