Skip to content

Commit

Permalink
Fix DRangedTree bounding on left-hand side of values
Browse files Browse the repository at this point in the history
  • Loading branch information
robinmessage committed Sep 27, 2023
1 parent 23cc3d0 commit 99a7b93
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions methods/utils/dranged_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def contains(self, point):
return self.left.contains(point)
def dump(self, space):
print(space, f"split axis {self.axis} at {self.value}")
print(space + " <")
print(space + " <=")
self.left.dump(space + "\t")
print(space + " >")
self.right.dump(space + "\t")
Expand All @@ -149,6 +149,29 @@ def depth(self):
def size(self):
return 1 + self.left.size() + self.right.size()

class SplitDLeanRightTree(DRangedTree):
"""This is identical to a SplitDTree, except values equal to the value go into the right tree instead of the left"""
def __init__(self, left: DRangedTree, right: DRangedTree, axis: int, value: float):
self.left = left
self.right = right
self.axis = axis
self.value = value
def contains(self, point):
if point[self.axis] >= self.value:
return self.right.contains(point)
else:
return self.left.contains(point)
def dump(self, space):
print(space, f"split axis {self.axis} at {self.value}")
print(space + " <")
self.left.dump(space + "\t")
print(space + " >=")
self.right.dump(space + "\t")
def depth(self):
return 1 + max(self.left.depth(), self.right.depth())
def size(self):
return 1 + self.left.size() + self.right.size()

class TreeState:
def __init__(self, dimensions_or_tree, j: int = -1, drop: bool = False):
if isinstance(dimensions_or_tree, TreeState):
Expand Down Expand Up @@ -234,7 +257,8 @@ def _make_tree_internal(rects, bounds, state: TreeState):
new_bounds = np.copy(bounds)
new_bounds[0, d] = left
subtree = _make_tree_internal(rects, new_bounds, state.descend(d))
return SplitDTree(EmptyTree(), subtree, d, left)
# We need to lean right, because we only want to exclude values strictly less than left.
return SplitDLeanRightTree(EmptyTree(), subtree, d, left)
if bounds[1, d] == math.inf:
right = np.max(rects[:, 1, d])
new_bounds = np.copy(bounds)
Expand Down

0 comments on commit 99a7b93

Please sign in to comment.