Skip to content

Commit

Permalink
Merge pull request #192 from astronomy-commons/sean/margin-filter
Browse files Browse the repository at this point in the history
Add spatial filtering to margin catalogs
  • Loading branch information
smcguire-cmu authored Feb 29, 2024
2 parents b538a40 + 199b893 commit 580e428
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 11 deletions.
14 changes: 4 additions & 10 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import dataclasses
from typing import Any, Dict, List, Tuple, Type, Union, cast
from typing import Any, Dict, List, Tuple, Type, Union

import dask.dataframe as dd
import hipscat as hc
Expand All @@ -17,7 +17,6 @@
from lsdb.core.search import ConeSearch, IndexSearch, PolygonSearch
from lsdb.core.search.box_search import BoxSearch
from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data
from lsdb.dask.divisions import get_pixels_divisions
from lsdb.dask.join_catalog_data import join_catalog_data_on, join_catalog_data_through
from lsdb.types import DaskDFPixelMap

Expand Down Expand Up @@ -266,14 +265,9 @@ def _search(self, search):
"""
filtered_pixels = search.search_partitions(self.hc_structure.get_healpix_pixels())
filtered_hc_structure = self.hc_structure.filter_from_pixel_list(filtered_pixels)
partitions = self._ddf.to_delayed()
targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels]
filtered_partitions = [search.search_points(partition) for partition in targeted_partitions]
divisions = get_pixels_divisions(filtered_pixels)
search_ddf = dd.from_delayed(filtered_partitions, meta=self._ddf._meta, divisions=divisions)
search_ddf = cast(dd.DataFrame, search_ddf)
ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)}
return Catalog(search_ddf, ddf_partition_map, filtered_hc_structure)
ddf_partition_map, search_ddf = self._perform_search(filtered_pixels, search)
margin = self.margin._search(search) if self.margin is not None else None
return Catalog(search_ddf, ddf_partition_map, filtered_hc_structure, margin=margin)

def merge(
self,
Expand Down
25 changes: 24 additions & 1 deletion src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, cast

import dask.dataframe as dd
import numpy as np
Expand All @@ -8,9 +8,12 @@
from typing_extensions import Self

from lsdb.catalog.dataset.dataset import Dataset
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.dask.divisions import get_pixels_divisions
from lsdb.types import DaskDFPixelMap


# pylint: disable=W0212
class HealpixDataset(Dataset):
"""LSDB Catalog DataFrame to perform analysis of sky catalogs and efficient
spatial operations.
Expand Down Expand Up @@ -106,3 +109,23 @@ def query(self, expr: str) -> Self:
"""
ddf = self._ddf.query(expr)
return self.__class__(ddf, self._ddf_pixel_map, self.hc_structure)

def _perform_search(self, filtered_pixels: List[HealpixPixel], search: AbstractSearch):
"""Performs a search on the catalog from a list of pixels to search in
Args:
filtered_pixels (List[HealpixPixel]): List of pixels in the catalog to be searched
search (AbstractSearch): The search object to perform the search with
Returns:
A tuple containing a dictionary mapping pixel to partition index and a dask dataframe
containing the search results
"""
partitions = self._ddf.to_delayed()
targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels]
filtered_partitions = [search.search_points(partition) for partition in targeted_partitions]
divisions = get_pixels_divisions(filtered_pixels)
search_ddf = dd.from_delayed(filtered_partitions, meta=self._ddf._meta, divisions=divisions)
search_ddf = cast(dd.DataFrame, search_ddf)
ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)}
return ddf_partition_map, search_ddf
51 changes: 51 additions & 0 deletions src/lsdb/catalog/margin_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import dask.dataframe as dd
import healpy as hp
import hipscat as hc
import numpy as np
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.filter import get_filtered_pixel_list
from hipscat.pixel_tree.pixel_tree_builder import PixelTreeBuilder

from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
from lsdb.types import DaskDFPixelMap
Expand All @@ -23,3 +28,49 @@ def __init__(
hc_structure: hc.catalog.MarginCatalog,
):
super().__init__(ddf, ddf_pixel_map, hc_structure)

def _search(self, search):
"""Find rows by reusable search algorithm.
Filters partitions in the catalog to those that match some rough criteria and their neighbors.
Filters to points that match some finer criteria.
Args:
search: instance of AbstractSearch
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
"""

# if the margin size is greater than the size of a pixel, this is an invalid search
max_order = self.hc_structure.pixel_tree.get_max_depth()
max_order_size = hp.nside2resol(2**max_order, arcmin=True)
if self.hc_structure.catalog_info.margin_threshold > max_order_size * 60:
raise ValueError(
f"Margin size {self.hc_structure.catalog_info.margin_threshold} is greater than the size of "
f"a pixel at the highest order {max_order}."
)

# Get the pixels that match the search pixels
filtered_search_pixels = search.search_partitions(self.hc_structure.get_healpix_pixels())

# Get the margin pixels at the max order + 1 from the search pixels
# the get_margin function requires a higher order than the given pixel
margin_order = max(pixel.order for pixel in filtered_search_pixels) + 1
margin_pixels = [
hc.pixel_math.get_margin(pixel.order, pixel.pixel, margin_order - pixel.order)
for pixel in filtered_search_pixels
]

# Remove duplicate margin pixels and construct HealpixPixel objects
margin_pixels = list(set(np.concatenate(margin_pixels)))
margin_pixels = [HealpixPixel(margin_order, pixel) for pixel in margin_pixels]

# Align the margin pixels with the catalog pixels and combine with the search pixels
margin_pixel_tree = PixelTreeBuilder.from_healpix(margin_pixels)
filtered_margin_pixels = get_filtered_pixel_list(self.hc_structure.pixel_tree, margin_pixel_tree)
filtered_pixels = list(set(filtered_search_pixels + filtered_margin_pixels))

filtered_hc_structure = self.hc_structure.filter_from_pixel_list(filtered_pixels)
ddf_partition_map, search_ddf = self._perform_search(filtered_pixels, search)
return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure)
16 changes: 16 additions & 0 deletions tests/lsdb/catalog/test_box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@ def test_box_search_ra(small_sky_order1_catalog, assert_divisions_are_correct):
assert_divisions_are_correct(ra_search_catalog)


def test_box_search_ra_margin(small_sky_order1_source_with_margin, assert_divisions_are_correct):
ra_search_catalog = small_sky_order1_source_with_margin.box(ra=(280, 300))
ra_search_df = ra_search_catalog.compute()
ra_values = ra_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column]
assert len(ra_search_df) < len(small_sky_order1_source_with_margin.compute())
assert all(280 <= ra <= 300 for ra in ra_values)
assert_divisions_are_correct(ra_search_catalog)

assert ra_search_catalog.margin is not None
ra_margin_search_df = ra_search_catalog.margin.compute()
ra_values = ra_margin_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column]
assert len(ra_margin_search_df) < len(small_sky_order1_source_with_margin.margin.compute())
assert all(280 <= ra <= 300 for ra in ra_values)
assert_divisions_are_correct(ra_search_catalog.margin)


def test_box_search_ra_complement(small_sky_order1_catalog):
ra_column = small_sky_order1_catalog.hc_structure.catalog_info.ra_column

Expand Down
38 changes: 38 additions & 0 deletions tests/lsdb/catalog/test_cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,44 @@ def test_cone_search_filters_correct_points(small_sky_order1_catalog, assert_div
assert_divisions_are_correct(cone_search_catalog)


def test_cone_search_filters_correct_points_margin(
small_sky_order1_source_with_margin, assert_divisions_are_correct
):
ra = -35
dec = -55
radius_degrees = 2
radius = radius_degrees * 3600
center_coord = SkyCoord(ra, dec, unit="deg")
cone_search_catalog = small_sky_order1_source_with_margin.cone_search(ra, dec, radius)
assert cone_search_catalog.margin is not None
cone_search_df = cone_search_catalog.compute()
for _, row in small_sky_order1_source_with_margin.compute().iterrows():
row_ra = row[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column]
row_dec = row[small_sky_order1_source_with_margin.hc_structure.catalog_info.dec_column]
sep = SkyCoord(row_ra, row_dec, unit="deg").separation(center_coord)
if sep.degree <= radius_degrees:
assert len(cone_search_df.loc[cone_search_df["id"] == row["id"]]) == 1
else:
assert len(cone_search_df.loc[cone_search_df["id"] == row["id"]]) == 0
cone_search_margin_df = cone_search_catalog.margin.compute()
for _, row in small_sky_order1_source_with_margin.margin.compute().iterrows():
row_ra = row[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column]
row_dec = row[small_sky_order1_source_with_margin.hc_structure.catalog_info.dec_column]
sep = SkyCoord(row_ra, row_dec, unit="deg").separation(center_coord)
if sep.degree <= radius_degrees:
assert len(cone_search_margin_df.loc[cone_search_margin_df["id"] == row["id"]]) == 1
else:
assert len(cone_search_margin_df.loc[cone_search_margin_df["id"] == row["id"]]) == 0
assert_divisions_are_correct(cone_search_catalog)
assert_divisions_are_correct(cone_search_catalog.margin)


def test_cone_search_big_margin(small_sky_order1_source_with_margin):
small_sky_order1_source_with_margin.margin.hc_structure.catalog_info.margin_threshold = 600000
with pytest.raises(ValueError, match="Margin size"):
small_sky_order1_source_with_margin.cone_search(0, 0, 1)


def test_cone_search_filters_partitions(small_sky_order1_catalog):
ra = 0
dec = -80
Expand Down
28 changes: 28 additions & 0 deletions tests/lsdb/catalog/test_polygon_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,34 @@ def test_polygon_search_filters_correct_points(small_sky_order1_catalog, assert_
assert_divisions_are_correct(polygon_search_catalog)


def test_polygon_search_filters_correct_points_margin(
small_sky_order1_source_with_margin, assert_divisions_are_correct
):
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
polygon, _ = get_cartesian_polygon(vertices)
polygon_search_catalog = small_sky_order1_source_with_margin.polygon_search(vertices)
polygon_search_df = polygon_search_catalog.compute()
ra_values_radians = np.radians(
polygon_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column]
)
dec_values_radians = np.radians(
polygon_search_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.dec_column]
)
assert all(polygon.contains(ra_values_radians, dec_values_radians))
assert_divisions_are_correct(polygon_search_catalog)

assert polygon_search_catalog.margin is not None
polygon_search_margin_df = polygon_search_catalog.margin.compute()
ra_values_radians = np.radians(
polygon_search_margin_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.ra_column]
)
dec_values_radians = np.radians(
polygon_search_margin_df[small_sky_order1_source_with_margin.hc_structure.catalog_info.dec_column]
)
assert all(polygon.contains(ra_values_radians, dec_values_radians))
assert_divisions_are_correct(polygon_search_catalog.margin)


def test_polygon_search_filters_partitions(small_sky_order1_catalog):
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
_, vertices_xyz = get_cartesian_polygon(vertices)
Expand Down

0 comments on commit 580e428

Please sign in to comment.