Skip to content

Commit

Permalink
update unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiangxu-0103 committed Dec 26, 2023
1 parent f1558bb commit 802b1c7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
10 changes: 5 additions & 5 deletions tests/test_models/test_decode_heads/test_cylinder3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def test_cylinder3d_head_loss(self):

sparse_voxels = SparseConvTensor(voxel_feats, voxel_coors, grid_size,
batch_size)
voxel_dict = dict(
feat_dict = dict(
voxel_feats=sparse_voxels,
voxel_coors=voxel_coors,
coors=coors,
point2voxel_maps=[point2voxel_map])
# Test forward
voxel_dict = cylinder3d_head.forward(voxel_dict)
feat_dict = cylinder3d_head.forward(feat_dict)

self.assertEqual(voxel_dict['logits'].shape,
self.assertEqual(feat_dict['logits'].shape,
torch.Size([voxel_coors.shape[0], 20]))

# When truth is non-empty then losses
Expand All @@ -62,13 +62,13 @@ def test_cylinder3d_head_loss(self):
datasample = Det3DDataSample()
datasample.gt_pts_seg = gt_pts_seg

losses = cylinder3d_head.loss_by_feat(voxel_dict, [datasample])
losses = cylinder3d_head.loss_by_feat(feat_dict, [datasample])

loss_ce = losses['loss_ce'].item()
loss_lovasz = losses['loss_lovasz'].item()

self.assertGreater(loss_ce, 0, 'ce loss should be positive')
self.assertGreater(loss_lovasz, 0, 'lovasz loss should be positive')

point_logits = cylinder3d_head.predict(voxel_dict, [datasample])
point_logits = cylinder3d_head.predict(feat_dict, [datasample])
assert point_logits[0].shape == torch.Size([100, 20])
20 changes: 10 additions & 10 deletions tests/test_models/test_voxel_encoders/test_voxel_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def test_seg_VFE():
coor = F.pad(coor, (1, 0), mode='constant', value=i)
coors.append(coor)
coors = torch.cat(coors, dim=0).cuda()
voxel_dict = dict(voxels=features, coors=coors)
voxel_dict = seg_VFE(voxel_dict)
assert voxel_dict['voxel_feats'].shape[0] == voxel_dict[
'voxel_coors'].shape[0]
assert len(voxel_dict['point_feats']) == 4
assert voxel_dict['point_feats'][0].shape == torch.Size([240000, 64])
assert voxel_dict['point_feats'][1].shape == torch.Size([240000, 128])
assert voxel_dict['point_feats'][2].shape == torch.Size([240000, 256])
assert voxel_dict['point_feats'][3].shape == torch.Size([240000, 256])
assert len(voxel_dict['point2voxel_maps']) == 4
feat_dict = dict(voxels=features, coors=coors)
feat_dict = seg_VFE(feat_dict)
assert feat_dict['voxel_feats'].shape[0] == feat_dict['voxel_coors'].shape[
0]
assert len(feat_dict['point_feats']) == 4
assert feat_dict['point_feats'][0].shape == torch.Size([240000, 64])
assert feat_dict['point_feats'][1].shape == torch.Size([240000, 128])
assert feat_dict['point_feats'][2].shape == torch.Size([240000, 256])
assert feat_dict['point_feats'][3].shape == torch.Size([240000, 256])
assert len(feat_dict['point2voxel_maps']) == 4

0 comments on commit 802b1c7

Please sign in to comment.