diff --git a/models/gt_matches_generation.py b/models/gt_matches_generation.py index 97bf33f..52ec367 100644 --- a/models/gt_matches_generation.py +++ b/models/gt_matches_generation.py @@ -54,14 +54,19 @@ def generate_gt_matches(data: Dict[str, Any], symmetric_dist = 0.5 * (min_dist0[cross_check_consistent0] + min_dist1[cross_check_consistent1]) - gt_matches0[cross_check_consistent0][symmetric_dist > positive_threshold] = IGNORE_INDEX - gt_matches0[cross_check_consistent0][symmetric_dist > negative_threshold] = UNMATCHED_INDEX + gt_matches0_cross = gt_matches0[cross_check_consistent0].clone() + gt_matches1_cross = gt_matches1[cross_check_consistent1].clone() + gt_matches0_uncross = gt_matches0[~cross_check_consistent0].clone() + gt_matches1_uncross = gt_matches1[~cross_check_consistent1].clone() - gt_matches1[cross_check_consistent1][symmetric_dist > positive_threshold] = IGNORE_INDEX - gt_matches1[cross_check_consistent1][symmetric_dist > negative_threshold] = UNMATCHED_INDEX + gt_matches0_cross[symmetric_dist > positive_threshold] = IGNORE_INDEX + gt_matches0_cross[symmetric_dist > negative_threshold] = UNMATCHED_INDEX - gt_matches0[~cross_check_consistent0][min_dist0[~cross_check_consistent0] <= negative_threshold] = IGNORE_INDEX - gt_matches1[~cross_check_consistent1][min_dist1[~cross_check_consistent1] <= negative_threshold] = IGNORE_INDEX + gt_matches1_cross[symmetric_dist > positive_threshold] = IGNORE_INDEX + gt_matches1_cross[symmetric_dist > negative_threshold] = UNMATCHED_INDEX + + gt_matches0_uncross[min_dist0[~cross_check_consistent0] <= negative_threshold] = IGNORE_INDEX + gt_matches1_uncross[min_dist1[~cross_check_consistent1] <= negative_threshold] = IGNORE_INDEX # mutual NN with sym_dist <= pos.th ==> MATCHED # mutual NN with pos.th < sym_dist <= neg.th ==> IGNORED @@ -74,8 +79,13 @@ def generate_gt_matches(data: Dict[str, Any], gt_matches1[~mask1] = IGNORE_INDEX # also ignore MATCHED point if its nearest neighbor is invalid - gt_matches0[cross_check_consistent0][~mask1.gather(1, nn_matches0)[cross_check_consistent0]] = IGNORE_INDEX - gt_matches1[cross_check_consistent1][~mask0.gather(1, nn_matches1)[cross_check_consistent1]] = IGNORE_INDEX + gt_matches0_cross[~mask1.gather(1, nn_matches0)[cross_check_consistent0]] = IGNORE_INDEX + gt_matches1_cross[~mask0.gather(1, nn_matches1)[cross_check_consistent1]] = IGNORE_INDEX + + # update gt_matches0, gt_matches1 + gt_matches0.masked_scatter_(cross_check_consistent0, gt_matches0_cross) + gt_matches1.masked_scatter_(cross_check_consistent1, gt_matches1_cross) + data = { **data,