diff --git a/mmhuman3d/data/datasets/human_image_dataset.py b/mmhuman3d/data/datasets/human_image_dataset.py index 6775af02..1006b9c0 100644 --- a/mmhuman3d/data/datasets/human_image_dataset.py +++ b/mmhuman3d/data/datasets/human_image_dataset.py @@ -393,7 +393,7 @@ def _parse_result(self, res, mode='keypoint', body_part=None): global_orient=global_orient, gender=gender) gt_keypoints3d = gt_output['joints'].detach().cpu().numpy() - gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24)) + gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), pred_keypoints3d.shape[1])) elif self.dataset_name == 'h36m': _, h36m_idxs, _ = get_mapping('human_data', 'h36m') gt_keypoints3d = \ @@ -436,7 +436,7 @@ def _parse_result(self, res, mode='keypoint', body_part=None): pred_keypoints3d[:, 3]) / 2 gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2 - # H36M for testing! + # H36M convention for testing! elif gt_keypoints3d.shape[1] == 17: assert pred_keypoints3d.shape[1] == 17