Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delucchi/radius #174

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def crossmatch(
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
return Catalog(ddf, ddf_map, hc_catalog)

def cone_search(self, ra: float, dec: float, radius: float):
def cone_search(self, ra: float, dec: float, radius_arcsec: float):
"""Perform a cone search to filter the catalog

Filters to points within radius great circle distance to the point specified by ra and dec in degrees.
Expand All @@ -219,13 +219,13 @@ def cone_search(self, ra: float, dec: float, radius: float):
Args:
ra (float): Right Ascension of the center of the cone in degrees
dec (float): Declination of the center of the cone in degrees
radius (float): Radius of the cone in degrees
radius (float): Radius of the cone in arcseconds

Returns:
A new Catalog containing the points filtered to those within the cone, and the partitions that
overlap the cone.
"""
return self._search(ConeSearch(ra, dec, radius, self.hc_structure))
return self._search(ConeSearch(ra, dec, radius_arcsec, self.hc_structure))

def box(self, ra: Tuple[float, float] | None = None, dec: Tuple[float, float] | None = None) -> Catalog:
"""Performs filtering according to right ascension and declination ranges.
Expand Down
19 changes: 10 additions & 9 deletions src/lsdb/core/search/cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,34 @@ class ConeSearch(AbstractSearch):
Filters partitions in the catalog to those that have some overlap with the cone.
"""

def __init__(self, ra, dec, radius, metadata):
validate_radius(radius)
def __init__(self, ra, dec, radius_arcsec, metadata):
validate_radius(radius_arcsec)
validate_declination_values(dec)

self.ra = ra
self.dec = dec
self.radius = radius
self.radius_arcsec = radius_arcsec
self.metadata = metadata

def search_partitions(self, pixels: List[HealpixPixel]) -> List[HealpixPixel]:
"""Determine the target partitions for further filtering."""
pixel_tree = PixelTreeBuilder.from_healpix(pixels)
return filter_pixels_by_cone(pixel_tree, self.ra, self.dec, self.radius)
return filter_pixels_by_cone(pixel_tree, self.ra, self.dec, self.radius_arcsec)

def search_points(self, frame: pd.DataFrame) -> pd.DataFrame:
"""Determine the search results within a data frame"""
return cone_filter(frame, self.ra, self.dec, self.radius, self.metadata)
return cone_filter(frame, self.ra, self.dec, self.radius_arcsec, self.metadata)


@dask.delayed
def cone_filter(data_frame: pd.DataFrame, ra, dec, radius, metadata: hc.catalog.Catalog):
def cone_filter(data_frame: pd.DataFrame, ra, dec, radius_arcsec, metadata: hc.catalog.Catalog):
"""Filters a dataframe to only include points within the specified cone

Args:
data_frame (pd.DataFrame): DataFrame containing points in the sky
ra (float): Right Ascension of the center of the cone in degrees
dec (float): Declination of the center of the cone in degrees
radius (float): Radius of the cone in degrees
radius_arcsec (float): Radius of the cone in arcseconds
metadata (hipscat.Catalog): hipscat `Catalog` with catalog_info that matches `data_frame`

Returns:
Expand All @@ -56,6 +56,7 @@ def cone_filter(data_frame: pd.DataFrame, ra, dec, radius, metadata: hc.catalog.
df_decs = data_frame[metadata.catalog_info.dec_column].values
df_coords = SkyCoord(df_ras, df_decs, unit="deg")
center_coord = SkyCoord(ra, dec, unit="deg")
df_separations = df_coords.separation(center_coord).value
data_frame = data_frame.iloc[df_separations < radius]
df_separations_deg = df_coords.separation(center_coord).value
radius_degrees = radius_arcsec / 3600
data_frame = data_frame.iloc[df_separations_deg < radius_degrees]
return data_frame
2 changes: 2 additions & 0 deletions tests/data/small_sky_to_xmatch/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Norder,Npix,Dir
0,11,0
4 changes: 4 additions & 0 deletions tests/data/small_sky_to_xmatch/partition_join_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Norder,Npix,join_Norder,join_Npix
0,11,1,44
0,11,1,45
0,11,1,46
8 changes: 4 additions & 4 deletions tests/data/small_sky_xmatch/partition_info.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Norder,Dir,Npix,num_rows
1,0,44,41
1,0,45,29
1,0,46,41
Norder,Npix,Dir
1,44,0
1,45,0
1,46,0
6 changes: 6 additions & 0 deletions tests/data/small_sky_xmatch_margin/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Norder,Npix,Dir
0,4,0
1,44,0
1,45,0
1,46,0
1,47,0
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_save_catalog_with_some_empty_partitions(small_sky_order1_catalog, tmp_p
base_catalog_path = os.path.join(tmp_path, "small_sky")

# The result of this cone search is known to have one empty partition
cone_search_catalog = small_sky_order1_catalog.cone_search(0, -80, 15)
cone_search_catalog = small_sky_order1_catalog.cone_search(0, -80, 15 * 3600)
assert cone_search_catalog._ddf.npartitions == 2

non_empty_pixels = []
Expand Down
11 changes: 6 additions & 5 deletions tests/lsdb/catalog/test_cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
def test_cone_search_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct):
ra = 0
dec = -80
radius = 20
radius_degrees = 20
radius = radius_degrees * 3600
center_coord = SkyCoord(ra, dec, unit="deg")
cone_search_catalog = small_sky_order1_catalog.cone_search(ra, dec, radius)
cone_search_df = cone_search_catalog.compute()
for _, row in small_sky_order1_catalog.compute().iterrows():
row_ra = row[small_sky_order1_catalog.hc_structure.catalog_info.ra_column]
row_dec = row[small_sky_order1_catalog.hc_structure.catalog_info.dec_column]
sep = SkyCoord(row_ra, row_dec, unit="deg").separation(center_coord)
if sep.degree <= radius:
if sep.degree <= radius_degrees:
assert len(cone_search_df.loc[cone_search_df["id"] == row["id"]]) == 1
else:
assert len(cone_search_df.loc[cone_search_df["id"] == row["id"]]) == 0
Expand All @@ -24,7 +25,7 @@ def test_cone_search_filters_correct_points(small_sky_order1_catalog, assert_div
def test_cone_search_filters_partitions(small_sky_order1_catalog):
ra = 0
dec = -80
radius = 20
radius = 20 * 3600
hc_conesearch = small_sky_order1_catalog.hc_structure.filter_by_cone(ra, dec, radius)
consearch_catalog = small_sky_order1_catalog.cone_search(ra, dec, radius)
assert len(hc_conesearch.get_healpix_pixels()) == len(consearch_catalog.get_healpix_pixels())
Expand All @@ -36,7 +37,7 @@ def test_cone_search_filters_partitions(small_sky_order1_catalog):
def test_cone_search_filters_no_matching_points(small_sky_order1_catalog, assert_divisions_are_correct):
ra = 0
dec = -80
radius = 0.2
radius = 0.2 * 3600
cone_search_catalog = small_sky_order1_catalog.cone_search(ra, dec, radius)
cone_search_df = cone_search_catalog.compute()
assert len(cone_search_df) == 0
Expand All @@ -46,7 +47,7 @@ def test_cone_search_filters_no_matching_points(small_sky_order1_catalog, assert
def test_cone_search_filters_no_matching_partitions(small_sky_order1_catalog, assert_divisions_are_correct):
ra = 20
dec = 80
radius = 20
radius = 20 * 3600
cone_search_catalog = small_sky_order1_catalog.cone_search(ra, dec, radius)
cone_search_df = cone_search_catalog.compute()
assert len(cone_search_df) == 0
Expand Down
28 changes: 16 additions & 12 deletions tests/lsdb/catalog/test_crossmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
class TestCrossmatch:
@staticmethod
def test_kdtree_crossmatch(algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, algorithm=algo, radius_arcsec=0.01 * 3600
).compute()
with pytest.warns(RuntimeWarning, match="Results may be inaccurate"):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, algorithm=algo, radius_arcsec=0.01 * 3600
).compute()
assert len(xmatched) == len(xmatch_correct)
for _, correct_row in xmatch_correct.iterrows():
assert correct_row["ss_id"] in xmatched["id_small_sky"].values
Expand All @@ -24,9 +25,10 @@ def test_kdtree_crossmatch(algo, small_sky_catalog, small_sky_xmatch_catalog, xm

@staticmethod
def test_kdtree_crossmatch_thresh(algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct_005):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, radius_arcsec=0.005 * 3600, algorithm=algo
).compute()
with pytest.warns(RuntimeWarning, match="Results may be inaccurate"):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, radius_arcsec=0.005 * 3600, algorithm=algo
).compute()
assert len(xmatched) == len(xmatch_correct_005)
for _, correct_row in xmatch_correct_005.iterrows():
assert correct_row["ss_id"] in xmatched["id_small_sky"].values
Expand All @@ -38,9 +40,10 @@ def test_kdtree_crossmatch_thresh(algo, small_sky_catalog, small_sky_xmatch_cata
def test_kdtree_crossmatch_multiple_neighbors(
algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct_3n_2t_no_margin
):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, n_neighbors=3, radius_arcsec=2 * 3600, algorithm=algo
).compute()
with pytest.warns(RuntimeWarning, match="Results may be inaccurate"):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, n_neighbors=3, radius_arcsec=2 * 3600, algorithm=algo
).compute()
assert len(xmatched) == len(xmatch_correct_3n_2t_no_margin)
for _, correct_row in xmatch_correct_3n_2t_no_margin.iterrows():
assert correct_row["ss_id"] in xmatched["id_small_sky"].values
Expand Down Expand Up @@ -102,9 +105,10 @@ def test_wrong_suffixes(algo, small_sky_catalog, small_sky_xmatch_catalog):


def test_custom_crossmatch_algorithm(small_sky_catalog, small_sky_xmatch_catalog, xmatch_mock):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, algorithm=MockCrossmatchAlgorithm, mock_results=xmatch_mock
).compute()
with pytest.warns(RuntimeWarning, match="Results may be inaccurate"):
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, algorithm=MockCrossmatchAlgorithm, mock_results=xmatch_mock
).compute()
assert len(xmatched) == len(xmatch_mock)
for _, correct_row in xmatch_mock.iterrows():
assert correct_row["ss_id"] in xmatched["id_small_sky"].values
Expand Down
Loading