diff --git a/projects/PETR/petr/transforms_3d.py b/projects/PETR/petr/transforms_3d.py index fa43f4950b..a15f7f3028 100644 --- a/projects/PETR/petr/transforms_3d.py +++ b/projects/PETR/petr/transforms_3d.py @@ -175,35 +175,33 @@ def transform(self, results): return results def rotate_bev_along_z(self, results, angle): - rot_cos = torch.cos(torch.tensor(angle)) - rot_sin = torch.sin(torch.tensor(angle)) + rot_cos = np.cos(angle) + rot_sin = np.sin(angle) - rot_mat = torch.tensor([[rot_cos, -rot_sin, 0, 0], - [rot_sin, rot_cos, 0, 0], [0, 0, 1, 0], - [0, 0, 0, 1]]) - rot_mat_inv = torch.inverse(rot_mat) + rot_mat = np.array([[rot_cos, rot_sin, 0, 0], + [-rot_sin, rot_cos, 0, 0], [0, 0, 1, 0], + [0, 0, 0, 1]]) + rot_mat_inv = np.linalg.inverse(rot_mat) num_view = len(results['lidar2cam']) for view in range(num_view): results['lidar2cam'][view] = ( - torch.tensor(np.array(results['lidar2cam'][view]).T).float() - @ rot_mat_inv).T.numpy() + results['lidar2cam'][view] @ rot_mat_inv) return def scale_xyz(self, results, scale_ratio): - rot_mat = torch.tensor([ + scale_mat = np.array([ [scale_ratio, 0, 0, 0], [0, scale_ratio, 0, 0], [0, 0, scale_ratio, 0], [0, 0, 0, 1], ]) - rot_mat_inv = torch.inverse(rot_mat) + scale_mat_inv = np.linalg.inverse(scale_mat) num_view = len(results['lidar2cam']) for view in range(num_view): - results['lidar2cam'][view] = (torch.tensor( - rot_mat_inv.T - @ results['lidar2cam'][view].T).float()).T.numpy() + results['lidar2cam'][view] = ( + scale_mat_inv @ results['lidar2cam'][view]) return diff --git a/projects/PETR/petr/utils.py b/projects/PETR/petr/utils.py index edf2b763c9..292e406561 100644 --- a/projects/PETR/petr/utils.py +++ b/projects/PETR/petr/utils.py @@ -14,17 +14,19 @@ def normalize_bbox(bboxes, pc_range): width = bboxes[..., 4:5].log() height = bboxes[..., 5:6].log() + # normalize boxes to match the old checkpoints trained prior to + # coordinate system refactoring ( 7: vx = bboxes[..., 7:8] vy = bboxes[..., 8:9] normalized_bboxes = torch.cat( - (cx, cy, length, width, cz, height, rot.sin(), rot.cos(), vx, vy), + (cx, cy, width, length, cz, height, rot.sin(), rot.cos(), vx, vy), dim=-1) else: normalized_bboxes = torch.cat( - (cx, cy, length, width, cz, height, rot.sin(), rot.cos()), dim=-1) + (cx, cy, width, length, cz, height, rot.sin(), rot.cos()), dim=-1) return normalized_bboxes @@ -42,9 +44,11 @@ def denormalize_bbox(normalized_bboxes, pc_range): cy = normalized_bboxes[..., 1:2] cz = normalized_bboxes[..., 4:5] + # boxes are expected to match the format from old checkpoints, + # denormalize them to match the new format # size - length = normalized_bboxes[..., 2:3] - width = normalized_bboxes[..., 3:4] + width = normalized_bboxes[..., 2:3] + length = normalized_bboxes[..., 3:4] height = normalized_bboxes[..., 5:6] width = width.exp()