Skip to content

Commit

Permalink
Merge pull request #386 from w-k-jones/refactor_pbc_dist_func
Browse files Browse the repository at this point in the history
Refactor of PBC distance functions
  • Loading branch information
w-k-jones authored Jun 4, 2024
2 parents 2d1cc4a + 319d7a3 commit 8e7557c
Show file tree
Hide file tree
Showing 7 changed files with 442 additions and 198 deletions.
5 changes: 2 additions & 3 deletions tobac/feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import iris
import xarray as xr

from tobac.tracking import build_distance_function
from tobac.utils import internal as internal_utils
from tobac.utils import decorators

Expand Down Expand Up @@ -1545,8 +1544,8 @@ def filter_min_distance(
# Check if we have PBCs.
if PBC_flag in ["hdim_1", "hdim_2", "both"]:
# Note that we multiply by dxy to get the distances in spatial coordinates
dist_func = build_distance_function(
min_h1 * dxy, max_h1 * dxy, min_h2 * dxy, max_h2 * dxy, PBC_flag
dist_func = pbc_utils.build_distance_function(
min_h1 * dxy, max_h1 * dxy, min_h2 * dxy, max_h2 * dxy, PBC_flag, is_3D
)
features_tree = BallTree(feature_locations, metric="pyfunc", func=dist_func)
neighbours = features_tree.query_radius(feature_locations, r=min_distance)
Expand Down
47 changes: 23 additions & 24 deletions tobac/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def add_markers(
h2_end_coord=hdim_2_max,
PBC_flag=PBC_flag,
)
# Build distance function ahead of time, 3D always true as we then reduce
dist_func = pbc_utils.build_distance_function(
0, h1_len, 0, h2_len, PBC_flag, True
)
for seed_box in all_seed_boxes:
# Need to see if there are any other points seeded
# in this seed box first.
Expand All @@ -205,6 +209,7 @@ def add_markers(
local_index[1] + seed_box[0],
local_index[2] + seed_box[2],
)

# If it's a background marker, we can just set it
# with the feature we're working on.
if curr_box_pt == bg_marker:
Expand All @@ -213,18 +218,14 @@ def add_markers(
# it has another feature in it. Calculate the distance
# from its current set feature and the new feature.
if is_3D:
curr_coord = (row["vdim"], row["hdim_1"], row["hdim_2"])
curr_coord = np.array(
(row["vdim"], row["hdim_1"], row["hdim_2"])
)
else:
curr_coord = (0, row["hdim_1"], row["hdim_2"])

dist_from_curr_pt = pbc_utils.calc_distance_coords_pbc(
np.array(global_index),
np.array(curr_coord),
min_h1=0,
max_h1=h1_len,
min_h2=0,
max_h2=h2_len,
PBC_flag=PBC_flag,
curr_coord = np.array((0, row["hdim_1"], row["hdim_2"]))

dist_from_curr_pt = dist_func(
np.array(global_index), curr_coord
)

# This is technically an O(N^2) operation, but
Expand All @@ -234,21 +235,19 @@ def add_markers(
features["feature"] == curr_box_pt
].iloc[0]
if is_3D:
orig_coord = (
orig_row["vdim"],
orig_row["hdim_1"],
orig_row["hdim_2"],
orig_coord = np.array(
(
orig_row["vdim"],
orig_row["hdim_1"],
orig_row["hdim_2"],
)
)
else:
orig_coord = (0, orig_row["hdim_1"], orig_row["hdim_2"])
dist_from_orig_pt = pbc_utils.calc_distance_coords_pbc(
np.array(global_index),
np.array(orig_coord),
min_h1=0,
max_h1=h1_len,
min_h2=0,
max_h2=h2_len,
PBC_flag=PBC_flag,
orig_coord = np.array(
(0, orig_row["hdim_1"], orig_row["hdim_2"])
)
dist_from_orig_pt = dist_func(
np.array(global_index), orig_coord
)
# The current point center is further away
# than the original point center, so do nothing
Expand Down
Loading

0 comments on commit 8e7557c

Please sign in to comment.