diff --git a/mmdet3d/datasets/convert_utils.py b/mmdet3d/datasets/convert_utils.py index 66dba29117..cb4d97e137 100644 --- a/mmdet3d/datasets/convert_utils.py +++ b/mmdet3d/datasets/convert_utils.py @@ -10,7 +10,7 @@ from shapely.geometry import MultiPoint, box from shapely.geometry.polygon import Polygon -from mmdet3d.structures import Box3DMode, LiDARInstance3DBoxes, points_cam2img +from mmdet3d.structures import Box3DMode, CameraInstance3DBoxes, points_cam2img from mmdet3d.structures.ops import box_np_ops kitti_categories = ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck', @@ -318,11 +318,8 @@ def get_kitti_style_2d_boxes(info: dict, def convert_annos(info: dict, cam_idx: int) -> dict: """Convert front-cam anns to i-th camera (KITTI-style info).""" rect = info['calib']['R0_rect'].astype(np.float32) - if cam_idx == 0: - lidar2cami = info['calib']['Tr_velo_to_cam'].astype(np.float32) - else: - lidar2cami = info['calib'][f'Tr_velo_to_cam{cam_idx}'].astype( - np.float32) + lidar2cam0 = info['calib']['Tr_velo_to_cam'].astype(np.float32) + lidar2cami = info['calib'][f'Tr_velo_to_cam{cam_idx}'].astype(np.float32) annos = info['annos'] converted_annos = copy.deepcopy(annos) loc = annos['location'] @@ -330,11 +327,12 @@ def convert_annos(info: dict, cam_idx: int) -> dict: rots = annos['rotation_y'] gt_bboxes_3d = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1).astype(np.float32) - # BC-breaking: gt_bboxes_3d is already in lidar coordinates + # convert gt_bboxes_3d to velodyne coordinates + gt_bboxes_3d = CameraInstance3DBoxes(gt_bboxes_3d).convert_to( + Box3DMode.LIDAR, np.linalg.inv(rect @ lidar2cam0), correct_yaw=True) # convert gt_bboxes_3d to cam coordinates - gt_bboxes_3d = LiDARInstance3DBoxes(gt_bboxes_3d).convert_to( + gt_bboxes_3d = gt_bboxes_3d.convert_to( Box3DMode.CAM, rect @ lidar2cami, correct_yaw=True).numpy() - converted_annos['location'] = gt_bboxes_3d[:, :3] converted_annos['dimensions'] = gt_bboxes_3d[:, 3:6] converted_annos['rotation_y'] = gt_bboxes_3d[:, 6] diff --git a/projects/DSVT/README.md b/projects/DSVT/README.md index a4b45b570d..c1d8cf2744 100644 --- a/projects/DSVT/README.md +++ b/projects/DSVT/README.md @@ -57,18 +57,24 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1- ### Training commands -The support of training DSVT is on the way. +In MMDetection3D's root directory, run the following command to test the model: + +```bash +tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch +``` ## Results and models ### Waymo -| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | -| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: | -| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | | | 75.2 | 72.2 | 68.9 | 66.1 | | +| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | +| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :-------: | +| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | 75.5 | 72.4 | 69.2 | 66.3 | [log](<>) | **Note** that `ResSECOND` denotes the base block in SECOND has residual layers. +**Note** Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/). However, we can provide the training logs. + ## Citation ```latex diff --git a/projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py b/projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py index 9d0be465e8..bf56259183 100644 --- a/projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py +++ b/projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py @@ -6,7 +6,6 @@ voxel_size = [0.32, 0.32, 6] grid_size = [468, 468, 1] point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4.0] -# data_root = 'data/waymo_mini/kitti_format/' data_root = 'data/waymo/kitti_format/' class_names = ['Car', 'Pedestrian', 'Cyclist'] metainfo = dict(classes=class_names) @@ -194,7 +193,7 @@ dataset=dict( type=dataset_type, data_root=data_root, - ann_file='waymo_wo_cam_ins_infos_train.pkl', + ann_file='waymo_infos_train.pkl', data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'), pipeline=train_pipeline, modality=input_modality, @@ -216,7 +215,7 @@ type=dataset_type, data_root=data_root, data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'), - ann_file='waymo_wo_cam_ins_infos_val.pkl', + ann_file='waymo_infos_val.pkl', pipeline=test_pipeline, modality=input_modality, test_mode=True, @@ -225,32 +224,23 @@ backend_args=backend_args)) test_dataloader = val_dataloader -# val_evaluator = dict( -# type='WaymoMetric', -# ann_file='./data/waymo_mini/kitti_format/waymo_infos_val.pkl', -# waymo_bin_file='./data/waymo_mini/waymo_format/gt_mini.bin', -# backend_args=backend_args, -# convert_kitti_format=False) val_evaluator = dict( type='WaymoMetric', - ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl', waymo_bin_file='./data/waymo/waymo_format/gt.bin', - backend_args=backend_args, - convert_kitti_format=False) + result_prefix='./dsvt_pred') test_evaluator = val_evaluator -vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')] -# vis_backends = [dict(type='LocalVisBackend')] +# vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')] +vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# schedules lr = 1e-5 -# This schedule is mainly used by models on nuScenes dataset -# max_norm=10 is better for SECOND optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)), clip_grad=dict(max_norm=10, norm_type=2)) -# learning rate param_scheduler = [ dict( type='CosineAnnealingLR', diff --git a/tools/dataset_converters/kitti_converter.py b/tools/dataset_converters/kitti_converter.py index e904918f60..367cfd7ba9 100644 --- a/tools/dataset_converters/kitti_converter.py +++ b/tools/dataset_converters/kitti_converter.py @@ -5,7 +5,6 @@ import mmcv import mmengine import numpy as np -from mmengine import logging, print_log from nuscenes.utils.geometry_utils import view_points from mmdet3d.structures import points_cam2img @@ -250,12 +249,6 @@ def create_waymo_info_file(data_path, max_sweeps (int, optional): Max sweeps before the detection frame to be used. Default: 5. """ - print_log( - 'Deprecation Warning: related functions has been migrated to ' - '`Waymo2KITTI.create_waymo_info_file`. It will be removed in ' - 'the future!', - logger='current', - level=logging.WARNING) imageset_folder = Path(data_path) / 'ImageSets' train_img_ids = _read_imageset_file(str(imageset_folder / 'train.txt')) val_img_ids = _read_imageset_file(str(imageset_folder / 'val.txt')) diff --git a/tools/dataset_converters/kitti_data_utils.py b/tools/dataset_converters/kitti_data_utils.py index b0e90bd8d3..64c3bc415b 100644 --- a/tools/dataset_converters/kitti_data_utils.py +++ b/tools/dataset_converters/kitti_data_utils.py @@ -148,14 +148,10 @@ def get_label_anno(label_path): for x in content]).reshape(-1, 3) annotations['rotation_y'] = np.array([float(x[14]) for x in content]).reshape(-1) - if len(content) != 0 and len(content[0]) >= 16: # have score + if len(content) != 0 and len(content[0]) == 16: # have score annotations['score'] = np.array([float(x[15]) for x in content]) else: annotations['score'] = np.zeros((annotations['bbox'].shape[0], )) - # have num_lidar_points_in_box, given in waymo - if len(content) != 0 and len(content[0]) == 17: - annotations['num_points_in_gt'] = np.array( - [int(x[16]) for x in content]) index = list(range(num_objects)) + [-1] * (num_gt - num_objects) annotations['index'] = np.array(index, dtype=np.int32) annotations['group_ids'] = np.arange(num_gt, dtype=np.int32) @@ -556,7 +552,6 @@ def gather(self, image_ids): image_infos = mmengine.track_parallel_progress(self.gather_single, image_ids, self.num_worker) - # image_infos = mmengine.track_progress(self.gather_single, image_ids) return list(image_infos)