From 74c5d5d8c548ea4dd73906cf5eb6a6c4b9c90f68 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 4 Jan 2023 23:44:56 +0800 Subject: [PATCH] add segmentation for PETRv2 --- ...sk_p4_800x320_map_doublesize_cos_epoch.yml | 235 ++++ paddle3d/datasets/__init__.py | 2 +- paddle3d/datasets/nuscenes/__init__.py | 1 + .../nuscenes/nuscenes_multi_view_det_seg.py | 711 +++++++++++ paddle3d/models/detection/petr/__init__.py | 1 + paddle3d/models/detection/petr/petr3d_seg.py | 445 +++++++ paddle3d/models/heads/dense_heads/__init__.py | 1 + .../models/heads/dense_heads/petr_head_seg.py | 1053 +++++++++++++++++ paddle3d/models/losses/__init__.py | 1 + paddle3d/models/losses/lane_loss.py | 76 ++ paddle3d/transforms/reader.py | 36 +- 11 files changed, 2559 insertions(+), 3 deletions(-) create mode 100644 configs/petr/petrv2_seg_vovnet_gridmask_p4_800x320_map_doublesize_cos_epoch.yml create mode 100644 paddle3d/datasets/nuscenes/nuscenes_multi_view_det_seg.py create mode 100644 paddle3d/models/detection/petr/petr3d_seg.py create mode 100644 paddle3d/models/heads/dense_heads/petr_head_seg.py create mode 100644 paddle3d/models/losses/lane_loss.py diff --git a/configs/petr/petrv2_seg_vovnet_gridmask_p4_800x320_map_doublesize_cos_epoch.yml b/configs/petr/petrv2_seg_vovnet_gridmask_p4_800x320_map_doublesize_cos_epoch.yml new file mode 100644 index 00000000..cb106080 --- /dev/null +++ b/configs/petr/petrv2_seg_vovnet_gridmask_p4_800x320_map_doublesize_cos_epoch.yml @@ -0,0 +1,235 @@ +batch_size: 1 +epochs: 24 + +train_dataset: + type: NuscenesMVSegDataset + dataset_root: data/nuscenes/ + ann_file: data/nuscenes/mmdet3d_nuscenes_30f_infos_train.pkl + lane_ann_file: data/nuscenes/HDmaps-nocover_infos_train.pkl + mode: train + class_names: [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', + 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' + ] + transforms: + - type: LoadMultiViewImageFromFiles + data_root: data/nuscenes/ + to_float32: True + - type: LoadMapsFromFiles + map_data_root: data/nuscenes/HDmaps-nocover/ + k: 0 + - type: LoadMultiViewImageFromMultiSweepsFiles + data_root: data/nuscenes/ + sweeps_num: 1 + to_float32: True + pad_empty_sweeps: True + sweep_range: [3, 27] + test_mode: False + - type: LoadAnnotations3D + with_bbox_3d: True + with_label_3d: True + - type: SampleRangeFilter + point_cloud_range: [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] + - type: SampleNameFilter + classes: [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', + 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' + ] + - type: ResizeCropFlipImage + sample_aug_cfg: + resize_lim: [0.47, 0.625] + final_dim: [320, 800] + bot_pct_lim: [0.0, 0.0] + rot_lim: [0.0, 0.0] + H: 900 + W: 1600 + rand_flip: True + training: True + - type: GlobalRotScaleTransImage + rot_range: [-0.3925, 0.3925] + translation_std: [0, 0, 0] + scale_ratio_range: [0.95, 1.05] + reverse_angle: True + training: True + - type: NormalizeMultiviewImage + mean: [103.530, 116.280, 123.675] + std: [57.375, 57.120, 58.395] + - type: PadMultiViewImage + size_divisor: 32 + - type: SampleFilerByKey + keys: ['gt_bboxes_3d', 'gt_labels_3d', 'img', 'maps'] + meta_keys: ['filename', 'ori_shape', 'img_shape', 'lidar2img', + 'intrinsics', 'extrinsics', 'pad_shape', + 'scale_factor', 'flip', 'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'sample_idx', + 'timestamp'] + +val_dataset: + type: NuscenesMVSegDataset + dataset_root: data/nuscenes/ + ann_file: data/nuscenes/mmdet3d_nuscenes_30f_infos_val.pkl + lane_ann_file: data/nuscenes/HDmaps-nocover_infos_val.pkl + mode: val + class_names: ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', + 'barrier', 'motorcycle', 'bicycle', 'pedestrian', + 'traffic_cone'] + transforms: + - type: LoadMultiViewImageFromFiles + data_root: data/nuscenes/ + to_float32: True + - type: LoadMapsFromFiles + map_data_root: data/nuscenes/HDmaps-nocover/ + k: 0 + - type: LoadMultiViewImageFromMultiSweepsFiles + data_root: data/nuscenes/ + sweeps_num: 1 + to_float32: True + pad_empty_sweeps: True + sweep_range: [3, 27] + - type: ResizeCropFlipImage + sample_aug_cfg: + resize_lim: [0.47, 0.625] + final_dim: [320, 800] + bot_pct_lim: [0.0, 0.0] + rot_lim: [0.0, 0.0] + H: 900 + W: 1600 + rand_flip: True + training: False + - type: NormalizeMultiviewImage + mean: [103.530, 116.280, 123.675] + std: [57.375, 57.120, 58.395] + - type: PadMultiViewImage + size_divisor: 32 + - type: SampleFilerByKey + keys: ['img','gt_map','maps'] + meta_keys: ['filename', 'ori_shape', 'img_shape', 'lidar2img', + 'intrinsics', 'extrinsics', 'pad_shape', + 'scale_factor', 'flip', 'box_type_3d', 'img_norm_cfg', 'sample_idx', + 'timestamp'] + +optimizer: + type: AdamW + weight_decay: 0.01 + grad_clip: + type: ClipGradByGlobalNorm + clip_norm: 35 + # auto_skip_clip: True + +lr_scheduler: + type: LinearWarmup + learning_rate: + # type: CosineAnnealingDecay + type: CosineAnnealingDecayByEpoch + learning_rate: 0.0002 + # T_max: 84408 # 3517 * 24 + T_max: 24 + eta_min: 0.0000002 + warmup_steps: 500 + start_lr: 0.00006666666 + end_lr: 0.0002 + +model: + type: Petr3D_seg + use_recompute: True + use_grid_mask: True + backbone: + type: VoVNetCP ###use checkpoint to save memory + spec_name: V-99-eSE + norm_eval: True + frozen_stages: -1 + input_ch: 3 + out_features: ('stage4','stage5',) + neck: + type: CPFPN ###remove unused parameters + in_channels: [768, 1024] + out_channels: 256 + num_outs: 2 + pts_bbox_head: + type: PETRHeadSeg + num_classes: 10 + in_channels: 256 + num_query: 900 + num_lane: 1024 + LID: true + with_multiview: true + with_position: true + with_fpe: true + with_time: true + with_multi: true + position_range: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0] + code_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + normedlinear: False + transformer: + type: PETRTransformer + decoder_embed_dims: 256 + decoder: + type: PETRTransformerDecoder + return_intermediate: True + num_layers: 6 + transformerlayers: + type: PETRTransformerDecoderLayer + attns: + - type: MultiHeadAttention + embed_dims: 256 + num_heads: 8 + attn_drop: 0.1 + drop_prob: 0.1 + - type: PETRMultiheadAttention + embed_dims: 256 + num_heads: 8 + attn_drop: 0.1 + drop_prob: 0.1 + batch_first: True + feedforward_channels: 2048 + ffn_dropout: 0.1 + operation_order: ['self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'] + transformer_lane: + type: PETRTransformer + decoder_embed_dims: 256 + decoder: + type: PETRTransformerDecoder + return_intermediate: True + num_layers: 6 + transformerlayers: + type: PETRTransformerDecoderLayer + attns: + - type: MultiHeadAttention + embed_dims: 256 + num_heads: 8 + attn_drop: 0.1 + drop_prob: 0.1 + - type: PETRMultiheadAttention + embed_dims: 256 + num_heads: 8 + attn_drop: 0.1 + drop_prob: 0.1 + batch_first: True + feedforward_channels: 2048 + ffn_dropout: 0.1 + operation_order: ['self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'] + positional_encoding: + type: SinePositionalEncoding3D + num_feats: 128 + normalize: True + bbox_coder: + type: NMSFreeCoder + post_center_range: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0] + pc_range: [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] + max_num: 300 + voxel_size: [0.2, 0.2, 8] + num_classes: 10 + loss_cls: + type: WeightedFocalLoss + gamma: 2.0 + alpha: 0.25 + loss_weight: 2.0 + reduction: sum + loss_bbox: + type: WeightedL1Loss + loss_weight: 0.25 + reduction: sum + loss_lane_mask: + type: SigmoidCELoss + loss_weight: 1.0 + reduction: mean + \ No newline at end of file diff --git a/paddle3d/datasets/__init__.py b/paddle3d/datasets/__init__.py index ea087a02..ea096a4b 100644 --- a/paddle3d/datasets/__init__.py +++ b/paddle3d/datasets/__init__.py @@ -15,5 +15,5 @@ from .base import BaseDataset from .kitti import KittiDepthDataset, KittiMonoDataset, KittiPCDataset from .modelnet40 import ModelNet40 -from .nuscenes import NuscenesMVDataset, NuscenesPCDataset +from .nuscenes import NuscenesMVDataset, NuscenesPCDataset, NuscenesMVSegDataset from .waymo import WaymoPCDataset diff --git a/paddle3d/datasets/nuscenes/__init__.py b/paddle3d/datasets/nuscenes/__init__.py index 9ed1cd43..faad650a 100644 --- a/paddle3d/datasets/nuscenes/__init__.py +++ b/paddle3d/datasets/nuscenes/__init__.py @@ -14,3 +14,4 @@ from .nuscenes_multi_view_det import NuscenesMVDataset from .nuscenes_pointcloud_det import NuscenesPCDataset +from .nuscenes_multi_view_det_seg import NuscenesMVSegDataset \ No newline at end of file diff --git a/paddle3d/datasets/nuscenes/nuscenes_multi_view_det_seg.py b/paddle3d/datasets/nuscenes/nuscenes_multi_view_det_seg.py new file mode 100644 index 00000000..d70a6a88 --- /dev/null +++ b/paddle3d/datasets/nuscenes/nuscenes_multi_view_det_seg.py @@ -0,0 +1,711 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +import os +import os.path as osp +import pickle +from collections.abc import Mapping, Sequence +from functools import reduce +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import paddle +from nuscenes.utils import splits as nuscenes_split +from nuscenes.utils.data_classes import Box as NuScenesBox +from nuscenes.utils.geometry_utils import transform_matrix +from pyquaternion import Quaternion + +import paddle3d.transforms as T +from paddle3d.apis import manager +from paddle3d.datasets.nuscenes.nuscenes_det import NuscenesDetDataset +from paddle3d.datasets.nuscenes.nuscenes_manager import NuScenesManager +from paddle3d.geometries import BBoxes3D, CoordMode +from paddle3d.sample import Sample, SampleMeta +from paddle3d.transforms import TransformABC +from paddle3d.utils.logger import logger + + +def is_filepath(x): + return isinstance(x, str) or isinstance(x, Path) + + +@manager.DATASETS.add_component +class NuscenesMVSegDataset(NuscenesDetDataset): + """ + Nuscecens dataset for multi-view camera detection task. + """ + DATASET_NAME = "Nuscenes" + + def __init__(self, + dataset_root: str, + ann_file: str = None, + lane_ann_file: str = None, + load_interval: int = 1, + mode: str = "train", + transforms: Union[TransformABC, List[TransformABC]] = None, + max_sweeps: int = 10, + class_names: Union[list, tuple] = None, + use_valid_flag: bool = False): + self.mode = mode + self.dataset_root = dataset_root + self.filter_empty_gt = True + self.box_type_3d = 'LiDAR' + self.box_mode_3d = None + self.ann_file = ann_file + self.lane_ann_file = lane_ann_file + self.load_interval = load_interval + self.version = self.VERSION_MAP[self.mode] + + self.max_sweeps = max_sweeps + self._build_data() + self.metadata = self.data_infos['metadata'] + + self.data_infos = list( + sorted(self.data_infos['infos'], key=lambda e: e['timestamp'])) + self.lane_infos = self.load_annotations(lane_ann_file) + + if isinstance(transforms, list): + transforms = T.Compose(transforms) + + self.transforms = transforms + + if not self.is_test_mode: + self.flag = np.zeros(len(self), dtype=np.uint8) + + self.modality = dict( + use_camera=True, + use_lidar=False, + use_radar=False, + use_map=False, + use_external=True, + ) + self.with_velocity = True + self.use_valid_flag = use_valid_flag + self.channel = "LIDAR_TOP" + if class_names is not None: + self.class_names = class_names + else: + self.class_names = list(self.CLASS_MAP.keys()) + + def __len__(self): + return len(self.data_infos) + + def _rand_another(self, idx): + """Randomly get another item with the same flag. + + Returns: + int: Another index of item with the same flag. + """ + pool = np.where(self.flag == self.flag[idx])[0] + return np.random.choice(pool) + + def load_annotations(self, ann_file): + """Load annotations from ann_file. + Args: + ann_file (str): Path of the annotation file. + Returns: + list[dict]: List of annotations sorted by timestamps. + """ + #data = mmcv.load(ann_file, file_format='pkl') + data = np.load(ann_file, allow_pickle=True) + data_infos = list(sorted(data['infos'], key=lambda e: e['timestamp'])) + data_infos = data_infos[::self.load_interval] + self.metadata = data['metadata'] + self.version = self.metadata['version'] + return data_infos + + def get_ann_info(self, index): + """Get annotation info according to the given index. + + Args: + index (int): Index of the annotation data to get. + + Returns: + dict: Annotation information consists of the following keys: + + - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): \ + 3D ground truth bboxes + - gt_labels_3d (np.ndarray): Labels of ground truths. + - gt_names (list[str]): Class names of ground truths. + """ + info = self.data_infos[index] + # filter out bbox containing no points + if self.use_valid_flag: + mask = info['valid_flag'] + else: + mask = info['num_lidar_pts'] > 0 + gt_bboxes_3d = info['gt_boxes'][mask] + gt_names_3d = info['gt_names'][mask] + gt_labels_3d = [] + for cat in gt_names_3d: + if cat in self.CLASS_MAP: + # gt_labels_3d.append(self.CLASS_MAP[cat]) + gt_labels_3d.append(self.class_names.index(cat)) + else: + gt_labels_3d.append(-1) + gt_labels_3d = np.array(gt_labels_3d) + + if self.with_velocity: + gt_velocity = info['gt_velocity'][mask] + nan_mask = np.isnan(gt_velocity[:, 0]) + gt_velocity[nan_mask] = [0.0, 0.0] + gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1) + + # the nuscenes box center is [0.5, 0.5, 0.5], we change it to be + # the same as KITTI (0.5, 0.5, 0) + origin = [0.5, 0.5, 0.5] + dst = np.array([0.5, 0.5, 0]) + src = np.array(origin) + gt_bboxes_3d[:, :3] += gt_bboxes_3d[:, 3:6] * (dst - src) + gt_bboxes_3d = BBoxes3D(gt_bboxes_3d, + coordmode=2, + origin=[0.5, 0.5, 0.5]) + + anns_results = dict(gt_bboxes_3d=gt_bboxes_3d, + gt_labels_3d=gt_labels_3d, + gt_names=gt_names_3d) + return anns_results + + def get_data_info(self, index): + """Get data info according to the given index. + Args: + index (int): Index of the sample data to get. + Returns: + dict: Data information that will be passed to the data \ + preprocessing pipelines. It includes the following keys: + + - sample_idx (str): Sample index. + - pts_filename (str): Filename of point clouds. + - sweeps (list[dict]): Infos of sweeps. + - timestamp (float): Sample timestamp. + - img_filename (str, optional): Image filename. + - lidar2img (list[np.ndarray], optional): Transformations \ + from lidar to different cameras. + - ann_info (dict): Annotation info. + """ + info = self.data_infos[index] + lane_info = self.lane_infos[index] + + sample = Sample(path=None, modality="multiview") + sample.sample_idx = info['token'] + sample.meta.id = info['token'] + sample.pts_filename = info['lidar_path'] + sample.sweeps = info['sweeps'] + sample.timestamp = info['timestamp'] / 1e6 + sample.map_filename = lane_info['maps']['map_mask'] + + if self.modality['use_camera']: + image_paths = [] + lidar2img_rts = [] + intrinsics = [] + extrinsics = [] + img_timestamp = [] + for cam_type, cam_info in info['cams'].items(): + img_timestamp.append(cam_info['timestamp'] / 1e6) + image_paths.append(cam_info['data_path']) + # obtain lidar to image transformation matrix + lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation']) + lidar2cam_t = cam_info[ + 'sensor2lidar_translation'] @ lidar2cam_r.T + lidar2cam_rt = np.eye(4) + lidar2cam_rt[:3, :3] = lidar2cam_r.T + lidar2cam_rt[3, :3] = -lidar2cam_t + intrinsic = cam_info['cam_intrinsic'] + viewpad = np.eye(4) + viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic + lidar2img_rt = (viewpad @ lidar2cam_rt.T) + intrinsics.append(viewpad) + # The extrinsics mean the tranformation from lidar to camera. + # If anyone want to use the extrinsics as sensor to lidar, please + # use np.linalg.inv(lidar2cam_rt.T) and modify the ResizeCropFlipImage + # and LoadMultiViewImageFromMultiSweepsFiles. + extrinsics.append(lidar2cam_rt) + lidar2img_rts.append(lidar2img_rt) + + sample.update( + dict(img_timestamp=img_timestamp, + img_filename=image_paths, + lidar2img=lidar2img_rts, + intrinsics=intrinsics, + extrinsics=extrinsics)) + + # if not self.is_test_mode: + if self.mode == 'train': + annos = self.get_ann_info(index) + sample.ann_info = annos + return sample + + def __getitem__(self, idx): + if self.is_test_mode: + pass + + while True: + sample = self.get_data_info(idx) + + if sample is None: + idx = self._rand_another(idx) + continue + + sample['img_fields'] = [] + sample['bbox3d_fields'] = [] + sample['pts_mask_fields'] = [] + sample['pts_seg_fields'] = [] + sample['bbox_fields'] = [] + sample['mask_fields'] = [] + sample['seg_fields'] = [] + sample['box_type_3d'] = self.box_type_3d + sample['box_mode_3d'] = self.box_mode_3d + + sample = self.transforms(sample) + + if self.is_train_mode and self.filter_empty_gt and \ + (sample is None or len(sample['gt_labels_3d']) == 0 ): + idx = self._rand_another(idx) + continue + + return sample + + def _build_data(self): + test = 'test' in self.version + + if self.ann_file is not None: + self.data_infos = pickle.load(open(self.ann_file, 'rb')) + return + + if test: + test_ann_cache_file = os.path.join( + self.dataset_root, + '{}_annotation_test.pkl'.format(self.DATASET_NAME)) + if os.path.exists(test_ann_cache_file): + self.data_infos = pickle.load(open(test_ann_cache_file, 'rb')) + return + else: + train_ann_cache_file = os.path.join( + self.dataset_root, + '{}_annotation_train.pkl'.format(self.DATASET_NAME)) + val_ann_cache_file = os.path.join( + self.dataset_root, + '{}_annotation_val.pkl'.format(self.DATASET_NAME)) + if os.path.exists(train_ann_cache_file): + self.data_infos = pickle.load(open(train_ann_cache_file, 'rb')) + return + #print('self.version for nusc: ', self.version) + self.nusc = NuScenesManager.get(version=self.version, + dataroot=self.dataset_root) + + if self.version == 'v1.0-trainval': + train_scenes = nuscenes_split.train + val_scenes = nuscenes_split.val + elif self.version == 'v1.0-test': + train_scenes = nuscenes_split.test + val_scenes = [] + elif self.version == 'v1.0-mini': + train_scenes = nuscenes_split.mini_train + val_scenes = nuscenes_split.mini_val + else: + raise ValueError('unknown nuscenes dataset version') + + available_scenes = get_available_scenes(self.nusc) + available_scene_names = [s['name'] for s in available_scenes] + + train_scenes = list( + filter(lambda x: x in available_scene_names, train_scenes)) + val_scenes = list( + filter(lambda x: x in available_scene_names, val_scenes)) + train_scenes = set([ + available_scenes[available_scene_names.index(s)]['token'] + for s in train_scenes + ]) + val_scenes = set([ + available_scenes[available_scene_names.index(s)]['token'] + for s in val_scenes + ]) + + if test: + print('test scene: {}'.format(len(train_scenes))) + else: + print('train scene: {}, val scene: {}'.format( + len(train_scenes), len(val_scenes))) + train_nusc_infos, val_nusc_infos = _fill_trainval_infos( + self.nusc, + train_scenes, + val_scenes, + test, + max_sweeps=self.max_sweeps) + + metadata = dict(version=self.version) + + if test: + print('test sample: {}'.format(len(train_nusc_infos))) + data = dict(infos=train_nusc_infos, metadata=metadata) + pickle.dump(data, open(test_ann_cache_file, 'wb')) + else: + print('train sample: {}, val sample: {}'.format( + len(train_nusc_infos), len(val_nusc_infos))) + data = dict(infos=train_nusc_infos, metadata=metadata) + + pickle.dump(data, open(train_ann_cache_file, 'wb')) + + data['infos'] = val_nusc_infos + + pickle.dump(data, open(val_ann_cache_file, 'wb')) + + def _filter(self, anno: dict, box: NuScenesBox = None) -> bool: + # filter out objects that are not being scanned + mask = (anno['num_lidar_pts'] + anno['num_radar_pts']) > 0 and \ + anno['category_name'] in self.LABEL_MAP and \ + self.LABEL_MAP[anno['category_name']] in self.class_names + return mask + + def get_sweeps(self, index: int) -> List[str]: + """ + """ + sweeps = [] + sample = self.data[index] + token = sample['data'][self.channel] + sample_data = self.nusc.get('sample_data', token) + + if self.max_sweeps <= 0: + return sweeps + + # Homogeneous transform of current sample from ego car coordinate to sensor coordinate + curr_sample_cs = self.nusc.get("calibrated_sensor", + sample_data["calibrated_sensor_token"]) + curr_sensor_from_car = transform_matrix(curr_sample_cs["translation"], + Quaternion( + curr_sample_cs["rotation"]), + inverse=True) + # Homogeneous transformation matrix of current sample from global coordinate to ego car coordinate + curr_sample_pose = self.nusc.get("ego_pose", + sample_data["ego_pose_token"]) + curr_car_from_global = transform_matrix( + curr_sample_pose["translation"], + Quaternion(curr_sample_pose["rotation"]), + inverse=True, + ) + curr_timestamp = 1e-6 * sample_data["timestamp"] + + prev_token = sample_data['prev'] + while len(sweeps) < self.max_sweeps - 1: + if prev_token == "": + if len(sweeps) == 0: + sweeps.append({ + "lidar_path": + osp.join(self.dataset_root, sample_data['filename']), + "time_lag": + 0, + "ref_from_curr": + None, + }) + else: + sweeps.append(sweeps[-1]) + else: + prev_sample_data = self.nusc.get('sample_data', prev_token) + # Homogeneous transformation matrix of previous sample from ego car coordinate to global coordinate + prev_sample_pose = self.nusc.get( + "ego_pose", prev_sample_data["ego_pose_token"]) + prev_global_from_car = transform_matrix( + prev_sample_pose["translation"], + Quaternion(prev_sample_pose["rotation"]), + inverse=False, + ) + # Homogeneous transform of previous sample from sensor coordinate to ego car coordinate + prev_sample_cs = self.nusc.get( + "calibrated_sensor", + prev_sample_data["calibrated_sensor_token"]) + prev_car_from_sensor = transform_matrix( + prev_sample_cs["translation"], + Quaternion(prev_sample_cs["rotation"]), + inverse=False, + ) + + curr_from_pre = reduce( + np.dot, + [ + curr_sensor_from_car, curr_car_from_global, + prev_global_from_car, prev_car_from_sensor + ], + ) + prev_timestamp = 1e-6 * prev_sample_data["timestamp"] + time_lag = curr_timestamp - prev_timestamp + + sweeps.append({ + "lidar_path": + osp.join(self.dataset_root, prev_sample_data['filename']), + "time_lag": + time_lag, + "ref_from_curr": + curr_from_pre, + }) + prev_token = prev_sample_data['prev'] + return sweeps + + @property + def metric(self): + print('self.version for metric: ', self.version) + if not hasattr(self, 'nusc'): + self.nusc = NuScenesManager.get(version=self.version, + dataroot=self.dataset_root) + return super().metric + + def collate_fn(self, batch: List): + """ + """ + sample = batch[0] + if isinstance(sample, np.ndarray): + try: + batch = np.stack(batch, axis=0) + return batch + except Exception as e: + return batch + elif isinstance(sample, SampleMeta): + return batch + return super().collate_fn(batch) + + +def get_available_scenes(nusc): + """Get available scenes from the input nuscenes class. + + Given the raw data, get the information of available scenes for + further info generation. + + Args: + nusc (class): Dataset class in the nuScenes dataset. + + Returns: + available_scenes (list[dict]): List of basic information for the + available scenes. + """ + available_scenes = [] + print('total scene num: {}'.format(len(nusc.scene))) + for scene in nusc.scene: + scene_token = scene['token'] + scene_rec = nusc.get('scene', scene_token) + sample_rec = nusc.get('sample', scene_rec['first_sample_token']) + sd_rec = nusc.get('sample_data', sample_rec['data']['LIDAR_TOP']) + has_more_frames = True + scene_not_exist = False + while has_more_frames: + lidar_path, boxes, _ = nusc.get_sample_data(sd_rec['token']) + lidar_path = str(lidar_path) + if os.getcwd() in lidar_path: + # path from lyftdataset is absolute path + lidar_path = lidar_path.split(f'{os.getcwd()}/')[-1] + # relative path + if not is_filepath(lidar_path): + scene_not_exist = True + break + else: + break + if scene_not_exist: + continue + available_scenes.append(scene) + print('exist scene num: {}'.format(len(available_scenes))) + return available_scenes + + +def _fill_trainval_infos(nusc, + train_scenes, + val_scenes, + test=False, + max_sweeps=10): + """Generate the train/val infos from the raw data. + + Args: + nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset. + train_scenes (list[str]): Basic information of training scenes. + val_scenes (list[str]): Basic information of validation scenes. + test (bool, optional): Whether use the test mode. In test mode, no + annotations can be accessed. Default: False. + max_sweeps (int, optional): Max number of sweeps. Default: 10. + + Returns: + tuple[list[dict]]: Information of training set and validation set + that will be saved to the info file. + """ + train_nusc_infos = [] + val_nusc_infos = [] + + msg = "Begin to generate a info of nuScenes dataset." + + for sample_idx in logger.range(len(nusc.sample), msg=msg): + sample = nusc.sample[sample_idx] + lidar_token = sample['data']['LIDAR_TOP'] + sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP']) + cs_record = nusc.get('calibrated_sensor', + sd_rec['calibrated_sensor_token']) + pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token']) + lidar_path, boxes, _ = nusc.get_sample_data(lidar_token) + + assert os.path.exists(lidar_path) + + info = { + 'lidar_path': lidar_path, + 'token': sample['token'], + 'sweeps': [], + 'cams': dict(), + 'lidar2ego_translation': cs_record['translation'], + 'lidar2ego_rotation': cs_record['rotation'], + 'ego2global_translation': pose_record['translation'], + 'ego2global_rotation': pose_record['rotation'], + 'timestamp': sample['timestamp'], + } + + l2e_r = info['lidar2ego_rotation'] + l2e_t = info['lidar2ego_translation'] + e2g_r = info['ego2global_rotation'] + e2g_t = info['ego2global_translation'] + l2e_r_mat = Quaternion(l2e_r).rotation_matrix + e2g_r_mat = Quaternion(e2g_r).rotation_matrix + + # obtain 6 image's information per frame + camera_types = [ + 'CAM_FRONT', + 'CAM_FRONT_RIGHT', + 'CAM_FRONT_LEFT', + 'CAM_BACK', + 'CAM_BACK_LEFT', + 'CAM_BACK_RIGHT', + ] + for cam in camera_types: + cam_token = sample['data'][cam] + cam_path, _, cam_intrinsic = nusc.get_sample_data(cam_token) + cam_info = obtain_sensor2top(nusc, cam_token, l2e_t, l2e_r_mat, + e2g_t, e2g_r_mat, cam) + cam_info.update(cam_intrinsic=cam_intrinsic) + info['cams'].update({cam: cam_info}) + + # obtain sweeps for a single key-frame + sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP']) + sweeps = [] + while len(sweeps) < max_sweeps: + if not sd_rec['prev'] == '': + sweep = obtain_sensor2top(nusc, sd_rec['prev'], l2e_t, + l2e_r_mat, e2g_t, e2g_r_mat, 'lidar') + sweeps.append(sweep) + sd_rec = nusc.get('sample_data', sd_rec['prev']) + else: + break + info['sweeps'] = sweeps + # obtain annotation + if not test: + annotations = [ + nusc.get('sample_annotation', token) for token in sample['anns'] + ] + locs = np.array([b.center for b in boxes]).reshape(-1, 3) + dims = np.array([b.wlh for b in boxes]).reshape(-1, 3) + rots = np.array([b.orientation.yaw_pitch_roll[0] + for b in boxes]).reshape(-1, 1) + velocity = np.array( + [nusc.box_velocity(token)[:2] for token in sample['anns']]) + valid_flag = np.array( + [(anno['num_lidar_pts'] + anno['num_radar_pts']) > 0 + for anno in annotations], + dtype=bool).reshape(-1) + # convert velo from global to lidar + for i in range(len(boxes)): + velo = np.array([*velocity[i], 0.0]) + velo = velo @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv( + l2e_r_mat).T + velocity[i] = velo[:2] + + names = [b.name for b in boxes] + for i in range(len(names)): + # NuscenesDetDataset.LABEL_MAP + if names[i] in NuscenesDetDataset.LABEL_MAP: + names[i] = NuscenesDetDataset.LABEL_MAP[names[i]] + names = np.array(names) + # we need to convert box size to + # the format of our lidar coordinate system + # which is x_size, y_size, z_size (corresponding to l, w, h) + gt_boxes = np.concatenate([locs, dims[:, [1, 0, 2]], rots], axis=1) + assert len(gt_boxes) == len( + annotations), f'{len(gt_boxes)}, {len(annotations)}' + info['gt_boxes'] = gt_boxes + info['gt_names'] = names + info['gt_velocity'] = velocity.reshape(-1, 2) + info['num_lidar_pts'] = np.array( + [a['num_lidar_pts'] for a in annotations]) + info['num_radar_pts'] = np.array( + [a['num_radar_pts'] for a in annotations]) + info['valid_flag'] = valid_flag + + if sample['scene_token'] in train_scenes: + train_nusc_infos.append(info) + else: + val_nusc_infos.append(info) + + return train_nusc_infos, val_nusc_infos + + +def obtain_sensor2top(nusc, + sensor_token, + l2e_t, + l2e_r_mat, + e2g_t, + e2g_r_mat, + sensor_type='lidar'): + """Obtain the info with RT matric from general sensor to Top LiDAR. + + Args: + nusc (class): Dataset class in the nuScenes dataset. + sensor_token (str): Sample data token corresponding to the + specific sensor type. + l2e_t (np.ndarray): Translation from lidar to ego in shape (1, 3). + l2e_r_mat (np.ndarray): Rotation matrix from lidar to ego + in shape (3, 3). + e2g_t (np.ndarray): Translation from ego to global in shape (1, 3). + e2g_r_mat (np.ndarray): Rotation matrix from ego to global + in shape (3, 3). + sensor_type (str, optional): Sensor to calibrate. Default: 'lidar'. + + Returns: + sweep (dict): Sweep information after transformation. + """ + sd_rec = nusc.get('sample_data', sensor_token) + cs_record = nusc.get('calibrated_sensor', sd_rec['calibrated_sensor_token']) + pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token']) + data_path = str(nusc.get_sample_data_path(sd_rec['token'])) + if os.getcwd() in data_path: # path from lyftdataset is absolute path + data_path = data_path.split(f'{os.getcwd()}/')[-1] # relative path + sweep = { + 'data_path': data_path, + 'type': sensor_type, + 'sample_data_token': sd_rec['token'], + 'sensor2ego_translation': cs_record['translation'], + 'sensor2ego_rotation': cs_record['rotation'], + 'ego2global_translation': pose_record['translation'], + 'ego2global_rotation': pose_record['rotation'], + 'timestamp': sd_rec['timestamp'] + } + l2e_r_s = sweep['sensor2ego_rotation'] + l2e_t_s = sweep['sensor2ego_translation'] + e2g_r_s = sweep['ego2global_rotation'] + e2g_t_s = sweep['ego2global_translation'] + + # obtain the RT from sensor to Top LiDAR + # sweep->ego->global->ego'->lidar + l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix + e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix + R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ ( + np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) + T = (l2e_t_s @ e2g_r_s_mat.T + + e2g_t_s) @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) + T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T + ) + l2e_t @ np.linalg.inv(l2e_r_mat).T + sweep['sensor2lidar_rotation'] = R.T # points @ R.T + T + sweep['sensor2lidar_translation'] = T + return sweep diff --git a/paddle3d/models/detection/petr/__init__.py b/paddle3d/models/detection/petr/__init__.py index e9a905b9..e445aa1d 100644 --- a/paddle3d/models/detection/petr/__init__.py +++ b/paddle3d/models/detection/petr/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .petr3d import Petr3D +from .petr3d_seg import Petr3D_seg \ No newline at end of file diff --git a/paddle3d/models/detection/petr/petr3d_seg.py b/paddle3d/models/detection/petr/petr3d_seg.py new file mode 100644 index 00000000..05edd7cd --- /dev/null +++ b/paddle3d/models/detection/petr/petr3d_seg.py @@ -0,0 +1,445 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +import os +from os import path as osp +import copy +import cv2 +import uuid + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from PIL import Image + +from paddle3d.apis import manager +from paddle3d.geometries import BBoxes3D +from paddle3d.sample import Sample, SampleMeta +from paddle3d.utils import dtype2float32 + +from einops import rearrange + + +def IOU(intputs, targets): + numerator = 2 * (intputs * targets).sum(axis=1) + denominator = intputs.sum(axis=1) + targets.sum(axis=1) + loss = (numerator + 0.01) / (denominator + 0.01) + return loss + + +class GridMask(nn.Layer): + def __init__(self, + use_h, + use_w, + rotate=1, + offset=False, + ratio=0.5, + mode=0, + prob=1.): + super(GridMask, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.st_prob = prob + self.prob = prob + + def set_prob(self, epoch, max_epoch): + self.prob = self.st_prob * epoch / max_epoch #+ 1.#0.5 + + def forward(self, x): + if np.random.rand() > self.prob or not self.training: + return x + n, c, h, w = x.shape + x = x.reshape([-1, h, w]) + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(2, h) + self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + self.l, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + self.l, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[(hh - h) // 2:(hh - h) // 2 + + h, (ww - w) // 2:(ww - w) // 2 + w] + + mask = paddle.to_tensor(mask).astype('float32') + if self.mode == 1: + mask = 1 - mask + mask = mask.expand_as(x) + if self.offset: + offset = paddle.to_tensor( + 2 * (np.random.rand(h, w) - 0.5)).astype('float32') + x = x * mask + offset * (1 - mask) + else: + x = x * mask + + return x.reshape([n, c, h, w]) + + +def bbox3d2result(bboxes, scores, labels, attrs=None): + """Convert detection results to a list of numpy arrays. + """ + result_dict = dict( + boxes_3d=bboxes.cpu(), scores_3d=scores.cpu(), labels_3d=labels.cpu()) + + if attrs is not None: + result_dict['attrs_3d'] = attrs.cpu() + + return result_dict + + +@manager.MODELS.add_component +class Petr3D_seg(nn.Layer): + """Petr3D_seg.""" + + def __init__(self, + use_grid_mask=False, + backbone=None, + neck=None, + pts_bbox_head=None, + img_roi_head=None, + img_rpn_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + use_recompute=False): + super(Petr3D_seg, self).__init__() + + self.pts_bbox_head = pts_bbox_head + self.backbone = backbone + self.neck = neck + self.use_grid_mask = use_grid_mask + self.use_recompute = use_recompute + + if use_grid_mask: + self.grid_mask = GridMask( + True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7) + + self.init_weight() + + def init_weight(self, bias_lr_factor=0.1): + for _, param in self.backbone.named_parameters(): + param.optimize_attr['learning_rate'] = bias_lr_factor + + self.pts_bbox_head.init_weights() + + def extract_img_feat(self, img, img_metas): + """Extract features of images.""" + print('img in extract_img_feat: ', type(img)) + if isinstance(img, list): + img = paddle.stack(img, axis=0) + + B = img.shape[0] + if img is not None: + input_shape = img.shape[-2:] + # update real input shape of each single img + if not (hasattr(self, 'export_model') and self.export_model): + for img_meta in img_metas: + img_meta.update(input_shape=input_shape) + if img.dim() == 5: + if img.shape[0] == 1 and img.shape[1] != 1: + if hasattr(self, 'export_model') and self.export_model: + img = img.squeeze() + else: + img.squeeze_() + else: + B, N, C, H, W = img.shape + img = img.reshape([B * N, C, H, W]) + if self.use_grid_mask: + img = self.grid_mask(img) + img_feats = self.backbone(img) + if isinstance(img_feats, dict): + img_feats = list(img_feats.values()) + else: + return None + + img_feats = self.neck(img_feats) + + img_feats_reshaped = [] + for img_feat in img_feats: + BN, C, H, W = img_feat.shape + img_feats_reshaped.append( + img_feat.reshape([B, int(BN / B), C, H, W])) + + return img_feats_reshaped + + def extract_feat(self, img, img_metas): + """Extract features from images and points.""" + print('img in extract_feat: ', type(img)) + img_feats = self.extract_img_feat(img, img_metas) + return img_feats + + def forward_pts_train(self, + pts_feats, + gt_bboxes_3d, + gt_labels_3d, + maps, + img_metas, + gt_bboxes_ignore=None): + """ + """ + outs = self.pts_bbox_head(pts_feats, img_metas) + loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs, maps] + losses = self.pts_bbox_head.loss(*loss_inputs) + + return losses + + def forward(self, samples, **kwargs): + """ + """ + if self.training: + self.backbone.train() + return self.forward_train(samples, **kwargs) + else: + return self.forward_test(samples, **kwargs) + + def forward_train(self, + samples=None, + points=None, + img_metas=None, + gt_bboxes_3d=None, + gt_labels_3d=None, + gt_labels=None, + maps=None, + gt_bboxes=None, + img=None, + proposals=None, + gt_bboxes_ignore=None, + img_depth=None, + img_mask=None): + """ + """ + + if samples is not None: + img_metas = samples['meta'] + img = samples['img'] + gt_labels_3d = samples['gt_labels_3d'] + gt_bboxes_3d = samples['gt_bboxes_3d'] + maps = samples['maps'] + + if hasattr(self, 'amp_cfg_'): + with paddle.amp.auto_cast(**self.amp_cfg_): + img_feats = self.extract_feat(img=img, img_metas=img_metas) + img_feats = dtype2float32(img_feats) + else: + img_feats = self.extract_feat(img=img, img_metas=img_metas) + + losses = dict() + #losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore) + losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d, maps, img_metas, gt_bboxes_ignore) + losses.update(losses_pts) + + return dict(loss=losses) + + def forward_test(self, samples, gt_map=None, maps=None, img=None, **kwargs): + img_metas = samples['meta'] + img = samples['img'] + gt_map = samples['gt_map'] + maps = samples['maps'] + + img = [img] if img is None else img + + results = self.simple_test(img_metas, gt_map, img, maps, **kwargs) + return dict(preds=self._parse_results_to_sample(results, samples)) + + def simple_test_pts(self, x, img_metas, gt_map, maps, rescale=False): + """Test function of point cloud branch.""" + + outs = self.pts_bbox_head(x, img_metas) + bbox_list = self.pts_bbox_head.get_bboxes( + outs, img_metas, rescale=rescale) + + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + + with paddle.no_grad(): + lane_preds=outs['all_lane_preds'][5].squeeze(0) #[B,N,H,W] + n,w = lane_preds.shape + #pred_maps = lane_preds.reshape([256,3,16,16]) + pred_maps = lane_preds.reshape([1024,3,16,16]) + f_lane = rearrange(pred_maps.cpu().numpy(), '(h w) c h1 w2 -> c (h h1) (w w2)', h=32, w=32) + f_lane = F.sigmoid(paddle.to_tensor(f_lane)) + f_lane[f_lane>=0.5] = 1 + f_lane[f_lane<0.5] = 0 + f_lane_show=copy.deepcopy(f_lane) + gt_map_show=copy.deepcopy(gt_map[0]) + + f_lane=f_lane.reshape([3,-1]) + gt_map=gt_map[0].reshape([3,-1]) + + ret_iou=IOU(f_lane, gt_map).cpu() + show_res=False + if show_res: + save_uuid = str(uuid.uuid1()) + pres = f_lane_show.cpu().numpy() + pre = np.zeros([512, 512, 3]) + pre += 255 + label = [[71, 130, 255], [255, 255, 0], [255, 144, 30]] + pre[..., 0][pres[0] > 0.5] = label[0][0] + pre[..., 1][pres[0] > 0.23] = label[0][1] + pre[..., 2][pres[0] > 0.56] = label[0][2] + pre[..., 0][pres[2] > 0.5] = label[2][0] + pre[..., 1][pres[2] > 0.23] = label[2][1] + pre[..., 2][pres[2] > 0.56] = label[2][2] + pre[..., 0][pres[1] > 0.5] = label[1][0] + pre[..., 1][pres[1] > 0.23] = label[1][1] + pre[..., 2][pres[1] > 0.56] = label[1][2] + #save_pred_path = '/notebooks/paddle3D/Paddle3D_for_develop/visible/res_pre/{}.png'.format(save_uuid) + #cv2.imwrite(save_pred_path, pre.astype(np.uint8)) + pres = gt_map_show[0] + pre = paddle.zeros([512, 512, 3]) + pre += 255 + pre[..., 0][pres[0] > 0.5] = label[0][0] + pre[..., 1][pres[0] > 0.5] = label[0][1] + pre[..., 2][pres[0] > 0.5] = label[0][2] + pre[..., 0][pres[2] > 0.5] = label[2][0] + pre[..., 1][pres[2] > 0.5] = label[2][1] + pre[..., 2][pres[2] > 0.5] = label[2][2] + pre[..., 0][pres[1] > 0.5] = label[1][0] + pre[..., 1][pres[1] > 0.5] = label[1][1] + pre[..., 2][pres[1] > 0.5] = label[1][2] + #save_gt_path = '/notebooks/paddle3D/Paddle3D_for_develop/visible/res_gt/{}.png'.format(save_uuid) + #cv2.imwrite(save_gt_path, pres.cpu().numpy().astype(np.uint8) * 200) + return bbox_results, ret_iou + + def simple_test(self, img_metas, gt_map=None, img=None, maps=None, rescale=False): + """Test function without augmentaiton.""" + img_feats = self.extract_feat(img=img, img_metas=img_metas) + bbox_list = [dict() for i in range(len(img_metas))] + #bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale) + bbox_pts, ret_iou = self.simple_test_pts(img_feats, img_metas, gt_map, maps, rescale=rescale) + for result_dict, pts_bbox in zip(bbox_list, bbox_pts): + result_dict['pts_bbox'] = pts_bbox + result_dict['ret_iou'] = ret_iou + return bbox_list + + def _parse_results_to_sample(self, results: dict, sample: dict): + num_samples = len(results) + new_results = [] + for i in range(num_samples): + data = Sample(None, sample["modality"][i]) + bboxes_3d = results[i]['pts_bbox']["boxes_3d"].numpy() + labels = results[i]['pts_bbox']["labels_3d"].numpy() + confidences = results[i]['pts_bbox']["scores_3d"].numpy() + bottom_center = bboxes_3d[:, :3] + gravity_center = np.zeros_like(bottom_center) + gravity_center[:, :2] = bottom_center[:, :2] + gravity_center[:, 2] = bottom_center[:, 2] + bboxes_3d[:, 5] * 0.5 + bboxes_3d[:, :3] = gravity_center + data.bboxes_3d = BBoxes3D(bboxes_3d[:, 0:7]) + data.bboxes_3d.coordmode = 'Lidar' + data.bboxes_3d.origin = [0.5, 0.5, 0.5] + data.bboxes_3d.rot_axis = 2 + data.bboxes_3d.velocities = bboxes_3d[:, 7:9] + data['bboxes_3d_numpy'] = bboxes_3d[:, 0:7] + data['bboxes_3d_coordmode'] = 'Lidar' + data['bboxes_3d_origin'] = [0.5, 0.5, 0.5] + data['bboxes_3d_rot_axis'] = 2 + data['bboxes_3d_velocities'] = bboxes_3d[:, 7:9] + data.labels = labels + data.confidences = confidences + data.meta = SampleMeta(id=sample["meta"][i]['id']) + if "calibs" in sample: + calib = [calibs.numpy()[i] for calibs in sample["calibs"]] + data.calibs = calib + new_results.append(data) + return new_results + + def aug_test_pts(self, feats, img_metas, rescale=False): + feats_list = [] + for j in range(len(feats[0])): + feats_list_level = [] + for i in range(len(feats)): + feats_list_level.append(feats[i][j]) + feats_list.append(paddle.stack(feats_list_level, -1).mean(-1)) + outs = self.pts_bbox_head(feats_list, img_metas) + bbox_list = self.pts_bbox_head.get_bboxes( + outs, img_metas, rescale=rescale) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def aug_test(self, img_metas, imgs=None, rescale=False): + """Test function with augmentaiton.""" + img_feats = self.extract_feats(img_metas, imgs) + img_metas = img_metas[0] + bbox_list = [dict() for i in range(len(img_metas))] + bbox_pts = self.aug_test_pts(img_feats, img_metas, rescale) + for result_dict, pts_bbox in zip(bbox_list, bbox_pts): + result_dict['pts_bbox'] = pts_bbox + return bbox_list + + def export_forward(self, img, img_metas, time_stamp=None): + img_metas['image_shape'] = img.shape[-2:] + img_feats = self.extract_feat(img=img, img_metas=None) + + bbox_list = [dict() for i in range(len(img_metas))] + self.pts_bbox_head.export_model = True + outs = self.pts_bbox_head.export_forward(img_feats, img_metas, + time_stamp) + bbox_list = self.pts_bbox_head.get_bboxes(outs, None, rescale=True) + return bbox_list + + def export(self, save_dir: str, **kwargs): + self.forward = self.export_forward + self.export_model = True + + num_cams = 12 if self.pts_bbox_head.with_time else 6 + image_spec = paddle.static.InputSpec( + shape=[1, num_cams, 3, 320, 800], dtype="float32") + img2lidars_spec = { + "img2lidars": + paddle.static.InputSpec( + shape=[1, num_cams, 4, 4], name='img2lidars'), + } + + input_spec = [image_spec, img2lidars_spec] + + model_name = "petr_inference" + if self.pts_bbox_head.with_time: + time_spec = paddle.static.InputSpec( + shape=[1, num_cams], dtype="float32") + input_spec.append(time_spec) + model_name = "petrv2_inference" + + paddle.jit.to_static(self, input_spec=input_spec) + + paddle.jit.save(self, os.path.join(save_dir, model_name)) diff --git a/paddle3d/models/heads/dense_heads/__init__.py b/paddle3d/models/heads/dense_heads/__init__.py index ddf246f8..5417e723 100644 --- a/paddle3d/models/heads/dense_heads/__init__.py +++ b/paddle3d/models/heads/dense_heads/__init__.py @@ -15,5 +15,6 @@ from .anchor_head import * from .coders import NMSFreeCoder from .petr_head import PETRHead +from .petr_head_seg import PETRHeadSeg from .point_head import PointHeadSimple from .target_assigner import * diff --git a/paddle3d/models/heads/dense_heads/petr_head_seg.py b/paddle3d/models/heads/dense_heads/petr_head_seg.py new file mode 100644 index 00000000..25fe9cb9 --- /dev/null +++ b/paddle3d/models/heads/dense_heads/petr_head_seg.py @@ -0,0 +1,1053 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +# ------------------------------------------------------------------------ +# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ + +import copy +import math +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle3d.apis import manager +from paddle3d.models.heads.dense_heads.target_assigner.hungarian_assigner import ( + HungarianAssigner3D, nan_to_num, normalize_bbox) +from paddle3d.models.layers import param_init +from paddle3d.models.layers.layer_libs import NormedLinear, inverse_sigmoid +from paddle3d.models.losses.focal_loss import FocalLoss, WeightedFocalLoss +from paddle3d.models.losses.weight_loss import WeightedL1Loss + +from .samplers.pseudo_sampler import PseudoSampler + + +def reduce_mean(tensor): + """"Obtain the mean of tensor on different GPUs.""" + if not paddle.distributed.is_initialized(): + return tensor + tensor = tensor.clone() + paddle.distributed.all_reduce( + tensor.scale_(1. / paddle.distributed.get_world_size())) + return tensor + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def pos2posemb3d(pos, num_pos_feats=128, temperature=10000): + scale = 2 * math.pi + pos = pos * scale + dim_t = paddle.arange(num_pos_feats, dtype='int32') + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + pos_x = pos[..., 0, None] / dim_t + pos_y = pos[..., 1, None] / dim_t + pos_z = pos[..., 2, None] / dim_t + pos_x = paddle.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), + axis=-1).flatten(-2) + pos_y = paddle.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), + axis=-1).flatten(-2) + pos_z = paddle.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), + axis=-1).flatten(-2) + posemb = paddle.concat((pos_y, pos_x, pos_z), axis=-1) + return posemb + +def pos2posemb2d(pos, num_pos_feats=128, temperature=10000): + scale = 2 * math.pi + pos = pos * scale + dim_t = paddle.arange(num_pos_feats, dtype='int32') + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + pos_x = pos[..., 0, None] / dim_t + pos_y = pos[..., 1, None] / dim_t + pos_x = paddle.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), axis=-1).flatten(-2) + pos_y = paddle.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), axis=-1).flatten(-2) + posemb = paddle.concat((pos_y, pos_x), axis=-1) + return posemb + + +class SELayer(nn.Layer): + def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid): + super().__init__() + self.conv_reduce = nn.Conv2D(channels, channels, 1, bias_attr=True) + self.act1 = act_layer() + self.conv_expand = nn.Conv2D(channels, channels, 1, bias_attr=True) + self.gate = gate_layer() + + def forward(self, x, x_se): + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + return x * self.gate(x_se) + + +class RegLayer(nn.Layer): + def __init__( + self, + embed_dims=256, + shared_reg_fcs=2, + group_reg_dims=(2, 1, 3, 2, 2), # xy, z, size, rot, velo + act_layer=nn.ReLU, + drop=0.0): + super().__init__() + + reg_branch = [] + for _ in range(shared_reg_fcs): + reg_branch.append(nn.Linear(embed_dims, embed_dims)) + reg_branch.append(act_layer()) + reg_branch.append(nn.Dropout(drop)) + self.reg_branch = nn.Sequential(*reg_branch) + + self.task_heads = nn.LayerList() + for reg_dim in group_reg_dims: + task_head = nn.Sequential( + nn.Linear(embed_dims, embed_dims), act_layer(), + nn.Linear(embed_dims, reg_dim)) + self.task_heads.append(task_head) + + def forward(self, x): + reg_feat = self.reg_branch(x) + outs = [] + for task_head in self.task_heads: + out = task_head(reg_feat.clone()) + outs.append(out) + outs = paddle.concat(outs, -1) + return outs + + +@manager.HEADS.add_component +class PETRHeadSeg(nn.Layer): + """Implements the DETR transformer head. + See `paper: End-to-End Object Detection with Transformers + `_ for details. + """ + + def __init__( + self, + num_classes, + in_channels, + num_query=100, + num_lane=100, + num_reg_fcs=2, + transformer=None, + transformer_lane=None, + sync_cls_avg_factor=False, + positional_encoding=None, + code_weights=None, + bbox_coder=None, + loss_cls=None, + loss_bbox=None, + loss_iou=None, + loss_lane_mask=None, + assigner=None, + with_position=True, + with_multiview=False, + depth_step=0.8, + depth_num=64, + LID=False, + depth_start=1, + position_level=0, + position_range=[-65, -65, -8.0, 65, 65, 8.0], + group_reg_dims=(2, 1, 3, 2, 2), # xy, z, size, rot, velo + init_cfg=None, + normedlinear=False, + with_fpe=False, + with_time=False, + with_multi=False, + **kwargs): + + if 'code_size' in kwargs: + self.code_size = kwargs['code_size'] + else: + self.code_size = 10 + if code_weights is not None: + self.code_weights = code_weights + else: + self.code_weights = [ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2 + ] + self.code_weights = self.code_weights[:self.code_size] + self.bg_cls_weight = 0 + self.sync_cls_avg_factor = sync_cls_avg_factor + + self.assigner = HungarianAssigner3D( + pc_range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]) + + self.sampler = PseudoSampler() + + self.num_query = num_query + self.num_lane=num_lane + self.num_classes = num_classes + self.in_channels = in_channels + self.num_reg_fcs = num_reg_fcs + + self.fp16_enabled = False + self.embed_dims = 256 + self.depth_step = depth_step + self.depth_num = depth_num + self.position_dim = 3 * self.depth_num + self.position_range = position_range + self.LID = LID + self.depth_start = depth_start + self.position_level = position_level + self.with_position = with_position + self.with_multiview = with_multiview + + self.num_pred = 6 + self.normedlinear = normedlinear + self.with_fpe = with_fpe + self.with_time = with_time + self.with_multi = with_multi + self.group_reg_dims = group_reg_dims + + super(PETRHeadSeg, self).__init__() + + self.num_classes = num_classes + self.in_channels = in_channels + + self.loss_cls = loss_cls + + self.loss_bbox = loss_bbox + + self.loss_iou = loss_iou + + self.loss_lane_mask = loss_lane_mask + + self.cls_out_channels = num_classes + + self.positional_encoding = positional_encoding + + initializer = paddle.nn.initializer.Assign(self.code_weights) + self.code_weights = self.create_parameter( + [len(self.code_weights)], default_initializer=initializer) + self.code_weights.stop_gradient = True + + self.bbox_coder = bbox_coder + self.pc_range = self.bbox_coder.pc_range + self._init_layers() + self.transformer = transformer + self.transformer_lane = transformer_lane + self.pd_eps = paddle.to_tensor(np.finfo('float32').eps) + + def _init_layers(self): + """Initialize layers of the transformer head.""" + if self.with_position: + self.input_proj = nn.Conv2D( + self.in_channels, self.embed_dims, kernel_size=1) + else: + self.input_proj = nn.Conv2D( + self.in_channels, self.embed_dims, kernel_size=1) + + cls_branch = [] + for _ in range(self.num_reg_fcs): + cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + cls_branch.append(nn.LayerNorm(self.embed_dims)) + cls_branch.append(nn.ReLU()) + if self.normedlinear: + cls_branch.append( + NormedLinear(self.embed_dims, self.cls_out_channels)) + else: + cls_branch.append(nn.Linear(self.embed_dims, self.cls_out_channels)) + fc_cls = nn.Sequential(*cls_branch) + + lane_branch = [] + for _ in range(self.num_reg_fcs): + lane_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + lane_branch.append(nn.ReLU()) + lane_branch.append(nn.Linear(self.embed_dims, 768)) + lane_branch = nn.Sequential(*lane_branch) + + if self.with_multi: + reg_branch = RegLayer(self.embed_dims, self.num_reg_fcs, + self.group_reg_dims) + else: + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(nn.Linear(self.embed_dims, self.code_size)) + reg_branch = nn.Sequential(*reg_branch) + + self.cls_branches = nn.LayerList( + [copy.deepcopy(fc_cls) for _ in range(self.num_pred)]) + self.reg_branches = nn.LayerList( + [copy.deepcopy(reg_branch) for _ in range(self.num_pred)]) + self.lane_branches = nn.LayerList( + [copy.deepcopy(lane_branch) for _ in range(self.num_pred)]) + + if self.with_multiview: + self.adapt_pos3d = nn.Sequential( + nn.Conv2D( + self.embed_dims * 3 // 2, + self.embed_dims * 4, + kernel_size=1, + stride=1, + padding=0), + nn.ReLU(), + nn.Conv2D( + self.embed_dims * 4, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + ) + else: + self.adapt_pos3d = nn.Sequential( + nn.Conv2D( + self.embed_dims, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + nn.ReLU(), + nn.Conv2D( + self.embed_dims, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + ) + + if self.with_position: + self.position_encoder = nn.Sequential( + nn.Conv2D( + self.position_dim, + self.embed_dims * 4, + kernel_size=1, + stride=1, + padding=0), + nn.ReLU(), + nn.Conv2D( + self.embed_dims * 4, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + ) + + self.reference_points = nn.Embedding(self.num_query, 3) + self.query_embedding = nn.Sequential( + nn.Linear(self.embed_dims * 3 // 2, self.embed_dims), + nn.ReLU(), + nn.Linear(self.embed_dims, self.embed_dims), + ) + if self.with_fpe: + self.fpe = SELayer(self.embed_dims) + + nx=ny=round(math.sqrt(self.num_lane)) + x = (paddle.arange(nx) + 0.5) / nx + y = (paddle.arange(ny) + 0.5) / ny + xy=paddle.meshgrid(x,y) + self.reference_points_lane = paddle.concat([xy[0].reshape([-1])[..., None],xy[1].reshape([-1])[..., None]], axis=-1) + self.query_embedding_lane = nn.Sequential( + nn.Linear(self.embed_dims * 2 // 2, self.embed_dims), + nn.ReLU(), + nn.Linear(self.embed_dims, self.embed_dims), + ) + + + def init_weights(self): + """Initialize weights of the transformer head.""" + # The initialization for transformer is important + self.input_proj.apply(param_init.reset_parameters) + self.cls_branches.apply(param_init.reset_parameters) + self.reg_branches.apply(param_init.reset_parameters) + self.lane_branches.apply(param_init.reset_parameters) + self.adapt_pos3d.apply(param_init.reset_parameters) + + if self.with_position: + self.position_encoder.apply(param_init.reset_parameters) + + if self.with_fpe: + self.fpe.apply(param_init.reset_parameters) + + self.transformer.init_weights() + self.transformer_lane.init_weights() + + param_init.uniform_init(self.reference_points.weight, 0, 1) + if self.loss_cls.use_sigmoid: + bias_val = param_init.init_bias_by_prob(0.01) + for m in self.cls_branches: + param_init.constant_init(m[-1].bias, value=bias_val) + + def position_embeding(self, img_feats, img_metas, masks=None): + eps = 1e-5 + if hasattr(self, 'export_model') and self.export_model: + pad_h, pad_w = img_metas['image_shape'] + else: + pad_h, pad_w, _ = img_metas[0]['pad_shape'][0] + + B, N, C, H, W = img_feats[self.position_level].shape + coords_h = paddle.arange(H, dtype='float32') * pad_h / H + coords_w = paddle.arange(W, dtype='float32') * pad_w / W + + if self.LID: + index = paddle.arange( + start=0, end=self.depth_num, step=1, dtype='float32') + index_1 = index + 1 + bin_size = (self.position_range[3] - self.depth_start) / ( + self.depth_num * (1 + self.depth_num)) + coords_d = self.depth_start + bin_size * index * index_1 + else: + index = paddle.arange( + start=0, end=self.depth_num, step=1, dtype='float32') + bin_size = ( + self.position_range[3] - self.depth_start) / self.depth_num + coords_d = self.depth_start + bin_size * index + + D = coords_d.shape[0] + # W, H, D, 3 + coords = paddle.stack(paddle.meshgrid( + [coords_w, coords_h, coords_d])).transpose([1, 2, 3, 0]) + coords = paddle.concat((coords, paddle.ones_like(coords[..., :1])), -1) + coords[..., :2] = coords[..., :2] * paddle.maximum( + coords[..., 2:3], + paddle.ones_like(coords[..., 2:3]) * eps) + + if not (hasattr(self, 'export_model') and self.export_model): + img2lidars = [] + for img_meta in img_metas: + img2lidar = [] + for i in range(len(img_meta['lidar2img'])): + img2lidar.append(np.linalg.inv(img_meta['lidar2img'][i])) + img2lidars.append(np.asarray(img2lidar)) + + img2lidars = np.asarray(img2lidars) + + # (B, N, 4, 4) + img2lidars = paddle.to_tensor(img2lidars).astype(coords.dtype) + else: + img2lidars = img_metas['img2lidars'] + + coords = coords.reshape([1, 1, W, H, D, 4]).tile( + [B, N, 1, 1, 1, 1]).reshape([B, N, W, H, D, 4, 1]) + + img2lidars = img2lidars.reshape([B, N, 1, 1, 1, 16]).tile( + [1, 1, W, H, D, 1]).reshape([B, N, W, H, D, 4, 4]) + + coords3d = paddle.matmul(img2lidars, coords) + coords3d = coords3d.reshape(coords3d.shape[:-1])[..., :3] + coords3d[..., 0:1] = (coords3d[..., 0:1] - self.position_range[0]) / ( + self.position_range[3] - self.position_range[0]) + coords3d[..., 1:2] = (coords3d[..., 1:2] - self.position_range[1]) / ( + self.position_range[4] - self.position_range[1]) + coords3d[..., 2:3] = (coords3d[..., 2:3] - self.position_range[2]) / ( + self.position_range[5] - self.position_range[2]) + + coords_mask = (coords3d > 1.0) | (coords3d < 0.0) + coords_mask = coords_mask.astype('float32').flatten(-2).sum(-1) > ( + D * 0.5) + coords_mask = masks | coords_mask.transpose([0, 1, 3, 2]) + + coords3d = coords3d.transpose([0, 1, 4, 5, 3, 2]).reshape( + [B * N, self.depth_num * 3, H, W]) + + coords3d = inverse_sigmoid(coords3d) + coords_position_embeding = self.position_encoder(coords3d) + + return coords_position_embeding.reshape([B, N, self.embed_dims, H, + W]), coords_mask + + def forward(self, mlvl_feats, img_metas): + """Forward function. + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 5D-tensor with shape + (B, N, C, H, W). + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \ + Shape [nb_dec, bs, num_query, 9]. + """ + + x = mlvl_feats[self.position_level] + + batch_size, num_cams = x.shape[0], x.shape[1] + + input_img_h, input_img_w, _ = img_metas[0]['pad_shape'][0] + masks = paddle.ones((batch_size, num_cams, input_img_h, input_img_w)) + + for img_id in range(batch_size): + for cam_id in range(num_cams): + img_h, img_w, _ = img_metas[img_id]['img_shape'][cam_id] + masks[img_id, cam_id, :img_h, :img_w] = 0 + + x = self.input_proj(x.flatten(0, 1)) + x = x.reshape([batch_size, num_cams, *x.shape[-3:]]) + + # interpolate masks to have the same spatial shape with x + masks = F.interpolate(masks, size=x.shape[-2:]).cast('bool') + + if self.with_position: + coords_position_embeding, _ = self.position_embeding( + mlvl_feats, img_metas, masks) + + if self.with_fpe: + coords_position_embeding = self.fpe( + coords_position_embeding.flatten(0, 1), + x.flatten(0, 1)).reshape(x.shape) + + pos_embed = coords_position_embeding + + if self.with_multiview: + sin_embed = self.positional_encoding(masks) + sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).reshape( + x.shape) + pos_embed = pos_embed + sin_embed + else: + pos_embeds = [] + for i in range(num_cams): + xy_embed = self.positional_encoding(masks[:, i, :, :]) + pos_embeds.append(xy_embed.unsqueeze(1)) + sin_embed = paddle.concat(pos_embeds, 1) + sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).reshape( + x.shape) + pos_embed = pos_embed + sin_embed + else: + if self.with_multiview: + pos_embed = self.positional_encoding(masks) + pos_embed = self.adapt_pos3d(pos_embed.flatten(0, 1)).reshape( + x.shape) + else: + pos_embeds = [] + for i in range(num_cams): + pos_embed = self.positional_encoding(masks[:, i, :, :]) + pos_embeds.append(pos_embed.unsqueeze(1)) + pos_embed = paddle.concat(pos_embeds, 1) + + reference_points = self.reference_points.weight + query_det=self.query_embedding(pos2posemb3d(reference_points)) + query_lane=self.query_embedding_lane(pos2posemb2d(self.reference_points_lane)) + + reference_points = reference_points.unsqueeze(0).tile( + [batch_size, 1, 1]) + + outs_dec, _ = self.transformer(x, masks, query_det, pos_embed, + self.reg_branches) + + outs_dec = nan_to_num(outs_dec) + + outs_dec_lane, _ = self.transformer_lane(x, masks, query_lane, pos_embed, self.lane_branches) + outs_dec_lane = nan_to_num(outs_dec_lane) + lane_queries=outs_dec_lane + + + if self.with_time: + time_stamps = [] + for img_meta in img_metas: + time_stamps.append(np.asarray(img_meta['timestamp'])) + + time_stamp = paddle.to_tensor(time_stamps, dtype=x.dtype) + time_stamp = time_stamp.reshape([batch_size, -1, 6]) + + mean_time_stamp = ( + time_stamp[:, 1, :] - time_stamp[:, 0, :]).mean(-1) + + outputs_classes = [] + outputs_coords = [] + outputs_lanes=[] + for lvl in range(outs_dec.shape[0]): + reference = inverse_sigmoid(reference_points.clone()) + assert reference.shape[-1] == 3 + outputs_class = self.cls_branches[lvl](outs_dec[lvl]) + tmp = self.reg_branches[lvl](outs_dec[lvl]) + outputs_lane=self.lane_branches[lvl](lane_queries[lvl]) + + tmp[..., 0:2] += reference[..., 0:2] + + tmp[..., 0:2] = F.sigmoid(tmp[..., 0:2]) + tmp[..., 4:5] += reference[..., 2:3] + + tmp[..., 4:5] = F.sigmoid(tmp[..., 4:5]) + + if self.with_time: + tmp[..., 8:] = tmp[..., 8:] / mean_time_stamp[:, None, None] + + outputs_coord = tmp + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_lanes.append(outputs_lane) + + all_cls_scores = paddle.stack(outputs_classes) + all_bbox_preds = paddle.stack(outputs_coords) + all_lane_preds = paddle.stack(outputs_lanes) + + all_bbox_preds[..., 0:1] = ( + all_bbox_preds[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) + + self.pc_range[0]) + all_bbox_preds[..., 1:2] = ( + all_bbox_preds[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) + + self.pc_range[1]) + all_bbox_preds[..., 4:5] = ( + all_bbox_preds[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) + + self.pc_range[2]) + + outs = { + 'all_cls_scores': all_cls_scores, + 'all_bbox_preds': all_bbox_preds, + 'all_lane_preds': all_lane_preds, + 'enc_cls_scores': None, + 'enc_bbox_preds': None, + } + return outs + + def export_forward(self, mlvl_feats, img_metas, time_stamp=None): # need to be added lane part + """Forward function. + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 5D-tensor with shape + (B, N, C, H, W). + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \ + Shape [nb_dec, bs, num_query, 9]. + """ + + x = mlvl_feats[self.position_level] + + batch_size, num_cams = x.shape[0], x.shape[1] + + input_img_h, input_img_w = img_metas['image_shape'] + + masks = paddle.zeros([batch_size, num_cams, input_img_h, input_img_w]) + + x = self.input_proj(x.flatten(0, 1)) + x = x.reshape([batch_size, num_cams, *x.shape[-3:]]) + + # interpolate masks to have the same spatial shape with x + masks = F.interpolate(masks, size=x.shape[-2:]).cast('bool') + + if self.with_position: + coords_position_embeding, _ = self.position_embeding( + mlvl_feats, img_metas, masks) + + if self.with_fpe: + coords_position_embeding = self.fpe( + coords_position_embeding.flatten(0, 1), + x.flatten(0, 1)).reshape(x.shape) + + pos_embed = coords_position_embeding + + if self.with_multiview: + sin_embed = self.positional_encoding(masks) + sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).reshape( + x.shape) + pos_embed = pos_embed + sin_embed + else: + pos_embeds = [] + for i in range(num_cams): + xy_embed = self.positional_encoding(masks[:, i, :, :]) + pos_embeds.append(xy_embed.unsqueeze(1)) + sin_embed = paddle.concat(pos_embeds, 1) + sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).reshape( + x.shape) + pos_embed = pos_embed + sin_embed + else: + if self.with_multiview: + pos_embed = self.positional_encoding(masks) + pos_embed = self.adapt_pos3d(pos_embed.flatten(0, 1)).reshape( + x.shape) + else: + pos_embeds = [] + for i in range(num_cams): + pos_embed = self.positional_encoding(masks[:, i, :, :]) + pos_embeds.append(pos_embed.unsqueeze(1)) + pos_embed = paddle.concat(pos_embeds, 1) + + reference_points = self.reference_points.weight + query_embeds = self.query_embedding(pos2posemb3d(reference_points)) + + reference_points = reference_points.unsqueeze(0).tile( + [batch_size, 1, 1]) + + outs_dec, _ = self.transformer(x, masks, query_embeds, pos_embed, + self.reg_branches) + + outs_dec = nan_to_num(outs_dec) + + if self.with_time: + time_stamp = time_stamp.reshape([batch_size, -1, 6]) + mean_time_stamp = ( + time_stamp[:, 1, :] - time_stamp[:, 0, :]).mean(-1) + + outputs_classes = [] + outputs_coords = [] + for lvl in range(outs_dec.shape[0]): + reference = inverse_sigmoid(reference_points.clone()) + assert reference.shape[-1] == 3 + outputs_class = self.cls_branches[lvl](outs_dec[lvl]) + tmp = self.reg_branches[lvl](outs_dec[lvl]) + + tmp[..., 0:2] += reference[..., 0:2] + + tmp[..., 0:2] = F.sigmoid(tmp[..., 0:2]) + tmp[..., 4:5] += reference[..., 2:3] + + tmp[..., 4:5] = F.sigmoid(tmp[..., 4:5]) + + if self.with_time: + tmp[..., 8:] = tmp[..., 8:] / mean_time_stamp[:, None, None] + + outputs_coord = tmp + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + all_cls_scores = paddle.stack(outputs_classes) + all_bbox_preds = paddle.stack(outputs_coords) + + all_bbox_preds[..., 0:1] = ( + all_bbox_preds[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) + + self.pc_range[0]) + all_bbox_preds[..., 1:2] = ( + all_bbox_preds[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) + + self.pc_range[1]) + all_bbox_preds[..., 4:5] = ( + all_bbox_preds[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) + + self.pc_range[2]) + + outs = { + 'all_cls_scores': all_cls_scores, + 'all_bbox_preds': all_bbox_preds, + # 'enc_cls_scores': None, + # 'enc_bbox_preds': None, + } + return outs + + def _get_target_single(self, + cls_score, + bbox_pred, + gt_labels, + gt_bboxes, + gt_bboxes_ignore=None): + """"Compute regression and classification targets for one image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 4]. + gt_bboxes (Tensor): Ground truth bboxes for one image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + gt_bboxes_ignore (Tensor, optional): Bounding boxes + which can be ignored. Default None. + Returns: + tuple[Tensor]: a tuple containing the following for one image. + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + + num_bboxes = bbox_pred.shape[0] + # assigner and sampler + assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, + gt_labels, gt_bboxes_ignore) + sampling_result = self.sampler.sample(assign_result, bbox_pred, + gt_bboxes) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label targets + labels = paddle.full((num_bboxes, ), self.num_classes, dtype='int64') + + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = paddle.ones([num_bboxes]) + + # bbox targets + code_size = gt_bboxes.shape[1] + bbox_targets = paddle.zeros_like(bbox_pred)[..., :code_size] + bbox_weights = paddle.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + + # DETR + if sampling_result.pos_gt_bboxes.shape[1] == 4: + bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes.reshape( + sampling_result.pos_gt_bboxes.shape[0], self.code_size - 1) + else: + bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes + + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) + + def get_targets(self, + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + gt_bboxes_ignore_list=None): + """"Compute regression and classification targets for a batch image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + tuple: a tuple containing the following targets. + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all \ + images. + - bbox_targets_list (list[Tensor]): BBox targets for all \ + images. + - bbox_weights_list (list[Tensor]): BBox weights for all \ + images. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + assert gt_bboxes_ignore_list is None, \ + 'Only supports for gt_bboxes_ignore setting to None.' + num_imgs = len(cls_scores_list) + gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)] + + (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, cls_scores_list, bbox_preds_list, + gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) + + def loss_single(self, + cls_scores, + bbox_preds, + lane_preds, + gt_bboxes_list, + gt_labels_list, + gt_lane_list, + gt_bboxes_ignore_list=None): + """"Loss function for outputs from a single decoder layer of a single + feature level. + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.shape[0] + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + gt_bboxes_list, gt_labels_list, + gt_bboxes_ignore_list) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = paddle.concat(labels_list, 0) + label_weights = paddle.concat(label_weights_list, 0) + bbox_targets = paddle.concat(bbox_targets_list, 0) + bbox_weights = paddle.concat(bbox_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape([-1, self.cls_out_channels]) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + paddle.to_tensor([cls_avg_factor], dtype=cls_scores.dtype)) + + cls_avg_factor = max(cls_avg_factor, 1) + loss_cls = self.loss_cls(cls_scores, labels, + label_weights) / (cls_avg_factor + self.pd_eps) + + # Compute the average number of gt boxes accross all gpus, for + # normalization purposes + num_total_pos = paddle.to_tensor([num_total_pos], dtype=loss_cls.dtype) + num_total_pos = paddle.clip(reduce_mean(num_total_pos), min=1).item() + + # regression L1 loss + bbox_preds = bbox_preds.reshape([-1, bbox_preds.shape[-1]]) + normalized_bbox_targets = normalize_bbox(bbox_targets, self.pc_range) + # paddle.all + isnotnan = paddle.isfinite(normalized_bbox_targets).all(axis=-1) + bbox_weights = bbox_weights * self.code_weights + + loss_bbox = self.loss_bbox( + bbox_preds[isnotnan], normalized_bbox_targets[isnotnan], + bbox_weights[isnotnan]) / (num_total_pos + self.pd_eps) + + lane_preds = lane_preds.squeeze(0) + loss_lane_mask = self.loss_lane_mask(lane_preds, gt_lane_list[0]) + loss_lane_mask = nan_to_num(loss_lane_mask) + + loss_cls = nan_to_num(loss_cls) + loss_bbox = nan_to_num(loss_bbox) + + return loss_cls, loss_bbox, loss_lane_mask + + def loss(self, + gt_bboxes_list, + gt_labels_list, + preds_dicts, + gt_lanes, + gt_bboxes_ignore=None): + """"Loss function. + Args: + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + preds_dicts: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map , has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_bbox_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, 4). Only be + passed when as_two_stage is True, otherwise is None. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert gt_bboxes_ignore is None, \ + f'{self.__class__.__name__} only supports ' \ + f'for gt_bboxes_ignore setting to None.' + + all_cls_scores = preds_dicts['all_cls_scores'] + all_bbox_preds = preds_dicts['all_bbox_preds'] + enc_cls_scores = preds_dicts['enc_cls_scores'] + enc_bbox_preds = preds_dicts['enc_bbox_preds'] + all_lane_preds = preds_dicts['all_lane_preds'] + + num_dec_layers = len(all_cls_scores) + + def get_gravity_center(bboxes): + bottom_center = bboxes[:, :3] + gravity_center = np.zeros_like(bottom_center) + gravity_center[:, :2] = bottom_center[:, :2] + gravity_center[:, 2] = bottom_center[:, 2] + bboxes[:, 5] * 0.5 + return gravity_center + + gt_bboxes_list = [ + paddle.concat((paddle.to_tensor(get_gravity_center(gt_bboxes)), + paddle.to_tensor(gt_bboxes[:, 3:])), + axis=1) for gt_bboxes in gt_bboxes_list + ] + + gt_lanes=[gt_lanes[0]] + + all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_lanes_list = [gt_lanes for _ in range(num_dec_layers)] + + all_gt_bboxes_ignore_list = [ + gt_bboxes_ignore for _ in range(num_dec_layers) + ] + + losses_cls, losses_bbox, losses_lane_masks = multi_apply( + self.loss_single, all_cls_scores, all_bbox_preds, all_lane_preds, + all_gt_bboxes_list, all_gt_labels_list, all_gt_lanes_list, + all_gt_bboxes_ignore_list) + + loss_dict = dict() + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + binary_labels_list = [ + paddle.zeros_like(gt_labels_list[i]) + for i in range(len(all_gt_labels_list)) + ] + enc_loss_cls, enc_losses_bbox = \ + self.loss_single(enc_cls_scores, enc_bbox_preds, + gt_bboxes_list, binary_labels_list, gt_bboxes_ignore) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + + loss_dict['loss_mask'] = losses_lane_masks[-1] + + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_mask_i in zip(losses_cls[:-1], losses_bbox[:-1], losses_lane_masks[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + num_dec_layer += 1 + return loss_dict + + def get_bboxes(self, preds_dicts, img_metas, rescale=False): + """Generate bboxes from bbox head predictions. + Args: + preds_dicts (tuple[list[dict]]): Prediction results. + img_metas (list[dict]): Point cloud and image's meta info. + Returns: + list[dict]: Decoded bbox, scores and labels after nms. + """ + preds_dicts = self.bbox_coder.decode(preds_dicts) + num_samples = len(preds_dicts) + + ret_list = [] + for i in range(num_samples): + preds = preds_dicts[i] + bboxes = preds['bboxes'] + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + scores = preds['scores'] + labels = preds['labels'] + ret_list.append([bboxes, scores, labels]) + return ret_list + diff --git a/paddle3d/models/losses/__init__.py b/paddle3d/models/losses/__init__.py index 51c2ffb6..6095c5af 100644 --- a/paddle3d/models/losses/__init__.py +++ b/paddle3d/models/losses/__init__.py @@ -21,3 +21,4 @@ from .disentangled_box3d_loss import DisentangledBox3DLoss, unproject_points2d from .weight_loss import (WeightedCrossEntropyLoss, WeightedSmoothL1Loss, get_corner_loss_lidar) +from .lane_loss import SigmoidCELoss, FocalDiceLoss \ No newline at end of file diff --git a/paddle3d/models/losses/lane_loss.py b/paddle3d/models/losses/lane_loss.py new file mode 100644 index 00000000..3bf1f260 --- /dev/null +++ b/paddle3d/models/losses/lane_loss.py @@ -0,0 +1,76 @@ +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddle3d.apis import manager + + +@manager.LOSSES.add_component +class SigmoidCELoss(nn.Layer): + def __init__(self, loss_weight=1.0, reduction='mean'): + super(SigmoidCELoss, self).__init__() + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, inputs, targets): + """Forward function to calculate accuracy. + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + pos_weight = paddle.to_tensor((targets == 0), dtype='float32').sum(axis=1) / \ + paddle.to_tensor((targets == 1), dtype='float32').sum(axis=1).clip(min=1.0) + pos_weight = pos_weight.unsqueeze(1) + weight_loss = targets * pos_weight + (1 - targets) + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=self.reduction, weight=weight_loss) + return self.loss_weight * loss + + +@manager.LOSSES.add_component +class FocalDiceLoss(nn.Layer): + def __init__(self, alpha=1, gamma=2, loss_weight=1.0, reduction='mean'): + super(FocalDiceLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.loss_weight = loss_weight + self.reduction = reduction + + def focal_loss(self, inputs, labels): + BCE_loss = F.binary_cross_entropy_with_logits(inputs, labels, reduction='none') + pt = paddle.exp(-BCE_loss) + F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss + if self.reduction == 'mean': + return paddle.mean(F_loss) + elif self.reduction == 'sum': + return paddle.sum(F_loss) + + def dice_loss(self, inputs, labels, smooth=1): + inputs = F.sigmoid(inputs) + inputs = inputs.flatten() + labels = labels.flatten() + intersection = (inputs * labels).sum() + dice = (2. * intersection + smooth) / (inputs.sum() + labels.sum() + smooth) + return 1 - dice + + def forward(self, inputs, labels, smooth=1): + focal = self.focal_loss(inputs, labels) + dice = self.dice_loss(inputs, labels, smooth=smooth) + return focal + dice + + +if __name__ == '__main__': + import numpy as np + logit = paddle.randn([256, 768]) + arr = np.random.randn(256 * 768).reshape([256, 768]) + arr = np.clip(arr, 0, 1) + arr = arr.astype(np.int8).astype(np.float32) + label = paddle.to_tensor(arr) + #loss = SigmoidCELoss() + loss = FocalDiceLoss() + loss_value = loss(logit, label) + print('loss_value: ', loss_value) + + + \ No newline at end of file diff --git a/paddle3d/transforms/reader.py b/paddle3d/transforms/reader.py index 9272bd49..6cc69954 100644 --- a/paddle3d/transforms/reader.py +++ b/paddle3d/transforms/reader.py @@ -33,10 +33,41 @@ __all__ = [ "LoadImage", "LoadPointCloud", "RemoveCameraInvisiblePointsKITTI", - "RemoveCameraInvisiblePointsKITTIV2", "LoadSemanticKITTIRange" + "RemoveCameraInvisiblePointsKITTIV2", "LoadSemanticKITTIRange", + "LoadMapsFromFiles" ] +@manager.TRANSFORMS.add_component +class LoadMapsFromFiles(object): + def __init__(self, map_data_root, k=None): + self.map_data_root = map_data_root + self.k=k + + def resize_map(self, map_mask, scale_factor=2): + H, W = map_mask.shape[: 2] + map_mask_resized = cv2.resize(map_mask.astype(np.uint8), (scale_factor * W, scale_factor * H)) + return map_mask_resized.astype(np.float32) + + def __call__(self, results): + + map_filename = results['map_filename'] + map_filename = os.path.join(self.map_data_root, os.path.basename(map_filename)) + maps = np.load(map_filename) + map_mask = maps['arr_0'].astype(np.float32) + + map_mask = self.resize_map(map_mask, scale_factor=2) # for (512, 512) gt_map + + maps = map_mask.transpose((2,0,1)) + results['gt_map'] = maps + maps = rearrange(maps, 'c (h h1) (w w2) -> (h w) c h1 w2 ', h1=16, w2=16) + #maps = maps.reshape(256,3*256) + maps = maps.reshape(1024,3*256) # for (512, 512) gt_map + results['map_shape'] = maps.shape + results['maps'] = maps + return results + + @manager.TRANSFORMS.add_component class LoadImage(TransformABC): """ @@ -468,7 +499,8 @@ class LoadMultiViewImageFromFiles(TransformABC): - 1: cv2.IMREAD_COLOR """ - def __init__(self, to_float32=False, imread_flag=-1): + def __init__(self, data_root, to_float32=False, imread_flag=-1): + self.data_root = data_root self.to_float32 = to_float32 self.imread_flag = imread_flag