diff --git a/configs/gupnet/gupnet_dla34_kitti.yml b/configs/gupnet/gupnet_dla34_kitti.yml new file mode 100644 index 00000000..7c0fddc2 --- /dev/null +++ b/configs/gupnet/gupnet_dla34_kitti.yml @@ -0,0 +1,52 @@ +batch_size: 32 +epochs: 140 + +train_dataset: + type: GUPKittiMonoDataset + dataset_root: /root/kitti + use_3d_center: True + class_name: ['Pedestrian', 'Car', 'Cyclist'] + resolution: [1280, 384] + random_flip: 0.5 + random_crop: 0.5 + scale: 0.4 + shift: 0.1 + mode: train + +val_dataset: + type: GUPKittiMonoDataset + dataset_root: /root/kitti + use_3d_center: True + class_name: ['Pedestrian', 'Car', 'Cyclist'] + resolution: [1280, 384] + random_flip: 0.5 + random_crop: 0.5 + scale: 0.4 + shift: 0.1 + mode: val + +optimizer: + type: Adam + weight_decay: 0.00001 + +lr_scheduler: + type: CosineWarmupMultiStepDecayByEpoch + warmup_steps: 5 + learning_rate: 0.00125 + milestones: [90, 120] + decay_rate: 0.1 + start_lr: 0.00001 + +model: + type: GUPNET + backbone: + type: GUP_DLA34 + down_ratio: 4 + pretrained: ./checkpoint_root/dla34-ba72cf86_paddle_new.pdparams + head: + type: GUPNETPredictor + head_conv: 256 + threshold: 0.2 + stat_epoch_nums: 5 + max_epoch: 140 + train_datasets_length: 3712 diff --git a/paddle3d/datasets/__init__.py b/paddle3d/datasets/__init__.py index ea096a4b..c88ad8e5 100644 --- a/paddle3d/datasets/__init__.py +++ b/paddle3d/datasets/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .base import BaseDataset -from .kitti import KittiDepthDataset, KittiMonoDataset, KittiPCDataset +from .kitti import KittiDepthDataset, KittiMonoDataset, KittiPCDataset, GUPKittiMonoDataset from .modelnet40 import ModelNet40 from .nuscenes import NuscenesMVDataset, NuscenesPCDataset, NuscenesMVSegDataset from .waymo import WaymoPCDataset diff --git a/paddle3d/datasets/kitti/__init__.py b/paddle3d/datasets/kitti/__init__.py index a7a5fd08..d060bf4f 100644 --- a/paddle3d/datasets/kitti/__init__.py +++ b/paddle3d/datasets/kitti/__init__.py @@ -15,3 +15,4 @@ from .kitti_depth_det import KittiDepthDataset from .kitti_mono_det import KittiMonoDataset from .kitti_pointcloud_det import KittiPCDataset +from .kitti_gupnet import GUPKittiMonoDataset diff --git a/paddle3d/datasets/kitti/kitti_gupnet.py b/paddle3d/datasets/kitti/kitti_gupnet.py new file mode 100644 index 00000000..72a5d939 --- /dev/null +++ b/paddle3d/datasets/kitti/kitti_gupnet.py @@ -0,0 +1,532 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from PIL import Image +from typing import List, Dict +from paddle3d.datasets.kitti.kitti_det import KittiDetDataset +from paddle3d.datasets.kitti.kitti_utils import Object3d, Calibration, box_lidar_to_camera, filter_fake_result, camera_record_to_object +from paddle3d.datasets.kitti.kitti_gupnet_utils import get_affine_transform, affine_transform, gaussian_radius, draw_umich_gaussian, angle2class +from paddle3d.apis import manager +from paddle3d.sample import Sample +from paddle3d.datasets.metrics import MetricABC +from paddle3d.geometries.bbox import (BBoxes2D, BBoxes3D, CoordMode, + project_to_image) +from paddle3d.thirdparty import kitti_eval +from paddle3d.utils.logger import logger + + +@manager.DATASETS.add_component +class GUPKittiMonoDataset(KittiDetDataset): + """ + """ + + def __init__(self, + dataset_root, + use_3d_center=True, + class_name=['Pedestrian', 'Car', 'Cyclist'], + resolution=[1280, 384], + random_flip=0.5, + random_crop=0.5, + scale=0.4, + shift=0.1, + mode='train'): + super().__init__(dataset_root=dataset_root, mode=mode) + self.dataset_root = dataset_root + # basic configuration + self.num_classes = 3 + self.max_objs = 50 + self.class_name = class_name + self.cls2id = {'Pedestrian': 0, 'Car': 1, 'Cyclist': 2} + self.resolution = np.array(resolution) # W * H + self.use_3d_center = use_3d_center + + # l,w,h + self.cls_mean_size = np.array( + [[1.76255119, 0.66068622, 0.84422524], + [1.52563191462, 1.62856739989, 3.88311640418], + [1.73698127, 0.59706367, 1.76282397]]) + + # data mode loading + assert mode in ['train', 'val', 'trainval', 'test'] + self.mode = mode.lower() + split_dir = os.path.join(dataset_root, 'ImageSets', mode + '.txt') + self.idx_list = [x.strip() for x in open(split_dir).readlines()] + + # path configuration + self.image_dir = os.path.join(self.base_dir, 'image_2') + # data augmentation configuration + self.data_augmentation = True if mode in ['train', 'trainval' + ] else False + self.random_flip = random_flip + self.random_crop = random_crop + self.scale = scale + self.shift = shift + + # statistics + self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) + self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) + + # others + self.downsample = 4 + + def get_image(self, idx): + img_file = os.path.join(self.image_dir, '%06d.png' % idx) + # print(img_file) + assert os.path.exists(img_file) + return Image.open(img_file) # (H, W, 3) RGB mode + + def get_label(self, idx): + label_file = os.path.join(self.label_dir, '%06d.txt' % idx) + assert os.path.exists(label_file) + return self.get_objects_from_label(label_file) + + def get_calib(self, idx): + calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx) + # assert os.path.exists(calib_file) + if not os.path.exists(calib_file): + print('Non-exist: ', calib_file) + calib = self.get_calib_from_file(calib_file) + return Calibration(calib) + + def get_objects_from_label(self, label_file): + with open(label_file, 'r') as f: + lines = f.readlines() + objects = [Object3d(line) for line in lines] + return objects + + def get_calib_from_file(self, calib_file): + with open(calib_file) as f: + lines = f.readlines() + + obj = lines[2].strip().split(' ')[1:] + P2 = np.array(obj, dtype=np.float32) + obj = lines[3].strip().split(' ')[1:] + P3 = np.array(obj, dtype=np.float32) + obj = lines[4].strip().split(' ')[1:] + R0 = np.array(obj, dtype=np.float32) + obj = lines[5].strip().split(' ')[1:] + Tr_velo_to_cam = np.array(obj, dtype=np.float32) + + return { + 'P2': P2.reshape(3, 4), + 'P3': P3.reshape(3, 4), + 'R0': R0.reshape(3, 3), + 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4) + } + + def __len__(self): + return self.idx_list.__len__() + + def __getitem__(self, item): + # get samples + filename = '{}.png'.format(self.data[item]) + path = os.path.join(self.image_dir, filename) + calibs = self.load_calibration_info(item) + + sample = Sample(path=path, modality="image") + # P2 + sample.meta.camera_intrinsic = calibs[2][:3, :3] + sample.meta.id = self.data[item] + sample.calibs = calibs + + kitti_records, _ = self.load_annotation(item) + bboxes_2d, bboxes_3d, labels = camera_record_to_object(kitti_records) + + sample.bboxes_2d = bboxes_2d + sample.bboxes_3d = bboxes_3d + sample.labels = np.array([self.CLASS_MAP[label] for label in labels], + dtype=np.int32) + + # get inputs + index = int(self.idx_list[item]) # index mapping, get real data id + # image loading + img = self.get_image(index) + img_size = np.array(img.size) + + # data augmentation for image + center = np.array(img_size) / 2 + crop_size = img_size + random_flip_flag = False + if self.data_augmentation: + if np.random.random() < self.random_flip: + random_flip_flag = True + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if np.random.random() < self.random_crop: + crop_size = img_size * \ + np.clip(np.random.randn() * self.scale + + 1, 1 - self.scale, 1 + self.scale) + center[0] += img_size[0] * \ + np.clip(np.random.randn() * self.shift, - + 2 * self.shift, 2 * self.shift) + center[1] += img_size[1] * \ + np.clip(np.random.randn() * self.shift, - + 2 * self.shift, 2 * self.shift) + + # add affine transformation for 2d images. + trans, trans_inv = get_affine_transform( + center, crop_size, 0, self.resolution, inv=1) + img = img.transform( + tuple(self.resolution.tolist()), + method=Image.AFFINE, + data=tuple(trans_inv.reshape(-1).tolist()), + resample=Image.BILINEAR) + coord_range = np.array([center - crop_size / 2, + center + crop_size / 2]).astype(np.float32) + + # image encoding + img = np.array(img).astype(np.float32) / 255.0 + img = (img - self.mean) / self.std + img = img.transpose(2, 0, 1) # C * H * W + + # get calib + calib = self.get_calib(index) + + features_size = self.resolution // self.downsample # W * H + # ============================ get labels ============================== + if self.mode != 'test': + objects = self.get_label(index) + # data augmentation for labels + if random_flip_flag: + calib.flip(img_size) + for object in objects: + [x1, _, x2, _] = object.box2d + object.box2d[0], object.box2d[2] = img_size[0] - \ + x2, img_size[0] - x1 + object.ry = np.pi - object.ry + object.loc[0] *= -1 + if object.ry > np.pi: + object.ry -= 2 * np.pi + if object.ry < -np.pi: + object.ry += 2 * np.pi + # labels encoding + heatmap = np.zeros( + (self.num_classes, features_size[1], features_size[0]), + dtype=np.float32) # C * H * W + size_2d = np.zeros((self.max_objs, 2), dtype=np.float32) + offset_2d = np.zeros((self.max_objs, 2), dtype=np.float32) + depth = np.zeros((self.max_objs, 1), dtype=np.float32) + heading_bin = np.zeros((self.max_objs, 1), dtype=np.int64) + heading_res = np.zeros((self.max_objs, 1), dtype=np.float32) + src_size_3d = np.zeros((self.max_objs, 3), dtype=np.float32) + size_3d = np.zeros((self.max_objs, 3), dtype=np.float32) + offset_3d = np.zeros((self.max_objs, 2), dtype=np.float32) + cls_ids = np.zeros((self.max_objs), dtype=np.int64) + indices = np.zeros((self.max_objs), dtype=np.int64) + mask_2d = np.zeros((self.max_objs), dtype=np.uint8) + object_num = len( + objects) if len(objects) < self.max_objs else self.max_objs + for i in range(object_num): + # filter objects by class_name + if objects[i].cls_type not in self.class_name: + continue + + # filter inappropriate samples by difficulty + if objects[i].level_str == 'UnKnown' or objects[i].loc[-1] < 2: + continue + + # process 2d bbox & get 2d center + bbox_2d = objects[i].box2d.copy() + # add affine transformation for 2d boxes. + bbox_2d[:2] = affine_transform(bbox_2d[:2], trans) + bbox_2d[2:] = affine_transform(bbox_2d[2:], trans) + # modify the 2d bbox according to pre-compute downsample ratio + bbox_2d[:] /= self.downsample + + # process 3d bbox & get 3d center + center_2d = np.array([(bbox_2d[0] + bbox_2d[2]) / 2, + (bbox_2d[1] + bbox_2d[3]) / 2], + dtype=np.float32) # W * H + # real 3D center in 3D space + center_3d = objects[i].loc + [0, -objects[i].h / 2, 0] + center_3d = center_3d.reshape(-1, 3) # shape adjustment (N, 3) + # project 3D center to image plane + center_3d, _ = calib.rect_to_img(center_3d) + center_3d = center_3d[0] # shape adjustment + center_3d = affine_transform(center_3d.reshape(-1), trans) + center_3d /= self.downsample + + # generate the center of gaussian heatmap [optional: 3d center or 2d center] + center_heatmap = center_3d.astype( + np.int32) if self.use_3d_center else center_2d.astype( + np.int32) + if center_heatmap[0] < 0 or center_heatmap[0] >= features_size[ + 0]: + continue + if center_heatmap[1] < 0 or center_heatmap[1] >= features_size[ + 1]: + continue + + # generate the radius of gaussian heatmap + w, h = bbox_2d[2] - bbox_2d[0], bbox_2d[3] - bbox_2d[1] + radius = gaussian_radius((w, h)) + radius = max(0, int(radius)) + + if objects[i].cls_type in ['Van', 'Truck', 'DontCare']: + draw_umich_gaussian(heatmap[1], center_heatmap, radius) + continue + + cls_id = self.cls2id[objects[i].cls_type] + cls_ids[i] = cls_id + draw_umich_gaussian(heatmap[cls_id], center_heatmap, radius) + + # encoding 2d/3d offset & 2d size + indices[i] = center_heatmap[1] * \ + features_size[0] + center_heatmap[0] + offset_2d[i] = center_2d - center_heatmap + size_2d[i] = 1. * w, 1. * h + + # encoding depth + depth[i] = objects[i].loc[-1] + + # encoding heading angle + # heading_angle = objects[i].alpha + heading_angle = calib.ry2alpha( + objects[i].ry, + (objects[i].box2d[0] + objects[i].box2d[2]) / 2) + if heading_angle > np.pi: + heading_angle -= 2 * np.pi # check range + if heading_angle < -np.pi: + heading_angle += 2 * np.pi + # Convert continuous angle to discrete class and residual + heading_bin[i], heading_res[i] = angle2class(heading_angle) + + # encoding 3d offset & size_3d + offset_3d[i] = center_3d - center_heatmap + src_size_3d[i] = np.array( + [objects[i].h, objects[i].w, objects[i].l], + dtype=np.float32) + mean_size = self.cls_mean_size[self.cls2id[objects[i].cls_type]] + size_3d[i] = src_size_3d[i] - mean_size + + # objects[i].trucation <=0.5 and objects[i].occlusion<=2 and (objects[i].box2d[3]-objects[i].box2d[1])>=25: + if objects[i].truncation <= 0.5 and objects[i].occlusion <= 2: + mask_2d[i] = 1 + + targets = { + 'depth': depth, + 'size_2d': size_2d, + 'heatmap': heatmap, + 'offset_2d': offset_2d, + 'indices': indices, + 'size_3d': size_3d, + 'offset_3d': offset_3d, + 'heading_bin': heading_bin, + 'heading_res': heading_res, + 'cls_ids': cls_ids, + 'mask_2d': mask_2d + } + else: + targets = {} + # collect return data + inputs = img + info = { + 'img_id': index, + 'img_size': img_size, + 'bbox_downsample_ratio': img_size / features_size + } + + return inputs, calib.P2, coord_range, targets, info, sample + + @property + def name(self) -> str: + return "KITTI" + + @property + def labels(self) -> List[str]: + return self.class_name + + @property + def metric(self): + gt = [] + for idx in range(len(self)): + annos = self.load_annotation(idx) + if len(annos[0]) > 0 and len(annos[1]) > 0: + gt.append(np.concatenate((annos[0], annos[1]), axis=0)) + elif len(annos[0]) > 0: + gt.append(annos[0]) + else: + gt.append(annos[1]) + return GUPKittiMetric( + groundtruths=gt, + classmap={i: name + for i, name in enumerate(self.class_names)}, + indexes=self.data) + + +class GUPKittiMetric(MetricABC): + def __init__(self, groundtruths: List[np.ndarray], classmap: Dict[int, str], + indexes: List): + self.gt_annos = groundtruths + self.predictions = [] + self.classmap = classmap + self.indexes = indexes + + def _parse_gt_to_eval_format(self, + groundtruths: List[np.ndarray]) -> List[dict]: + res = [] + for rows in groundtruths: + if rows.size == 0: + res.append({ + 'name': np.zeros([0]), + 'truncated': np.zeros([0]), + 'occluded': np.zeros([0]), + 'alpha': np.zeros([0]), + 'bbox': np.zeros([0, 4]), + 'dimensions': np.zeros([0, 3]), + 'location': np.zeros([0, 3]), + 'rotation_y': np.zeros([0]), + 'score': np.zeros([0]) + }) + else: + res.append({ + 'name': rows[:, 0], + 'truncated': rows[:, 1].astype(np.float64), + 'occluded': rows[:, 2].astype(np.int64), + 'alpha': rows[:, 3].astype(np.float64), + 'bbox': rows[:, 4:8].astype(np.float64), + 'dimensions': rows[:, [10, 8, 9]].astype(np.float64), + 'location': rows[:, 11:14].astype(np.float64), + 'rotation_y': rows[:, 14].astype(np.float64) + }) + + return res + + def get_camera_box2d(self, bboxes_3d: BBoxes3D, proj_mat: np.ndarray): + box_corners = bboxes_3d.corners_3d + box_corners_in_image = project_to_image(box_corners, proj_mat) + minxy = np.min(box_corners_in_image, axis=1) + maxxy = np.max(box_corners_in_image, axis=1) + box_2d_preds = BBoxes2D(np.concatenate([minxy, maxxy], axis=1)) + + return box_2d_preds + + def _parse_predictions_to_eval_format( + self, predictions: List[Sample]) -> List[dict]: + res = {} + for pred in predictions: + filter_fake_result(pred) + id = pred.meta.id + if pred.bboxes_3d is None: + det = { + 'truncated': np.zeros([0]), + 'occluded': np.zeros([0]), + 'alpha': np.zeros([0]), + 'name': np.zeros([0]), + 'bbox': np.zeros([0, 4]), + 'dimensions': np.zeros([0, 3]), + 'location': np.zeros([0, 3]), + 'rotation_y': np.zeros([0]), + 'score': np.zeros([0]), + } + else: + num_boxes = pred.bboxes_3d.shape[0] + names = np.array( + [self.classmap[label] for label in pred.labels]) + calibs = pred.calibs + + alpha = pred.get('alpha', np.zeros([num_boxes])) + + if pred.bboxes_3d.coordmode != CoordMode.KittiCamera: + bboxes_3d = box_lidar_to_camera(pred.bboxes_3d, calibs) + else: + bboxes_3d = pred.bboxes_3d + + if bboxes_3d.origin != [.5, 1., .5]: + bboxes_3d[:, :3] += bboxes_3d[:, 3:6] * ( + np.array([.5, 1., .5]) - np.array(bboxes_3d.origin)) + bboxes_3d.origin = [.5, 1., .5] + + if pred.bboxes_2d is None: + bboxes_2d = self.get_camera_box2d(bboxes_3d, calibs[2]) + else: + bboxes_2d = pred.bboxes_2d + + loc = bboxes_3d[:, :3] + dim = bboxes_3d[:, 3:6] + + det = { + # fake value + 'truncated': np.zeros([num_boxes]), + 'occluded': np.zeros([num_boxes]), + # predict value + 'alpha': alpha, + 'name': names, + 'bbox': bboxes_2d, + 'dimensions': dim, + # TODO: coord trans + 'location': loc, + 'rotation_y': bboxes_3d[:, 6], + 'score': pred.confidences, + } + + res[id] = det + + return [res[idx] for idx in self.indexes] + + def update(self, predictions: List[Sample], **kwargs): + """ + """ + self.predictions += predictions + + def compute(self, verbose=False, **kwargs) -> dict: + """ + """ + gt_annos = self._parse_gt_to_eval_format(self.gt_annos) + dt_annos = self._parse_predictions_to_eval_format(self.predictions) + + if len(dt_annos) != len(gt_annos): + raise RuntimeError( + 'The number of predictions({}) is not equal to the number of GroundTruths({})' + .format(len(dt_annos), len(gt_annos))) + + metric_r40_dict = kitti_eval( + gt_annos, + dt_annos, + current_classes=list(self.classmap.values()), + metric_types=["bbox", "bev", "3d"], + recall_type='R40') + + metric_r11_dict = kitti_eval( + gt_annos, + dt_annos, + current_classes=list(self.classmap.values()), + metric_types=["bbox", "bev", "3d"], + recall_type='R11') + + if verbose: + for cls, cls_metrics in metric_r40_dict.items(): + logger.info("{}:".format(cls)) + for overlap_thresh, metrics in cls_metrics.items(): + for metric_type, thresh in zip(["bbox", "bev", "3d"], + overlap_thresh): + if metric_type in metrics: + logger.info( + "{} AP_R40@{:.0%}: {:.2f} {:.2f} {:.2f}".format( + metric_type.upper().ljust(4), thresh, + *metrics[metric_type])) + + for cls, cls_metrics in metric_r11_dict.items(): + logger.info("{}:".format(cls)) + for overlap_thresh, metrics in cls_metrics.items(): + for metric_type, thresh in zip(["bbox", "bev", "3d"], + overlap_thresh): + if metric_type in metrics: + logger.info( + "{} AP_R11@{:.0%}: {:.2f} {:.2f} {:.2f}".format( + metric_type.upper().ljust(4), thresh, + *metrics[metric_type])) + return metric_r40_dict, metric_r11_dict diff --git a/paddle3d/datasets/kitti/kitti_gupnet_utils.py b/paddle3d/datasets/kitti/kitti_gupnet_utils.py new file mode 100644 index 00000000..77e1f9fa --- /dev/null +++ b/paddle3d/datasets/kitti/kitti_gupnet_utils.py @@ -0,0 +1,186 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + +################### affine trainsform ################### + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + scale = np.array([scale, scale], dtype=np.float32) + + scale_tmp = scale + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + trans_inv = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + return trans, trans_inv + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + return trans + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def roty(t): + ''' Rotation about the y-axis. ''' + c = np.cos(t) + s = np.sin(t) + return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]) + + +def compute_box_3d(obj, calib): + ''' Takes an object and a projection matrix (P) and projects the 3d + bounding box into the image plane. + Returns: + corners_2d: (8,2) array in left image coord. + corners_3d: (8,3) array in in rect camera coord. + ''' + # compute rotational matrix around yaw axis + R = roty(obj.ry) + + # 3d bounding box dimensions + l = obj.l + w = obj.w + h = obj.h + + # 3d bounding box corners + x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] + y_corners = [0, 0, 0, 0, -h, -h, -h, -h] + # y_corners = [h/2,h/2,h/2,h/2,-h/2,-h/2,-h/2,-h/2] + z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] + + # rotate and translate 3d bounding box + corners_3d = np.dot(R, np.vstack([x_corners, y_corners, z_corners])) + # print corners_3d.shape + corners_3d[0, :] = corners_3d[0, :] + obj.pos[0] + corners_3d[1, :] = corners_3d[1, :] + obj.pos[1] + corners_3d[2, :] = corners_3d[2, :] + obj.pos[2] + + return np.transpose(corners_3d) + + +################### affine trainsform ################### + + +################### heatmap gaussian ################### +def gaussian_radius(bbox_size, min_overlap=0.7): + height, width = bbox_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1**2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2**2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3**2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + return min(r1, r2, r3) + + +def gaussian2D(shape, sigma=1): + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m + 1, -n:n + 1] + + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h + + +def draw_umich_gaussian(heatmap, center, radius, k=1): + diameter = 2 * radius + 1 + gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) + x, y = int(center[0]), int(center[1]) + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - + left:radius + right] + if min(masked_gaussian.shape) > 0 and min( + masked_heatmap.shape) > 0: # TODO debug + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + return heatmap + + +################### heatmap gaussian ################### + + +################### others ################### +def angle2class(angle): + ''' Convert continuous angle to discrete class and residual. ''' + # 超参数, 将角度分为12类(360° / 12 = 30°) + num_heading_bin = 12 # hyper param + angle = angle % (2 * np.pi) + assert (angle >= 0 and angle <= 2 * np.pi) + angle_per_class = 2 * np.pi / float(num_heading_bin) + shifted_angle = (angle + angle_per_class / 2) % (2 * np.pi) + class_id = int(shifted_angle / angle_per_class) + residual_angle = shifted_angle - \ + (class_id * angle_per_class + angle_per_class / 2) + return class_id, residual_angle diff --git a/paddle3d/datasets/kitti/kitti_utils.py b/paddle3d/datasets/kitti/kitti_utils.py index 3404ef63..5b2cf4d4 100644 --- a/paddle3d/datasets/kitti/kitti_utils.py +++ b/paddle3d/datasets/kitti/kitti_utils.py @@ -352,8 +352,8 @@ def generate_corners3d(self): def to_str(self): print_str = '%s %.3f %.3f %.3f box2d: %s hwl: [%.3f %.3f %.3f] pos: %s ry: %.3f' \ - % (self.cls_type, self.truncation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l, - self.loc, self.ry) + % (self.cls_type, self.truncation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l, + self.loc, self.ry) return print_str def to_kitti_format(self): @@ -480,3 +480,74 @@ def corners3d_to_img_boxes(self, corners3d): (x.reshape(-1, 8, 1), y.reshape(-1, 8, 1)), axis=2) return boxes, boxes_corner + + # GUPNET relative function + def alpha2ry(self, alpha, u): + """ + Get rotation_y by alpha + theta - 180 + alpha : Observation angle of object, ranging [-pi..pi] + x : Object center x to the camera center (x-W/2), in pixels + rotation_y : Rotation ry around Y-axis in camera coordinates [-pi..pi] + """ + ry = alpha + np.arctan2(u - self.cu, self.fu) + + if ry > np.pi: + ry -= 2 * np.pi + if ry < -np.pi: + ry += 2 * np.pi + + return ry + + # GUPNET relative function + def ry2alpha(self, ry, u): + alpha = ry - np.arctan2(u - self.cu, self.fu) + + if alpha > np.pi: + alpha -= 2 * np.pi + if alpha < -np.pi: + alpha += 2 * np.pi + + return alpha + + # GUPNET relative function + def flip(self, img_size): + wsize = 4 + hsize = 2 + p2ds = (np.concatenate([ + np.expand_dims( + np.tile( + np.expand_dims(np.linspace(0, img_size[0], wsize), 0), + [hsize, 1]), -1), + np.expand_dims( + np.tile( + np.expand_dims(np.linspace(0, img_size[1], hsize), 1), + [1, wsize]), -1), + np.linspace(2, 78, wsize * hsize).reshape(hsize, wsize, 1) + ], -1)).reshape(-1, 3) + p3ds = self.img_to_rect(p2ds[:, 0:1], p2ds[:, 1:2], p2ds[:, 2:3]) + p3ds[:, 0] *= -1 + p2ds[:, 0] = img_size[0] - p2ds[:, 0] + + # self.P2[0,3] *= -1 + cos_matrix = np.zeros([wsize * hsize, 2, 7]) + cos_matrix[:, 0, 0] = p3ds[:, 0] + cos_matrix[:, 0, 1] = cos_matrix[:, 1, 2] = p3ds[:, 2] + cos_matrix[:, 1, 0] = p3ds[:, 1] + cos_matrix[:, 0, 3] = cos_matrix[:, 1, 4] = 1 + cos_matrix[:, :, -2] = -p2ds[:, :2] + cos_matrix[:, :, -1] = (-p2ds[:, :2] * p3ds[:, 2:3]) + new_calib = np.linalg.svd(cos_matrix.reshape(-1, 7))[-1][-1] + new_calib /= new_calib[-1] + + new_calib_matrix = np.zeros([4, 3]).astype(np.float32) + new_calib_matrix[0, 0] = new_calib_matrix[1, 1] = new_calib[0] + new_calib_matrix[2, 0:2] = new_calib[1:3] + new_calib_matrix[3, :] = new_calib[3:6] + new_calib_matrix[-1, -1] = self.P2[-1, -1] + self.P2 = new_calib_matrix.T + self.cu = self.P2[0, 2] + self.cv = self.P2[1, 2] + self.fu = self.P2[0, 0] + self.fv = self.P2[1, 1] + self.tx = self.P2[0, 3] / (-self.fu) + self.ty = self.P2[1, 3] / (-self.fv) diff --git a/paddle3d/models/detection/__init__.py b/paddle3d/models/detection/__init__.py index ed2a218c..67a0aeb9 100644 --- a/paddle3d/models/detection/__init__.py +++ b/paddle3d/models/detection/__init__.py @@ -26,3 +26,4 @@ from .voxel_rcnn import * from .bevdet import * from .rtebev import * +from .gupnet import * diff --git a/paddle3d/models/detection/gupnet/__init__.py b/paddle3d/models/detection/gupnet/__init__.py new file mode 100644 index 00000000..9c505362 --- /dev/null +++ b/paddle3d/models/detection/gupnet/__init__.py @@ -0,0 +1 @@ +from .gupnet import GUPNET diff --git a/paddle3d/models/detection/gupnet/gupnet.py b/paddle3d/models/detection/gupnet/gupnet.py new file mode 100644 index 00000000..0565142a --- /dev/null +++ b/paddle3d/models/detection/gupnet/gupnet.py @@ -0,0 +1,227 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from typing import List +from paddle3d.apis import manager +from paddle3d.models.detection.gupnet.gupnet_dla import GUP_DLA34 +from paddle3d.models.detection.gupnet.gupnet_processor import GUPNETPostProcessor +from paddle3d.models.detection.gupnet.gupnet_predictor import GUPNETPredictor +from paddle3d.models.detection.gupnet.gupnet_loss import GUPNETLoss, Hierarchical_Task_Learning +from paddle3d.models.base import BaseMonoModel +from paddle3d.geometries import BBoxes2D, BBoxes3D, CoordMode +from paddle3d.sample import Sample + + +@manager.MODELS.add_component +class GUPNET(BaseMonoModel): + """ + """ + + def __init__(self, + backbone, + head, + max_detection: int = 50, + threshold=0.2, + stat_epoch_nums=5, + max_epoch=140, + train_datasets_length=3712): + super(GUPNET, self).__init__() + + self.max_detection = max_detection + self.train_datasets_length = train_datasets_length + mean_size = np.array([[1.76255119, 0.66068622, 0.84422524], + [1.52563191462, 1.62856739989, 3.88311640418], + [1.73698127, 0.59706367, 1.76282397]]) + self.mean_size = paddle.to_tensor(mean_size, dtype=paddle.float32) + self.cls_num = self.mean_size.shape[0] + self.backbone = backbone + self.head = head + self.loss = GUPNETLoss() + self.ei_loss = { + 'seg_loss': paddle.to_tensor(110.), + 'offset2d_loss': paddle.to_tensor(1.6), + 'size2d_loss': paddle.to_tensor(30.), + 'depth_loss': paddle.to_tensor(8.5), + 'offset3d_loss': paddle.to_tensor(0.6), + 'size3d_loss': paddle.to_tensor(0.7), + 'heading_loss': paddle.to_tensor(3.6) + } + self.cur_loss = paddle.zeros(paddle.to_tensor(1)) + self.cur_loss_weightor = Hierarchical_Task_Learning( + self.ei_loss, stat_epoch_nums=stat_epoch_nums, max_epoch=max_epoch) + self.post_processor = GUPNETPostProcessor(mean_size, threshold) + + # TODO: fix export function + def export_forward(self, samples): + images = samples['images'] + features = self.backbone(images) + + if isinstance(features, (list, tuple)): + features = features[-1] + + predictions = self.heads(features) + return self.post_process.export_forward( + predictions, [samples['trans_cam_to_img'], samples['down_ratios']]) + + def train_forward(self, samples): + # encode epoch + if not hasattr(self, 'cur_epoch'): + self.cur_epoch = 1 + self.pre_epoches = 1 + self.loss_weights = {} # 初始化loss权重 + + input, calibs_p2, coord_ranges, targets, info, sample = samples + + if not hasattr(self, 'have_load_img_ids'): + self.have_load_img_ids = info['img_id'] + self.trained_batch = 1 + self.stat_dict = {} + + elif info['img_id'][0] not in self.have_load_img_ids: + self.have_load_img_ids = paddle.concat((self.have_load_img_ids, + info['img_id'])) + self.trained_batch += 1 + else: + self.cur_epoch += 1 + del self.have_load_img_ids + del self.trained_batch + del self.stat_dict + self.have_load_img_ids = info['img_id'] + self.trained_batch = 1 + self.stat_dict = {} + + feat = self.backbone(input) + ret = self.head(feat, targets, calibs_p2, coord_ranges, is_train=True) + + loss_terms = self.loss(ret, targets) + + if not self.loss_weights: + self.loss_weights = self.cur_loss_weightor.compute_weight( + self.ei_loss, self.cur_epoch) + elif self.cur_epoch != self.pre_epoches: + self.loss_weights = self.cur_loss_weightor.compute_weight( + self.ei_loss, self.cur_epoch) + self.pre_epoches += 1 + + # update loss with loss_weights + loss = paddle.zeros(paddle.to_tensor(1)) + for key in self.loss_weights.keys(): + loss += self.loss_weights[key].detach() * loss_terms[key] + + # accumulate statistics + for key in loss_terms.keys(): + if key not in self.stat_dict.keys(): + self.stat_dict[key] = 0 + self.stat_dict[key] += loss_terms[key] + + if len(self.have_load_img_ids) == self.train_datasets_length: + for key in self.stat_dict.keys(): + self.stat_dict[key] /= self.trained_batch + self.ei_loss = self.stat_dict + + return {'loss': loss} + + def test_forward(self, samples): + input, calibs_p2, coord_ranges, targets, info, sample = samples + feat = self.backbone(input) + ret = self.head(feat, targets, calibs_p2, coord_ranges, is_train=False) + predictions = self.post_processor(ret, info, calibs_p2) + + res = [] + for id, img_id in enumerate(predictions.keys()): + res.append( + self._parse_results_to_sample(predictions[img_id], sample, id)) + + return {'preds': res} + + def _parse_results_to_sample(self, results: paddle.Tensor, sample: dict, + index: int): + ret = Sample(sample['path'][index], sample['modality'][index]) + ret.meta.update( + {key: value[index] + for key, value in sample['meta'].items()}) + + if 'calibs' in sample: + ret.calibs = [ + sample['calibs'][i][index] + for i in range(len(sample['calibs'])) + ] + + if len(results): + results = paddle.to_tensor(results) + + results = results.numpy() + clas = results[:, 0] + bboxes_2d = BBoxes2D(results[:, 2:6]) + + bboxes_3d = BBoxes3D( + results[:, [9, 10, 11, 8, 6, 7, 12]], + coordmode=CoordMode.KittiCamera, + origin=(0.5, 1, 0.5), + rot_axis=1) + + confidences = results[:, 13] + + ret.confidences = confidences + ret.bboxes_2d = bboxes_2d + ret.bboxes_3d = bboxes_3d + + for i in range(len(clas)): + clas[i] = clas[i] - 1 if clas[i] > 0 else 2 + ret.labels = clas + + return ret + + def init_weight(self, layers): + for m in layers.sublayers(): + if isinstance(m, nn.Conv2D): + nn.init.normal_(m.weight, std=0.001) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @property + def inputs(self) -> List[dict]: + images = { + 'name': 'images', + 'dtype': 'float32', + 'shape': [1, 3, self.image_height, self.image_width] + } + res = [images] + + intrinsics = { + 'name': 'trans_cam_to_img', + 'dtype': 'float32', + 'shape': [1, 3, 3] + } + res.append(intrinsics) + + down_ratios = { + 'name': 'down_ratios', + 'dtype': 'float32', + 'shape': [1, 2] + } + res.append(down_ratios) + return res + + @property + def outputs(self) -> List[dict]: + data = { + 'name': 'gupnet_output', + 'dtype': 'float32', + 'shape': [self.max_detection, 14] + } + return [data] diff --git a/paddle3d/models/detection/gupnet/gupnet_dla.py b/paddle3d/models/detection/gupnet/gupnet_dla.py new file mode 100644 index 00000000..96898d2a --- /dev/null +++ b/paddle3d/models/detection/gupnet/gupnet_dla.py @@ -0,0 +1,499 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math +import paddle +import paddle.nn as nn +from paddle3d.utils import checkpoint +from paddle3d.utils.logger import logger +from paddle3d.apis import manager + +__all__ = ["GUP_DLA", "GUP_DLA34"] + + +def _make_conv_level(in_channels, out_channels, num_convs, stride=1, + dilation=1): + """ + make conv layers based on its number. + """ + layers = [] + for i in range(num_convs): + layers.extend([ + nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias_attr=False, + dilation=dilation), + nn.BatchNorm2D(out_channels), + nn.ReLU() + ]) + + in_channels = out_channels + + return nn.Sequential(*layers) + + +class Conv2d(nn.Layer): + def __init__(self, + in_planes, + out_planes, + kernal_szie=3, + stride=1, + bias=True): + super(Conv2d, self).__init__() + self.conv = nn.Conv2D( + in_planes, + out_planes, + kernel_size=kernal_szie, + stride=stride, + padding=kernal_szie // 2, + bias_attr=bias) + self.bn = nn.BatchNorm2D(out_planes) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class BasicBlock(nn.Layer): + """Basic Block + """ + + def __init__(self, in_channels, out_channels, stride=1, dilation=1): + super().__init__() + + self.conv1 = nn.Conv2D( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + bias_attr=False, + dilation=dilation) + self.bn1 = nn.BatchNorm2D(out_channels) + + self.relu = nn.ReLU() + + self.conv2 = nn.Conv2D( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=dilation, + bias_attr=False, + dilation=dilation) + self.bn2 = nn.BatchNorm2D(out_channels) + + def forward(self, x, residual=None): + """forward + """ + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class Tree(nn.Layer): + def __init__(self, + level, + block, + in_channels, + out_channels, + stride=1, + level_root=False, + root_dim=0, + root_kernel_size=1, + dilation=1, + root_residual=False): + super(Tree, self).__init__() + + if root_dim == 0: + root_dim = 2 * out_channels + + if level_root: + root_dim += in_channels + + if level == 1: + self.tree1 = block( + in_channels, out_channels, stride, dilation=dilation) + + self.tree2 = block( + out_channels, out_channels, stride=1, dilation=dilation) + else: + new_level = level - 1 + self.tree1 = Tree( + new_level, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual) + + self.tree2 = Tree( + new_level, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual) + if level == 1: + self.root = Root(root_dim, out_channels, root_kernel_size, + root_residual) + + self.level_root = level_root + self.root_dim = root_dim + self.level = level + + self.downsample = None + if stride > 1: + self.downsample = nn.MaxPool2D(stride, stride=stride) + + self.project = None + # If 'self.tree1' is a Tree (not BasicBlock), then the output of project is not used. + # if in_channels != out_channels and not isinstance(self.tree1, Tree): + if in_channels != out_channels: # 和 + self.project = nn.Sequential( + nn.Conv2D( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias_attr=False), nn.BatchNorm2D(out_channels)) + + def forward(self, x, residual=None, children=None): + """forward + """ + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + + if self.level == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class Root(nn.Layer): + """Root module + """ + + def __init__(self, in_channels, out_channels, kernel_size, residual): + super(Root, self).__init__() + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias_attr=False, + padding=(kernel_size - 1) // 2) + self.bn = nn.BatchNorm2D(out_channels) + self.relu = nn.ReLU() + self.residual = residual + + def forward(self, *x): + """forward + """ + + children = x + x = self.conv(paddle.concat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class GUP_DLABase(nn.Layer): + """DLA base module + """ + + def __init__(self, + levels, + channels, + block, + down_ratio=4, + last_level=5, + residual_root=False): + super().__init__() + assert down_ratio in [2, 4, 8, 16] + self.channels = channels + self.level_length = len(levels) + self.first_level = int(np.log2(down_ratio)) + self.last_level = last_level + if block is None: + block = BasicBlock + else: + block = eval(block) + + self.base_layer = nn.Sequential( + nn.Conv2D( + 3, + channels[0], + kernel_size=7, + stride=1, + padding=3, + bias_attr=False), nn.BatchNorm2D(channels[0]), nn.ReLU()) + + self.level0 = _make_conv_level( + in_channels=channels[0], + out_channels=channels[0], + num_convs=levels[0]) + + self.level1 = _make_conv_level( + in_channels=channels[0], + out_channels=channels[1], + num_convs=levels[1], + stride=2) + + self.level2 = Tree( + level=levels[2], + block=block, + in_channels=channels[1], + out_channels=channels[2], + stride=2, + level_root=False, + root_residual=residual_root) + + self.level3 = Tree( + level=levels[3], + block=block, + in_channels=channels[2], + out_channels=channels[3], + stride=2, + level_root=True, + root_residual=residual_root) + + self.level4 = Tree( + level=levels[4], + block=block, + in_channels=channels[3], + out_channels=channels[4], + stride=2, + level_root=True, + root_residual=residual_root) + + self.level5 = Tree( + level=levels[5], + block=block, + in_channels=channels[4], + out_channels=channels[5], + stride=2, + level_root=True, + root_residual=residual_root) + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m.weight.shape[0] * m.weight.shape[1] * m.weight.shape[2] + v = np.random.normal( + loc=0., scale=np.sqrt(2. / n), + size=m.weight.shape).astype('float32') + m.weight.set_value(v) + elif isinstance(m, nn.BatchNorm2D): + m.weight.set_value(np.ones(m.weight.shape).astype('float32')) + m.bias.set_value(np.zeros(m.bias.shape).astype('float32')) + + def forward(self, x): + """forward + """ + y = [] + x = self.base_layer(x) + + for i in range(self.level_length): + x = getattr(self, 'level{}'.format(i))(x) + y.append(x) + + return y + + def load_pretrained_model(self, path): + checkpoint.load_pretrained_model(self, path) + + +class GUP_DLAUp(nn.Layer): + """DLA Up module + """ + + def __init__(self, in_channels_list, scales_list=(1, 2, 4, 8, 16)): + super(GUP_DLAUp, self).__init__() + scales_list = np.array(scales_list, dtype=int) + + for i in range(len(in_channels_list) - 1): + j = -i - 2 + setattr( + self, 'ida_{}'.format(i), + GUP_IDAUp( + in_channels_list=in_channels_list[j:], + up_factors_list=scales_list[j:] // scales_list[j], + out_channels=in_channels_list[j])) + scales_list[j + 1:] = scales_list[j] + in_channels_list[j + 1:] = [ + in_channels_list[j] for _ in in_channels_list[j + 1:] + ] + + def forward(self, layers): + layers = list(layers) + assert len(layers) > 1 + for i in range(len(layers) - 1): + ida = getattr(self, 'ida_{}'.format(i)) + layers[-i - 2:] = ida(layers[-i - 2:]) + return layers[-1] + + +class GUP_IDAUp(nn.Layer): + ''' + input: features map of different layers + output: up-sampled features + ''' + + def __init__(self, in_channels_list, up_factors_list, out_channels): + super(GUP_IDAUp, self).__init__() + self.in_channels_list = in_channels_list + self.out_channels = out_channels + + for i in range(1, len(in_channels_list)): + in_channels = in_channels_list[i] + up_factors = int(up_factors_list[i]) + + proj = Conv2d( + in_channels, out_channels, kernal_szie=3, stride=1, bias=False) + node = Conv2d( + out_channels * 2, + out_channels, + kernal_szie=3, + stride=1, + bias=False) + up = nn.Conv2DTranspose( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=up_factors * 2, + stride=up_factors, + padding=up_factors // 2, + output_padding=0, + groups=out_channels, + bias_attr=False) + # self.fill_up_weights(up) + + setattr(self, 'proj_' + str(i), proj) + setattr(self, 'up_' + str(i), up) + setattr(self, 'node_' + str(i), node) + + # weight init + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m.weight.shape[0] * m.weight.shape[1] * m.weight.shape[2] + v = np.random.normal( + loc=0., scale=np.sqrt(2. / n), + size=m.weight.shape).astype('float32') + m.weight.set_value(v) + elif isinstance(m, nn.BatchNorm2D): + m.weight.set_value(np.ones(m.weight.shape).astype('float32')) + m.bias.set_value(np.zeros(m.bias.shape).astype('float32')) + + # weight init for up-sample layers [tranposed conv2d] + def fill_up_weights(self, up): + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + def forward(self, layers): + assert len(self.in_channels_list) == len(layers), \ + '{} vs {} layers'.format(len(self.in_channels_list), len(layers)) + + for i in range(1, len(layers)): + upsample = getattr(self, 'up_' + str(i)) + project = getattr(self, 'proj_' + str(i)) + node = getattr(self, 'node_' + str(i)) + + layers[i] = upsample(project(layers[i])) + layers[i] = node(paddle.concat([layers[i - 1], layers[i]], 1)) + + return layers + + +@manager.BACKBONES.add_component +class GUP_DLA(nn.Layer): + """DLA base module + """ + + def __init__(self, levels, channels, block, down_ratio=4, pretrained=None): + super().__init__() + assert down_ratio in [2, 4, 8, 16] + + self.pretrained = pretrained + self.first_level = int(np.log2(down_ratio)) + self.base = GUP_DLABase(levels, channels, block) + + # 只加载特征提取网络部分参数 + self.load_pretrained_model() + + scales = [2**i for i in range(len(channels[self.first_level:]))] + self.dla_up = GUP_DLAUp( + in_channels_list=channels[self.first_level:], scales_list=scales) + + def forward(self, x): + """forward + """ + x = self.base(x) + feat = self.dla_up(x[self.first_level:]) + return feat + + def load_pretrained_model(self): + if self.pretrained is not None: + checkpoint.load_pretrained_model(self, self.pretrained) + + +@manager.BACKBONES.add_component +def GUP_DLA34(**kwargs): + model = GUP_DLA( + levels=[1, 1, 1, 2, 2, 1], + channels=[16, 32, 64, 128, 256, 512], + block="BasicBlock", + **kwargs) + + return model diff --git a/paddle3d/models/detection/gupnet/gupnet_helper.py b/paddle3d/models/detection/gupnet/gupnet_helper.py new file mode 100644 index 00000000..c520b12c --- /dev/null +++ b/paddle3d/models/detection/gupnet/gupnet_helper.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn.functional as F +import paddle + + +def _nms(heatmap, kernel=3): + padding = (kernel - 1) // 2 + heatmapmax = F.max_pool2d( + heatmap, (kernel, kernel), stride=1, padding=padding) + keep = (heatmapmax == heatmap).astype('float32') + return heatmap * keep + + +def _topk(heatmap, K=50): + batch, cat, height, width = heatmap.shape + + # batch * cls_ids * 50 + topk_scores, topk_inds = paddle.topk(heatmap.reshape((batch, cat, -1)), K) + topk_inds = topk_inds % (height * width) + topk_ys = (topk_inds / width).astype('int32').astype('float32') + topk_xs = (topk_inds % width).astype('int32').astype('float32') + + # batch * cls_ids * 50 + topk_score, topk_ind = paddle.topk(topk_scores.reshape((batch, -1)), K) + topk_cls_ids = (topk_ind / K).astype('int32') + topk_inds = _gather_feat(topk_inds.reshape((batch, -1, 1)), + topk_ind).reshape((batch, K)) + topk_ys = _gather_feat(topk_ys.reshape((batch, -1, 1)), topk_ind).reshape( + (batch, K)) + topk_xs = _gather_feat(topk_xs.reshape((batch, -1, 1)), topk_ind).reshape( + (batch, K)) + + return topk_score, topk_inds, topk_cls_ids, topk_xs, topk_ys + + +def _gather_feat(feat, ind, mask=None): + ''' + Args: + feat: tensor shaped in B * (H*W) * C + ind: tensor shaped in B * K (default: 50) + mask: tensor shaped in B * K (default: 50) + + Returns: tensor shaped in B * K or B * sum(mask) + ''' + dim = feat.shape[2] # get channel dim + # B*len(ind) --> B*len(ind)*1 --> B*len(ind)*C + ind = ind.unsqueeze(2).expand(shape=[ind.shape[0], ind.shape[1], dim]) + feat = paddle.take_along_axis(feat, indices=ind, axis=1) + + if mask is not None: + mask = mask.unsqueeze(2).expand_as(feat) # B*50 ---> B*K*1 --> B*K*C + feat = feat[mask] + feat = feat.view(-1, dim) + return feat + + +def _transpose_and_gather_feat(feat, ind): + ''' + Args: + feat: feature maps shaped in B * C * H * W + ind: indices tensor shaped in B * K + Returns: + ''' + feat = feat.transpose(perm=(0, 2, 3, 1)) # B * C * H * W ---> B * H * W * C + feat = feat.reshape((feat.shape[0], -1, + feat.shape[3])) # B * H * W * C ---> B * (H*W) * C + feat = _gather_feat(feat, ind) # B * len(ind) * C + return feat + + +def extract_input_from_tensor(input, ind, mask): + input = _transpose_and_gather_feat(input, ind) # B*C*H*W --> B*K*C + return input[mask] # B*K*C --> M * C + + +def extract_target_from_tensor(target, mask): + return target[mask] diff --git a/paddle3d/models/detection/gupnet/gupnet_loss.py b/paddle3d/models/detection/gupnet/gupnet_loss.py new file mode 100644 index 00000000..2486a93b --- /dev/null +++ b/paddle3d/models/detection/gupnet/gupnet_loss.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle3d.models.detection.gupnet.gupnet_helper import _transpose_and_gather_feat +from paddle3d.apis import manager + + +class Hierarchical_Task_Learning: + def __init__(self, epoch0_loss, stat_epoch_nums=5, max_epoch=140): + self.index2term = [*epoch0_loss.keys()] + self.term2index = { + term: self.index2term.index(term) + for term in self.index2term + } # term2index + # self.term2index: { + # 'seg_loss': 0, 'offset2d_loss': 1, 'size2d_loss': 2, + # 'depth_loss': 3, 'offset3d_loss': 4, 'size3d_loss': 5, 'heading_loss': 6} + self.stat_epoch_nums = stat_epoch_nums + self.past_losses = [] + self.loss_graph = { + 'seg_loss': [], + 'size2d_loss': [], + 'offset2d_loss': [], + 'offset3d_loss': ['size2d_loss', 'offset2d_loss'], + 'size3d_loss': ['size2d_loss', 'offset2d_loss'], + 'heading_loss': ['size2d_loss', 'offset2d_loss'], + 'depth_loss': ['size2d_loss', 'size3d_loss', 'offset2d_loss'] + } + self.max_epoch = max_epoch + + def compute_weight(self, current_loss, epoch): + # compute initial weights + loss_weights = {} + eval_loss_input = paddle.concat( + [_.unsqueeze(0) for _ in current_loss.values()]).unsqueeze(0) + for term in self.loss_graph: + if len(self.loss_graph[term]) == 0: + loss_weights[term] = paddle.to_tensor(1.0) + else: + loss_weights[term] = paddle.to_tensor(0.0) + + if len(self.past_losses) == self.stat_epoch_nums: + past_loss = paddle.concat(self.past_losses) + mean_diff = (past_loss[:-2] - past_loss[2:]).mean(0) + if not hasattr(self, 'init_diff'): + self.init_diff = mean_diff + c_weights = 1 - \ + paddle.nn.functional.relu_( + mean_diff / self.init_diff).unsqueeze(0) + time_value = min(((epoch - 5) / (self.max_epoch - 5)), 1.0) + for current_topic in self.loss_graph: + if len(self.loss_graph[current_topic]) != 0: + control_weight = 1.0 + for pre_topic in self.loss_graph[current_topic]: + control_weight *= c_weights[0][ + self.term2index[pre_topic]] + loss_weights[current_topic] = time_value**( + 1 - control_weight) + self.past_losses.pop(0) + self.past_losses.append(eval_loss_input) + return loss_weights + + def update_e0(self, eval_loss): + self.epoch0_loss = paddle.concat( + [_.unsqueeze(0) for _ in eval_loss.values()]).unsqueeze(0) + + +@manager.LOSSES.add_component +class GUPNETLoss(nn.Layer): + def __init__(self): + super().__init__() + self.stat = {} + + def forward(self, preds, targets): + ''' + Args: + preds: prediction {dict 9} + 'heatmap', 'offset_2d', 'size_2d', 'train_tag', + 'heading', 'depth', 'offset_3d', 'size_3d', 'h3d_log_variance' + + targets: ground truth {dict 11} + 'depth', 'size_2d', 'heatmap', 'offset_2d', 'indices', + 'size_3d', 'offset_3d', 'heading_bin', 'heading_res', 'cls_ids', 'mask_2d' + ''' + self.stat['seg_loss'] = self.compute_segmentation_loss(preds, targets) + self.stat['offset2d_loss'], self.stat[ + 'size2d_loss'] = self.compute_bbox2d_loss(preds, targets) + self.stat['depth_loss'], self.stat['offset3d_loss'], self.stat[ + 'size3d_loss'], self.stat[ + 'heading_loss'] = self.compute_bbox3d_loss(preds, targets) + return self.stat + + def compute_segmentation_loss(self, input, target): + input['heatmap'] = paddle.clip( + paddle.nn.functional.sigmoid(input['heatmap']), + min=1e-4, + max=1 - 1e-4) + loss = focal_loss_cornernet(input['heatmap'], target['heatmap']) + return loss + + def compute_bbox2d_loss(self, input, target): + # compute size2d loss + size2d_input = extract_input_from_tensor( + input['size_2d'], target['indices'], target['mask_2d']) + size2d_target = extract_target_from_tensor(target['size_2d'], + target['mask_2d']) + size2d_loss = F.l1_loss(size2d_input, size2d_target, reduction='mean') + # compute offset2d loss + offset2d_input = extract_input_from_tensor( + input['offset_2d'], target['indices'], target['mask_2d']) + offset2d_target = extract_target_from_tensor(target['offset_2d'], + target['mask_2d']) + offset2d_loss = F.l1_loss( + offset2d_input, offset2d_target, reduction='mean') + + return offset2d_loss, size2d_loss + + def compute_bbox3d_loss(self, input, target, mask_type='mask_2d'): + # compute depth loss + depth_input = input['depth'][input['train_tag']] + depth_input, depth_log_variance = depth_input[:, 0:1], depth_input[:, 1: + 2] + depth_target = extract_target_from_tensor(target['depth'], + target[mask_type]) + depth_loss = laplacian_aleatoric_uncertainty_loss( + depth_input, depth_target, depth_log_variance) + + # compute offset3d loss + offset3d_input = input['offset_3d'][input['train_tag']] + offset3d_target = extract_target_from_tensor(target['offset_3d'], + target[mask_type]) + offset3d_loss = F.l1_loss( + offset3d_input, offset3d_target, reduction='mean') + + # compute size3d loss + size3d_input = input['size_3d'][input['train_tag']] + size3d_target = extract_target_from_tensor(target['size_3d'], + target[mask_type]) + size3d_loss = F.l1_loss(size3d_input[:, 1:], size3d_target[:, 1:], reduction='mean') * 2 / 3 + \ + laplacian_aleatoric_uncertainty_loss(size3d_input[:, 0:1], size3d_target[:, 0:1], + input['h3d_log_variance'][input['train_tag']]) / 3 + heading_loss = compute_heading_loss( + input['heading'][input['train_tag']], + target[mask_type], # mask_2d + target['heading_bin'], + target['heading_res']) + + return depth_loss, offset3d_loss, size3d_loss, heading_loss + + +# ====================== auxiliary functions ======================= + + +def extract_input_from_tensor(input, ind, mask): + input = _transpose_and_gather_feat(input, ind) # B*C*H*W --> B*K*C + return input[mask.astype(paddle.bool)] # B*K*C --> M * C + + +def extract_target_from_tensor(target, mask): + return target[mask.astype(paddle.bool)] + + +# compute heading loss two stage style + + +def compute_heading_loss(input, mask, target_cls, target_reg): + mask = mask.reshape((1, -1)).squeeze(0) # B * K ---> (B*K) + target_cls = target_cls.reshape((1, -1)).squeeze(0) # B * K * 1 ---> (B*K) + target_reg = target_reg.reshape((1, -1)).squeeze(0) # B * K * 1 ---> (B*K) + + # classification loss + input_cls = input[:, 0:12] + target_cls = target_cls[mask.astype(paddle.bool)] + cls_loss = F.cross_entropy(input_cls, target_cls, reduction='mean') + + # regression loss + input_reg = input[:, 12:24] + target_reg = target_reg[mask.astype(paddle.bool)] + cls_onehot = paddle.put_along_axis( + arr=paddle.zeros([target_cls.shape[0], 12]), + axis=1, + indices=target_cls.reshape((-1, 1)), + values=paddle.to_tensor(1).astype('float32')) + input_reg = paddle.sum(input_reg * cls_onehot, 1) + reg_loss = F.l1_loss(input_reg, target_reg, reduction='mean') + + return cls_loss + reg_loss + + +def laplacian_aleatoric_uncertainty_loss(input, + target, + log_variance, + reduction='mean'): + ''' + References: + MonoPair: Monocular 3D Object Detection Using Pairwise Spatial Relationships, CVPR'20 + Geometry and Uncertainty in Deep Learning for Computer Vision, University of Cambridge + ''' + assert reduction in ['mean', 'sum'] + loss = 1.4142 * paddle.exp(-0.5 * log_variance) * \ + paddle.abs(input - target) + 0.5 * log_variance + return loss.mean() if reduction == 'mean' else loss.sum() + + +def focal_loss_cornernet(input, target, gamma=2.): + ''' + Args: + input: prediction, 'batch x c x h x w' + target: ground truth, 'batch x c x h x w' + gamma: hyper param, default in 2.0 + Reference: Cornernet: Detecting Objects as Paired Keypoints, ECCV'18 + ''' + pos_inds = paddle.equal(target, 1).astype('float32') + neg_inds = paddle.less_than(target, + paddle.ones(target.shape)).astype('float32') + # pos_inds = target.eq(1).float() + # neg_inds = target.lt(1).float() + neg_weights = paddle.pow(1 - target, 4) + + loss = 0 + pos_loss = paddle.log(input) * paddle.pow(1 - input, gamma) * pos_inds + neg_loss = paddle.log(1 - input) * paddle.pow( + input, gamma) * neg_inds * neg_weights + + num_pos = pos_inds.sum() + + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if num_pos == 0: + loss = loss - neg_loss + else: + loss = loss - (pos_loss + neg_loss) / num_pos + + return loss.mean() diff --git a/paddle3d/models/detection/gupnet/gupnet_predictor.py b/paddle3d/models/detection/gupnet/gupnet_predictor.py new file mode 100644 index 00000000..1889eaa6 --- /dev/null +++ b/paddle3d/models/detection/gupnet/gupnet_predictor.py @@ -0,0 +1,379 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from paddle.vision.ops import roi_align +from paddle3d.apis import manager +from paddle3d.models.detection.gupnet.gupnet_helper import _nms, _topk, extract_input_from_tensor +from paddle3d.models.layers import param_init + + +def fill_fc_weights(layer): + if isinstance(layer, nn.Conv2D): + param_init.normal_init(layer.weight, std=0.001) + if layer.bias is not None: + param_init.constant_init(layer.bias, value=0.0) + + +def weights_init_xavier(layer): + if isinstance(layer, nn.Linear): + param_init.normal_init(layer.weight, std=0.001) + param_init.constant_init(layer.bias, value=0.0) + + if isinstance(layer, nn.Conv2D): + param_init.xavier_uniform_init(layer.weight) + if layer.bias is not None: + param_init.constant_init(layer.bias, value=0.0) + + elif isinstance(layer, nn.BatchNorm2D): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) + + +@manager.MODELS.add_component +class GUPNETPredictor(nn.Layer): + """ + """ + + def __init__(self, + channels=[16, 32, 64, 128, 256, 512], + head_conv=256, + max_detection: int = 50, + downsample=4): + super(GUPNETPredictor, self).__init__() + self.max_detection = max_detection + self.head_conv = head_conv # default setting for head conv + self.first_level = int(np.log2(downsample)) + self.mean_size = np.array( + [[1.76255119, 0.66068622, 0.84422524], + [1.52563191462, 1.62856739989, 3.88311640418], + [1.73698127, 0.59706367, 1.76282397]]) + self.mean_size = paddle.to_tensor(self.mean_size, dtype=paddle.float32) + self.cls_num = self.mean_size.shape[0] + # initialize the head of pipeline, according to heads setting. + self.heatmap = nn.Sequential( + nn.Conv2D( + channels[self.first_level], + self.head_conv, + kernel_size=3, + padding=1, + bias_attr=True), nn.ReLU(), + nn.Conv2D( + self.head_conv, + 3, + kernel_size=1, + stride=1, + padding=0, + bias_attr=nn.initializer.Constant(value=-2.19))) + self.offset_2d = nn.Sequential( + nn.Conv2D( + channels[self.first_level], + self.head_conv, + kernel_size=3, + padding=1, + bias_attr=True), nn.ReLU(), + nn.Conv2D( + self.head_conv, + 2, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True)) + self.size_2d = nn.Sequential( + nn.Conv2D( + channels[self.first_level], + self.head_conv, + kernel_size=3, + padding=1, + bias_attr=True), nn.ReLU(), + nn.Conv2D( + self.head_conv, + 2, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True)) + + self.depth = nn.Sequential( + nn.Conv2D( + channels[self.first_level] + 2 + self.cls_num, + self.head_conv, + kernel_size=3, + padding=1, + bias_attr=True), nn.BatchNorm2D(self.head_conv), nn.ReLU(), + nn.AdaptiveAvgPool2D(1), + nn.Conv2D( + self.head_conv, + 2, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True)) + self.offset_3d = nn.Sequential( + nn.Conv2D( + channels[self.first_level] + 2 + self.cls_num, + self.head_conv, + kernel_size=3, + padding=1, + bias_attr=True), nn.BatchNorm2D(self.head_conv), nn.ReLU(), + nn.AdaptiveAvgPool2D(1), + nn.Conv2D( + self.head_conv, + 2, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True)) + self.size_3d = nn.Sequential( + nn.Conv2D( + channels[self.first_level] + 2 + self.cls_num, + self.head_conv, + kernel_size=3, + padding=1, + bias_attr=True), nn.BatchNorm2D(self.head_conv), nn.ReLU(), + nn.AdaptiveAvgPool2D(1), + nn.Conv2D( + self.head_conv, + 4, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True)) + self.heading = nn.Sequential( + nn.Conv2D( + channels[self.first_level] + 2 + self.cls_num, + self.head_conv, + kernel_size=3, + padding=1, + bias_attr=True), nn.BatchNorm2D(self.head_conv), nn.ReLU(), + nn.AdaptiveAvgPool2D(1), + nn.Conv2D( + self.head_conv, + 24, + kernel_size=1, + stride=1, + padding=0, + bias_attr=True)) + + # init layers + self.offset_2d.apply(fill_fc_weights) + self.size_2d.apply(fill_fc_weights) + self.depth.apply(weights_init_xavier) + self.offset_3d.apply(weights_init_xavier) + self.size_3d.apply(weights_init_xavier) + self.heading.apply(weights_init_xavier) + + def forward(self, features, targets, calibs_p2, coord_ranges, is_train): + ret = {} + ret['heatmap'] = self.heatmap(features) + ret['offset_2d'] = self.offset_2d(features) + ret['size_2d'] = self.size_2d(features) + if is_train: + inds, cls_ids = targets['indices'], targets['cls_ids'] + masks = targets['mask_2d'].astype(paddle.bool) + else: + inds, cls_ids = _topk( + _nms( + paddle.clip( + paddle.nn.functional.sigmoid(ret['heatmap']), + min=1e-4, + max=1 - 1e-4)), + K=self.max_detection)[1:3] + masks = paddle.ones(inds.shape).astype(paddle.bool) + ret.update( + self.get_roi_feat(features, inds, masks, ret, calibs_p2, + coord_ranges, cls_ids, self.max_detection)) + return ret + + def get_roi_feat_by_mask(self, feat, box2d_maps, inds, mask, calibs, + coord_ranges, cls_ids, K): + BATCH_SIZE, _, HEIGHT, WIDE = feat.shape + num_masked_bin = mask.sum() + res = {} + if num_masked_bin != 0: + # get box2d of each roi region + box2d_masked = extract_input_from_tensor(box2d_maps, inds, mask) + + # get roi feature + boxes_num = paddle.to_tensor([0] * BATCH_SIZE).astype('int32') + for x in box2d_masked[:, 0]: + boxes_num[x.astype('int32')] += 1 + roi_feature_masked = roi_align( + feat, + box2d_masked[:, 1:], + boxes_num=boxes_num, + output_size=[7, 7], + aligned=False) + + # get coord range of each roi + coord_ranges_mask2d = coord_ranges[box2d_masked[:, 0].astype( + paddle.int32)] + + # map box2d coordinate from feature map size domain to original image size domain + box2d_masked = paddle.concat([ + box2d_masked[:, 0:1], box2d_masked[:, 1:2] / WIDE * + (coord_ranges_mask2d[:, 1, 0:1] - coord_ranges_mask2d[:, 0, 0:1] + ) + coord_ranges_mask2d[:, 0, 0:1], + box2d_masked[:, 2:3] / HEIGHT * + (coord_ranges_mask2d[:, 1, 1:2] - coord_ranges_mask2d[:, 0, 1:2] + ) + coord_ranges_mask2d[:, 0, 1:2], + box2d_masked[:, 3:4] / WIDE * (coord_ranges_mask2d[:, 1, 0:1] - + coord_ranges_mask2d[:, 0, 0:1]) + + coord_ranges_mask2d[:, 0, 0:1], box2d_masked[:, 4:5] / HEIGHT * + (coord_ranges_mask2d[:, 1, 1:2] - coord_ranges_mask2d[:, 0, 1:2] + ) + coord_ranges_mask2d[:, 0, 1:2] + ], 1) + roi_calibs = calibs[box2d_masked[:, 0].astype(paddle.int32)] + + # project the coordinate in the normal image to the camera coord by calibs + coords_in_camera_coord = paddle.concat([ + self.project2rect( + roi_calibs, + paddle.concat([ + box2d_masked[:, 1:3], + paddle.ones([num_masked_bin, 1]) + ], -1))[:, :2], + self.project2rect( + roi_calibs, + paddle.concat([ + box2d_masked[:, 3:5], + paddle.ones([num_masked_bin, 1]) + ], -1))[:, :2] + ], -1) + + coords_in_camera_coord = paddle.concat( + [box2d_masked[:, 0:1], coords_in_camera_coord], -1) + # generate coord maps + coord_maps = paddle.concat([ + paddle.tile( + paddle.concat([ + coords_in_camera_coord[:, 1:2] + i * + (coords_in_camera_coord[:, 3:4] - + coords_in_camera_coord[:, 1:2]) / 6 for i in range(7) + ], -1).unsqueeze(1), + repeat_times=([1, 7, 1])).unsqueeze(1), + paddle.tile( + paddle.concat([ + coords_in_camera_coord[:, 2:3] + i * + (coords_in_camera_coord[:, 4:5] - + coords_in_camera_coord[:, 2:3]) / 6 for i in range(7) + ], -1).unsqueeze(2), + repeat_times=([1, 1, 7])).unsqueeze(1) + ], 1) + + # concatenate coord maps with feature maps in the channel dim + cls_hots = paddle.zeros([num_masked_bin, self.cls_num]) + cls_hots[paddle.arange(num_masked_bin), cls_ids[mask]. + astype(paddle.int32)] = 1.0 + + roi_feature_masked = paddle.concat([ + roi_feature_masked, coord_maps, + paddle.tile( + cls_hots.unsqueeze(-1).unsqueeze(-1), + repeat_times=([1, 1, 7, 7])) + ], 1) + + # compute heights of projected objects + box2d_height = paddle.clip( + box2d_masked[:, 4] - box2d_masked[:, 2], min=1.0) + # compute real 3d height + size3d_offset = self.size_3d(roi_feature_masked)[:, :, 0, 0] + h3d_log_std = size3d_offset[:, 3:4] + size3d_offset = size3d_offset[:, :3] + size_3d = (self.mean_size[cls_ids[mask].astype(paddle.int32)] + + size3d_offset) + depth_geo = size_3d[:, 0] / \ + box2d_height.squeeze() * roi_calibs[:, 0, 0] + depth_net_out = self.depth(roi_feature_masked)[:, :, 0, 0] + # σ_p^2 + depth_geo_log_std = ( + h3d_log_std.squeeze() + 2 * + (roi_calibs[:, 0, 0].log() - box2d_height.log())).unsqueeze(-1) + # log(σ_d^2) = log(σ_p^2 + σ_b^2) + depth_net_log_std = paddle.logsumexp( + paddle.concat([depth_net_out[:, 1:2], depth_geo_log_std], -1), + -1, + keepdim=True) + depth_net_out = paddle.concat( + [(1. / + (paddle.nn.functional.sigmoid(depth_net_out[:, 0:1]) + 1e-6) - + 1.) + depth_geo.unsqueeze(-1), depth_net_log_std], -1) + + res['train_tag'] = paddle.ones(num_masked_bin).astype(paddle.bool) + res['heading'] = self.heading(roi_feature_masked)[:, :, 0, 0] + res['depth'] = depth_net_out + res['offset_3d'] = self.offset_3d(roi_feature_masked)[:, :, 0, 0] + res['size_3d'] = size3d_offset + res['h3d_log_variance'] = h3d_log_std + else: + res['depth'] = paddle.zeros([1, 2]) + res['offset_3d'] = paddle.zeros([1, 2]) + res['size_3d'] = paddle.zeros([1, 3]) + res['train_tag'] = paddle.zeros([1]).astype(paddle.bool) + res['heading'] = paddle.zeros([1, 24]) + res['h3d_log_variance'] = paddle.zeros([1, 1]) + return res + + def get_roi_feat(self, feat, inds, mask, ret, calibs, coord_ranges, cls_ids, + K): + BATCH_SIZE, _, HEIGHT, WIDE = feat.shape + coord_map = paddle.tile( + paddle.concat([ + paddle.tile( + paddle.arange(WIDE).unsqueeze(0), + repeat_times=([HEIGHT, 1])).unsqueeze(0), + paddle.tile( + paddle.arange(HEIGHT).unsqueeze(-1), + repeat_times=([1, WIDE])).unsqueeze(0) + ], + axis=0).unsqueeze(0), + repeat_times=([BATCH_SIZE, 1, 1, 1])).astype('float32') + + box2d_centre = coord_map + ret['offset_2d'] + box2d_maps = paddle.concat([ + box2d_centre - ret['size_2d'] / 2, box2d_centre + ret['size_2d'] / 2 + ], 1) + box2d_maps = paddle.concat([ + paddle.tile( + paddle.arange(BATCH_SIZE).unsqueeze(-1).unsqueeze(-1).unsqueeze( + -1), + repeat_times=([1, 1, HEIGHT, WIDE])).astype('float32'), + box2d_maps + ], 1) + # box2d_maps is box2d in each bin + res = self.get_roi_feat_by_mask(feat, box2d_maps, inds, mask, calibs, + coord_ranges, cls_ids, K) + return res + + def project2rect(self, calib, point_img): + c_u = calib[:, 0, 2] + c_v = calib[:, 1, 2] + f_u = calib[:, 0, 0] + f_v = calib[:, 1, 1] + b_x = calib[:, 0, 3] / (-f_u) # relative + b_y = calib[:, 1, 3] / (-f_v) + x = (point_img[:, 0] - c_u) * point_img[:, 2] / f_u + b_x + y = (point_img[:, 1] - c_v) * point_img[:, 2] / f_v + b_y + z = point_img[:, 2] + centre_by_obj = paddle.concat( + [x.unsqueeze(-1), y.unsqueeze(-1), + z.unsqueeze(-1)], -1) + return centre_by_obj + + def logsumexp(self, x): + x_max = x.data.max() + return paddle.log(paddle.sum(paddle.exp(x - x_max), 1, + keepdim=True)) + x_max diff --git a/paddle3d/models/detection/gupnet/gupnet_processor.py b/paddle3d/models/detection/gupnet/gupnet_processor.py new file mode 100644 index 00000000..a459f9a9 --- /dev/null +++ b/paddle3d/models/detection/gupnet/gupnet_processor.py @@ -0,0 +1,203 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle import nn +from paddle3d.apis import manager +from paddle3d.models.detection.gupnet.gupnet_helper import _nms, _topk, _transpose_and_gather_feat + +num_heading_bin = 12 + + +@manager.MODELS.add_component +class GUPNETPostProcessor(nn.Layer): + def __init__(self, cls_mean_size, threshold): + super().__init__() + self.cls_mean_size = cls_mean_size + self.threshold = threshold + + def forward(self, ret, info, calibs_p2): + # prediction result convert + predictions = self.extract_dets_from_outputs(ret, K=50) + predictions = predictions.detach().cpu().numpy() + calibs_p2 = calibs_p2.detach().cpu().numpy() + # get corresponding calibs & transform tensor to numpy + info = {key: val.detach().cpu().numpy() for key, val in info.items()} + predictions = self.decode_detections( + dets=predictions, + info=info, + calibs_p2=calibs_p2, + cls_mean_size=self.cls_mean_size, + threshold=self.threshold) + + return predictions + + def img_to_rect(self, u, v, depth_rect): + """ + :param u: (N) + :param v: (N) + :param depth_rect: (N) + :return: + """ + x = ((u - self.cu) * depth_rect) / self.fu + self.tx + y = ((v - self.cv) * depth_rect) / self.fv + self.ty + pts_rect = np.concatenate( + (x.reshape(-1, 1), y.reshape(-1, 1), depth_rect.reshape(-1, 1)), + axis=1) + return pts_rect + + # GUPNET relative function + def alpha2ry(self, alpha, u): + """ + Get rotation_y by alpha + theta - 180 + alpha : Observation angle of object, ranging [-pi..pi] + x : Object center x to the camera center (x-W/2), in pixels + rotation_y : Rotation ry around Y-axis in camera coordinates [-pi..pi] + """ + ry = alpha + np.arctan2(u - self.cu, self.fu) + + if ry > np.pi: + ry -= 2 * np.pi + if ry < -np.pi: + ry += 2 * np.pi + + return ry + + def decode_detections(self, dets, info, calibs_p2, cls_mean_size, + threshold): + '''NOTE: THIS IS A NUMPY FUNCTION + input: dets, numpy array, shape in [batch x max_dets x dim] + input: img_info, dict, necessary information of input images + input: calibs, corresponding calibs for the input batch + output: + ''' + results = {} + for i in range(dets.shape[0]): # batch + preds = [] + for j in range(dets.shape[1]): # max_dets + # encoder calib + + self.cu = calibs_p2[i][0, 2] + self.cv = calibs_p2[i][1, 2] + self.fu = calibs_p2[i][0, 0] + self.fv = calibs_p2[i][1, 1] + self.tx = calibs_p2[i][0, 3] / (-self.fu) + self.ty = calibs_p2[i][1, 3] / (-self.fv) + + cls_id = int(dets[i, j, 0]) + score = dets[i, j, 1] + if score < threshold: + continue + + # 2d bboxs decoding + x = dets[i, j, 2] * info['bbox_downsample_ratio'][i][0] + y = dets[i, j, 3] * info['bbox_downsample_ratio'][i][1] + w = dets[i, j, 4] * info['bbox_downsample_ratio'][i][0] + h = dets[i, j, 5] * info['bbox_downsample_ratio'][i][1] + bbox = [x - w / 2, y - h / 2, x + w / 2, y + h / 2] + + # 3d bboxs decoding + # depth decoding + depth = dets[i, j, 6] + + # heading angle decoding + alpha = get_heading_angle(dets[i, j, 7:31]) + ry = self.alpha2ry(alpha, x) + # ry = calibs[i].alpha2ry(alpha, x) + + # dimensions decoding + dimensions = dets[i, j, 31:34] + dimensions += cls_mean_size[int(cls_id)] + if True in (dimensions < 0.0): + continue + + # positions decoding + x3d = dets[i, j, 34] * info['bbox_downsample_ratio'][i][0] + y3d = dets[i, j, 35] * info['bbox_downsample_ratio'][i][1] + locations = self.img_to_rect(x3d, y3d, depth).reshape(-1) + locations[1] += dimensions[0] / 2 + + preds.append([cls_id, alpha] + bbox + dimensions.tolist() + + locations.tolist() + [ry, score]) + results[info['img_id'][i]] = preds + return results + + # two stage style + def extract_dets_from_outputs(self, outputs, K=50): + # get src outputs + heatmap = outputs['heatmap'] + size_2d = outputs['size_2d'] + offset_2d = outputs['offset_2d'] + + batch, channel, height, width = heatmap.shape # get shape + + heading = outputs['heading'].reshape((batch, K, -1)) + depth = outputs['depth'].reshape((batch, K, -1))[:, :, 0:1] + size_3d = outputs['size_3d'].reshape((batch, K, -1)) + offset_3d = outputs['offset_3d'].reshape((batch, K, -1)) + + heatmap = paddle.clip( + paddle.nn.functional.sigmoid(heatmap), min=1e-4, max=1 - 1e-4) + + # perform nms on heatmaps + heatmap = _nms(heatmap) + scores, inds, cls_ids, xs, ys = _topk(heatmap, K=K) + + offset_2d = _transpose_and_gather_feat(offset_2d, inds) + offset_2d = offset_2d.reshape((batch, K, 2)) + xs2d = xs.reshape((batch, K, 1)) + offset_2d[:, :, 0:1] + ys2d = ys.reshape((batch, K, 1)) + offset_2d[:, :, 1:2] + + xs3d = xs.reshape((batch, K, 1)) + offset_3d[:, :, 0:1] + ys3d = ys.reshape((batch, K, 1)) + offset_3d[:, :, 1:2] + + cls_ids = cls_ids.reshape((batch, K, 1)).astype('float32') + depth_score = (-(0.5 * outputs['depth'].reshape( + (batch, K, -1))[:, :, 1:2]).exp()).exp() + scores = scores.reshape((batch, K, 1)) * depth_score + + # check shape + xs2d = xs2d.reshape((batch, K, 1)) + ys2d = ys2d.reshape((batch, K, 1)) + xs3d = xs3d.reshape((batch, K, 1)) + ys3d = ys3d.reshape((batch, K, 1)) + + size_2d = _transpose_and_gather_feat(size_2d, inds) + size_2d = size_2d.reshape((batch, K, 2)) + + detections = paddle.concat([ + cls_ids, scores, xs2d, ys2d, size_2d, depth, heading, size_3d, xs3d, + ys3d + ], + axis=2) + + return detections + + +def class2angle(cls, residual, to_label_format=False): + ''' Inverse function to angle2class. ''' + angle_per_class = 2 * np.pi / float(num_heading_bin) + angle_center = cls * angle_per_class + angle = angle_center + residual + if to_label_format and angle > np.pi: + angle = angle - 2 * np.pi + return angle + + +def get_heading_angle(heading): + heading_bin, heading_res = heading[0:12], heading[12:24] + cls = np.argmax(heading_bin) + res = heading_res[cls] + return class2angle(cls, res, to_label_format=True) diff --git a/paddle3d/models/optimizers/lr_schedulers.py b/paddle3d/models/optimizers/lr_schedulers.py index d598f753..9e1cc9e6 100644 --- a/paddle3d/models/optimizers/lr_schedulers.py +++ b/paddle3d/models/optimizers/lr_schedulers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ Apache-2.0 license [see LICENSE for details]. """ from functools import partial - +import math import paddle from paddle.optimizer.lr import LRScheduler @@ -32,7 +32,6 @@ @manager.LR_SCHEDULERS.add_component class OneCycleWarmupDecayLr(LRScheduler): - def __init__(self, base_learning_rate, lr_ratio_peak=10, @@ -66,7 +65,6 @@ def get_lr(self, curr_iter): class LRSchedulerCycle(LRScheduler): - def __init__(self, total_step, lr_phases, mom_phases): self.total_step = total_step @@ -78,12 +76,12 @@ def __init__(self, total_step, lr_phases, mom_phases): if isinstance(lambda_func, str): lambda_func = eval(lambda_func) if i < len(lr_phases) - 1: - self.lr_phases.append( - (int(start * total_step), - int(lr_phases[i + 1][0] * total_step), lambda_func)) + self.lr_phases.append((int(start * total_step), + int(lr_phases[i + 1][0] * total_step), + lambda_func)) else: - self.lr_phases.append( - (int(start * total_step), total_step, lambda_func)) + self.lr_phases.append((int(start * total_step), total_step, + lambda_func)) assert self.lr_phases[0][0] == 0 self.mom_phases = [] for i, (start, lambda_func) in enumerate(mom_phases): @@ -92,19 +90,18 @@ def __init__(self, total_step, lr_phases, mom_phases): if isinstance(lambda_func, str): lambda_func = eval(lambda_func) if i < len(mom_phases) - 1: - self.mom_phases.append( - (int(start * total_step), - int(mom_phases[i + 1][0] * total_step), lambda_func)) + self.mom_phases.append((int(start * total_step), + int(mom_phases[i + 1][0] * total_step), + lambda_func)) else: - self.mom_phases.append( - (int(start * total_step), total_step, lambda_func)) + self.mom_phases.append((int(start * total_step), total_step, + lambda_func)) assert self.mom_phases[0][0] == 0 super().__init__() @manager.OPTIMIZERS.add_component class OneCycle(LRSchedulerCycle): - def __init__(self, total_step, lr_max, moms, div_factor, pct_start): self.lr_max = lr_max self.moms = moms @@ -154,10 +151,48 @@ def get_lr(self): if self.last_epoch == 0: return self.base_lr else: - cur_epoch = (self.last_epoch + - self.warmup_iters) // self.iters_per_epoch + cur_epoch = ( + self.last_epoch + self.warmup_iters) // self.iters_per_epoch return annealing_cos(self.base_lr, self.eta_min, cur_epoch / self.T_max) def _get_closed_form_lr(self): return self.get_lr() + + +@manager.LR_SCHEDULERS.add_component +class CosineWarmupMultiStepDecayByEpoch(LRScheduler): + def __init__(self, + learning_rate, + warmup_steps, + start_lr, + milestones, + decay_rate, + end_lr=None): + self.iters_per_epoch = 1 + self.warmup_iters = 0 + self.warmup_epochs = warmup_steps + self.start_lr = start_lr + self.end_lr = end_lr if end_lr is not None else learning_rate + self.milestones = milestones + self.decay_rate = decay_rate + super(CosineWarmupMultiStepDecayByEpoch, self).__init__(learning_rate) + + def get_lr(self): + # update current epoch + cur_epoch = ( + self.last_epoch + self.warmup_iters) // self.iters_per_epoch + + # cosine warmup + if cur_epoch < self.warmup_epochs: + return self.start_lr + (self.end_lr - self.start_lr) * ( + 1 - math.cos(math.pi * cur_epoch / self.warmup_epochs)) / 2 + else: + if self.last_epoch in [ + self.milestones[0] * self.iters_per_epoch, + self.milestones[1] * self.iters_per_epoch + ]: + self.end_lr *= self.decay_rate + return self.end_lr + else: + return self.end_lr