From 23cc3d03d14bb9eeced27ca9421a292f88fcb757 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Wed, 27 Sep 2023 10:15:20 +0100 Subject: [PATCH] Use expected_fraction to cut branches in DRangedTree early --- .../matching/find_potential_matches_fast.py | 2 +- methods/utils/dranged_tree.py | 31 ++++++++++++++++--- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/methods/matching/find_potential_matches_fast.py b/methods/matching/find_potential_matches_fast.py index e8b9f94..40ff1e5 100644 --- a/methods/matching/find_potential_matches_fast.py +++ b/methods/matching/find_potential_matches_fast.py @@ -76,7 +76,7 @@ def load_k( -0.1, -0.1, -0.1, - ])) + ]), 1 / 300) print("Sending k trees to workers...") diff --git a/methods/utils/dranged_tree.py b/methods/utils/dranged_tree.py index ae8aeaa..aa274a1 100644 --- a/methods/utils/dranged_tree.py +++ b/methods/utils/dranged_tree.py @@ -31,11 +31,12 @@ def size(self) -> int: return 1 @staticmethod - def build(items, widths) -> DRangedTree: + def build(items, widths, expected_fraction) -> DRangedTree: """ Build a DRangedTree for the items in list. Each item has the corresponding width for each dimension +/- to it. If the width for a dimension is negative, it is treated as a relative fractional width. So a width of 10 means the value of item +/- 10, whereas a width of -0.1 means +/- 10% of the original value. + expected_fraction is the proportion of search points we expect to match with. """ items = np.unique(items, axis = 0) # Ditch duplicate points # Reshape each point into a hyper-rect +- width @@ -59,6 +60,7 @@ def build(items, widths) -> DRangedTree: # Build the tree state = TreeState(dimensions) state.logging = False + state.bound_dimension_at = math.ceil(math.log2(1/(1-math.pow(expected_fraction, 1 / len(widths))))) - 1 return _make_tree_internal(rects, bounds, state) class SingletonTree(DRangedTree): @@ -154,6 +156,7 @@ def __init__(self, dimensions_or_tree, j: int = -1, drop: bool = False): assert(j < existing.dimensions) self.depth = existing.depth + 1 self.logging = existing.logging + self.bound_dimension_at = existing.bound_dimension_at if drop: self.dimensions = existing.dimensions - 1 self.descent = without(existing.descent, j) @@ -166,6 +169,7 @@ def __init__(self, dimensions_or_tree, j: int = -1, drop: bool = False): self.dimensions = dimensions_or_tree self.descent = np.zeros(dimensions_or_tree) self.logging = False + self.bound_dimension_at = -1 def descend(self, j: int) -> TreeState: return TreeState(self, j) @@ -221,6 +225,24 @@ def _make_tree_internal(rects, bounds, state: TreeState): subtree = _make_tree_internal(sub_rects, sub_bounds, state.drop(d)) return FulfilledTree(subtree, d) + # Check if we need to bound a dimension + for d in range(dimensions): + if state.descent[d] == state.bound_dimension_at: + # Check if bounds are infinite + if bounds[0, d] == -math.inf: + left = np.min(rects[:, 0, d]) + new_bounds = np.copy(bounds) + new_bounds[0, d] = left + subtree = _make_tree_internal(rects, new_bounds, state.descend(d)) + return SplitDTree(EmptyTree(), subtree, d, left) + if bounds[1, d] == math.inf: + right = np.max(rects[:, 1, d]) + new_bounds = np.copy(bounds) + new_bounds[1, d] = right + subtree = _make_tree_internal(rects, new_bounds, state.descend(d)) + return SplitDTree(subtree, EmptyTree(), d, right) + + # Identify possible split points for each dimension. Logically, these are the edges of the hyper-rects, # as it doesn't make sense to split not on an edge. lefts = [rects[:, 0, d] for d in range(dimensions)] @@ -423,6 +445,8 @@ def build_rects(items): from methods.common.luc import luc_matching_columns from time import time + expected_fraction = 1 / 300 # This proportion of pixels we end up matching + def build_dranged_tree_for_k(k_rows) -> DRangedTree: return DRangedTree.build(np.array([( row.elevation, @@ -435,12 +459,11 @@ def build_dranged_tree_for_k(k_rows) -> DRangedTree: row["cpc10_u"], row["cpc10_d"], ) for row in k_rows - ]), ALLOWED_VARIATION) + ]), ALLOWED_VARIATION, expected_fraction) - luc0, luc5, luc10 = luc_matching_columns(2012) + luc0, luc5, luc10 = luc_matching_columns(2012) source_pixels = pd.read_parquet("./test/data/1201-k.parquet") - expected_fraction = 1 / 300 # This proportion of pixels we end up matching # Split source_pixels into classes source_rows = []