Skip to content

Commit

Permalink
reset unused
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiahao1999 committed Dec 27, 2023
1 parent 5a1dc65 commit 49a564a
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 43 deletions.
16 changes: 7 additions & 9 deletions mmdet3d/datasets/convert_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from shapely.geometry import MultiPoint, box
from shapely.geometry.polygon import Polygon

from mmdet3d.structures import Box3DMode, LiDARInstance3DBoxes, points_cam2img
from mmdet3d.structures import Box3DMode, CameraInstance3DBoxes, points_cam2img
from mmdet3d.structures.ops import box_np_ops

kitti_categories = ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
Expand Down Expand Up @@ -318,23 +318,21 @@ def get_kitti_style_2d_boxes(info: dict,
def convert_annos(info: dict, cam_idx: int) -> dict:
"""Convert front-cam anns to i-th camera (KITTI-style info)."""
rect = info['calib']['R0_rect'].astype(np.float32)
if cam_idx == 0:
lidar2cami = info['calib']['Tr_velo_to_cam'].astype(np.float32)
else:
lidar2cami = info['calib'][f'Tr_velo_to_cam{cam_idx}'].astype(
np.float32)
lidar2cam0 = info['calib']['Tr_velo_to_cam'].astype(np.float32)
lidar2cami = info['calib'][f'Tr_velo_to_cam{cam_idx}'].astype(np.float32)
annos = info['annos']
converted_annos = copy.deepcopy(annos)
loc = annos['location']
dims = annos['dimensions']
rots = annos['rotation_y']
gt_bboxes_3d = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1).astype(np.float32)
# BC-breaking: gt_bboxes_3d is already in lidar coordinates
# convert gt_bboxes_3d to velodyne coordinates
gt_bboxes_3d = CameraInstance3DBoxes(gt_bboxes_3d).convert_to(
Box3DMode.LIDAR, np.linalg.inv(rect @ lidar2cam0), correct_yaw=True)
# convert gt_bboxes_3d to cam coordinates
gt_bboxes_3d = LiDARInstance3DBoxes(gt_bboxes_3d).convert_to(
gt_bboxes_3d = gt_bboxes_3d.convert_to(
Box3DMode.CAM, rect @ lidar2cami, correct_yaw=True).numpy()

converted_annos['location'] = gt_bboxes_3d[:, :3]
converted_annos['dimensions'] = gt_bboxes_3d[:, 3:6]
converted_annos['rotation_y'] = gt_bboxes_3d[:, 6]
Expand Down
14 changes: 10 additions & 4 deletions projects/DSVT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,24 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-

### Training commands

The support of training DSVT is on the way.
In MMDetection3D's root directory, run the following command to test the model:

```bash
tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch
```

## Results and models

### Waymo

| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :-------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | 75.5 | 72.4 | 69.2 | 66.3 | [log](<>) |

**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.

**Note** Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/). However, we can provide the training logs.

## Citation

```latex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
voxel_size = [0.32, 0.32, 6]
grid_size = [468, 468, 1]
point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4.0]
# data_root = 'data/waymo_mini/kitti_format/'
data_root = 'data/waymo/kitti_format/'
class_names = ['Car', 'Pedestrian', 'Cyclist']
metainfo = dict(classes=class_names)
Expand Down Expand Up @@ -194,7 +193,7 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_wo_cam_ins_infos_train.pkl',
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
Expand All @@ -216,7 +215,7 @@
type=dataset_type,
data_root=data_root,
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
ann_file='waymo_wo_cam_ins_infos_val.pkl',
ann_file='waymo_infos_val.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
Expand All @@ -225,32 +224,23 @@
backend_args=backend_args))
test_dataloader = val_dataloader

# val_evaluator = dict(
# type='WaymoMetric',
# ann_file='./data/waymo_mini/kitti_format/waymo_infos_val.pkl',
# waymo_bin_file='./data/waymo_mini/waymo_format/gt_mini.bin',
# backend_args=backend_args,
# convert_kitti_format=False)
val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
backend_args=backend_args,
convert_kitti_format=False)
result_prefix='./dsvt_pred')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
# vis_backends = [dict(type='LocalVisBackend')]
# vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# schedules
lr = 1e-5
# This schedule is mainly used by models on nuScenes dataset
# max_norm=10 is better for SECOND
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=10, norm_type=2))
# learning rate
param_scheduler = [
dict(
type='CosineAnnealingLR',
Expand Down
7 changes: 0 additions & 7 deletions tools/dataset_converters/kitti_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import mmcv
import mmengine
import numpy as np
from mmengine import logging, print_log
from nuscenes.utils.geometry_utils import view_points

from mmdet3d.structures import points_cam2img
Expand Down Expand Up @@ -250,12 +249,6 @@ def create_waymo_info_file(data_path,
max_sweeps (int, optional): Max sweeps before the detection frame
to be used. Default: 5.
"""
print_log(
'Deprecation Warning: related functions has been migrated to '
'`Waymo2KITTI.create_waymo_info_file`. It will be removed in '
'the future!',
logger='current',
level=logging.WARNING)
imageset_folder = Path(data_path) / 'ImageSets'
train_img_ids = _read_imageset_file(str(imageset_folder / 'train.txt'))
val_img_ids = _read_imageset_file(str(imageset_folder / 'val.txt'))
Expand Down
7 changes: 1 addition & 6 deletions tools/dataset_converters/kitti_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,10 @@ def get_label_anno(label_path):
for x in content]).reshape(-1, 3)
annotations['rotation_y'] = np.array([float(x[14])
for x in content]).reshape(-1)
if len(content) != 0 and len(content[0]) >= 16: # have score
if len(content) != 0 and len(content[0]) == 16: # have score
annotations['score'] = np.array([float(x[15]) for x in content])
else:
annotations['score'] = np.zeros((annotations['bbox'].shape[0], ))
# have num_lidar_points_in_box, given in waymo
if len(content) != 0 and len(content[0]) == 17:
annotations['num_points_in_gt'] = np.array(
[int(x[16]) for x in content])
index = list(range(num_objects)) + [-1] * (num_gt - num_objects)
annotations['index'] = np.array(index, dtype=np.int32)
annotations['group_ids'] = np.arange(num_gt, dtype=np.int32)
Expand Down Expand Up @@ -556,7 +552,6 @@ def gather(self, image_ids):
image_infos = mmengine.track_parallel_progress(self.gather_single,
image_ids,
self.num_worker)
# image_infos = mmengine.track_progress(self.gather_single, image_ids)
return list(image_infos)


Expand Down

0 comments on commit 49a564a

Please sign in to comment.