Skip to content

Commit

Permalink
Add MOC Filter (#503)
Browse files Browse the repository at this point in the history
* add moc_filter

* add search to healpix_dataset

* add moc search to catalog

* add unit tests

* fix isort

* refactor search
  • Loading branch information
smcguire-cmu authored Nov 20, 2024
1 parent 72e809a commit 60c9421
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 21 deletions.
21 changes: 17 additions & 4 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import nested_pandas as npd
import pandas as pd
from hats.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog
from mocpy import MOC
from pandas._libs import lib
from pandas._typing import AnyAll, Axis, IndexLabel
from pandas.api.extensions import no_default
Expand All @@ -18,6 +19,7 @@
from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm
from lsdb.core.search import BoxSearch, ConeSearch, IndexSearch, OrderSearch, PolygonSearch
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.core.search.moc_search import MOCSearch
from lsdb.core.search.pixel_search import PixelSearch
from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data
from lsdb.dask.join_catalog_data import (
Expand Down Expand Up @@ -324,6 +326,18 @@ def pixel_search(self, pixels: List[Tuple[int, int]]) -> Catalog:
"""
return self.search(PixelSearch(pixels))

def moc_search(self, moc: MOC, fine: bool = True) -> Catalog:
"""Finds all catalog points that are contained within a moc.
Args:
moc (mocpy.MOC): The moc that defines the region for the search.
fine (bool): True if points are to be filtered, False if only partitions. Defaults to True.
Returns:
A new Catalog containing only the points that are within the moc.
"""
return self.search(MOCSearch(moc, fine=fine))

def search(self, search: AbstractSearch):
"""Find rows by reusable search algorithm.
Expand All @@ -336,10 +350,9 @@ def search(self, search: AbstractSearch):
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
"""
filtered_hc_structure = search.filter_hc_catalog(self.hc_structure)
ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search)
margin = self.margin.search(search) if self.margin is not None else None
return Catalog(search_ndf, ddf_partition_map, filtered_hc_structure, margin=margin)
cat = super().search(search)
cat.margin = self.margin.search(search) if self.margin is not None else None
return cat

def merge(
self,
Expand Down
16 changes: 16 additions & 0 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,22 @@ def _perform_search(
ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)}
return ddf_partition_map, filtered_partitions_ddf

def search(self, search: AbstractSearch):
"""Find rows by reusable search algorithm.
Filters partitions in the catalog to those that match some rough criteria.
Filters to points that match some finer criteria.
Args:
search (AbstractSearch): Instance of AbstractSearch.
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
"""
filtered_hc_structure = search.filter_hc_catalog(self.hc_structure)
ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search)
return self.__class__(search_ndf, ddf_partition_map, filtered_hc_structure)

def map_partitions(
self,
func: Callable[..., npd.NestedFrame],
Expand Down
17 changes: 0 additions & 17 deletions src/lsdb/catalog/margin_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import nested_dask as nd

from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.types import DaskDFPixelMap


Expand All @@ -24,19 +23,3 @@ def __init__(
hc_structure: hc.catalog.MarginCatalog,
):
super().__init__(ddf, ddf_pixel_map, hc_structure)

def search(self, search: AbstractSearch):
"""Find rows by reusable search algorithm.
Filters partitions in the catalog to those that match some rough criteria and their neighbors.
Filters to points that match some finer criteria.
Args:
search (AbstractSearch): Instance of AbstractSearch.
Returns:
A new Catalog containing the points filtered to those matching the search parameters.
"""
filtered_hc_structure = search.filter_hc_catalog(self.hc_structure)
ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search)
return self.__class__(search_ndf, ddf_partition_map, filtered_hc_structure)
33 changes: 33 additions & 0 deletions src/lsdb/core/search/moc_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import astropy.units as u
import nested_pandas as npd
from hats.catalog import TableProperties
from mocpy import MOC

from lsdb.core.search.abstract_search import AbstractSearch

if TYPE_CHECKING:
from lsdb.types import HCCatalogTypeVar


class MOCSearch(AbstractSearch):
"""Filter the catalog by a MOC.
Filters partitions in the catalog to those that are in a specified moc.
"""

def __init__(self, moc: MOC, fine: bool = True):
super().__init__(fine)
self.moc = moc

def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar:
return hc_structure.filter_by_moc(self.moc)

def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame:
df_ras = frame[metadata.ra_column].to_numpy()
df_decs = frame[metadata.dec_column].to_numpy()
mask = self.moc.contains_lonlat(df_ras * u.deg, df_decs * u.deg)
return frame.iloc[mask]
31 changes: 31 additions & 0 deletions tests/lsdb/catalog/test_moc_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import astropy.units as u
import numpy as np
import pandas as pd
from hats.pixel_math import HealpixPixel
from mocpy import MOC


def test_moc_search_filters_correct_points(small_sky_order1_catalog):
search_moc = MOC.from_healpix_cells(ipix=np.array([176, 177]), depth=np.array([2, 2]), max_depth=2)
filtered_cat = small_sky_order1_catalog.moc_search(search_moc)
assert filtered_cat.get_healpix_pixels() == [HealpixPixel(1, 44)]
filtered_cat_comp = filtered_cat.compute()
cat_comp = small_sky_order1_catalog.compute()
assert np.all(
search_moc.contains_lonlat(
filtered_cat_comp["ra"].to_numpy() * u.deg, filtered_cat_comp["dec"].to_numpy() * u.deg
)
)
assert np.sum(
search_moc.contains_lonlat(cat_comp["ra"].to_numpy() * u.deg, cat_comp["dec"].to_numpy() * u.deg)
) == len(filtered_cat_comp)


def test_moc_search_non_fine(small_sky_order1_catalog):
search_moc = MOC.from_healpix_cells(ipix=np.array([176, 180]), depth=np.array([2, 2]), max_depth=2)
filtered_cat = small_sky_order1_catalog.moc_search(search_moc, fine=False)
assert filtered_cat.get_healpix_pixels() == [HealpixPixel(1, 44), HealpixPixel(1, 45)]
pd.testing.assert_frame_equal(
filtered_cat.compute(),
small_sky_order1_catalog.pixel_search([HealpixPixel(1, 44), HealpixPixel(1, 45)]).compute(),
)

0 comments on commit 60c9421

Please sign in to comment.