diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 585b32f0..d407481a 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -7,6 +7,7 @@ import nested_pandas as npd import pandas as pd from hats.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog +from mocpy import MOC from pandas._libs import lib from pandas._typing import AnyAll, Axis, IndexLabel from pandas.api.extensions import no_default @@ -18,6 +19,7 @@ from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm from lsdb.core.search import BoxSearch, ConeSearch, IndexSearch, OrderSearch, PolygonSearch from lsdb.core.search.abstract_search import AbstractSearch +from lsdb.core.search.moc_search import MOCSearch from lsdb.core.search.pixel_search import PixelSearch from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data from lsdb.dask.join_catalog_data import ( @@ -324,6 +326,18 @@ def pixel_search(self, pixels: List[Tuple[int, int]]) -> Catalog: """ return self.search(PixelSearch(pixels)) + def moc_search(self, moc: MOC, fine: bool = True) -> Catalog: + """Finds all catalog points that are contained within a moc. + + Args: + moc (mocpy.MOC): The moc that defines the region for the search. + fine (bool): True if points are to be filtered, False if only partitions. Defaults to True. + + Returns: + A new Catalog containing only the points that are within the moc. + """ + return self.search(MOCSearch(moc, fine=fine)) + def search(self, search: AbstractSearch): """Find rows by reusable search algorithm. @@ -336,10 +350,9 @@ def search(self, search: AbstractSearch): Returns: A new Catalog containing the points filtered to those matching the search parameters. """ - filtered_hc_structure = search.filter_hc_catalog(self.hc_structure) - ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search) - margin = self.margin.search(search) if self.margin is not None else None - return Catalog(search_ndf, ddf_partition_map, filtered_hc_structure, margin=margin) + cat = super().search(search) + cat.margin = self.margin.search(search) if self.margin is not None else None + return cat def merge( self, diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index 9aa5f15e..0f9ea771 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -194,6 +194,22 @@ def _perform_search( ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)} return ddf_partition_map, filtered_partitions_ddf + def search(self, search: AbstractSearch): + """Find rows by reusable search algorithm. + + Filters partitions in the catalog to those that match some rough criteria. + Filters to points that match some finer criteria. + + Args: + search (AbstractSearch): Instance of AbstractSearch. + + Returns: + A new Catalog containing the points filtered to those matching the search parameters. + """ + filtered_hc_structure = search.filter_hc_catalog(self.hc_structure) + ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search) + return self.__class__(search_ndf, ddf_partition_map, filtered_hc_structure) + def map_partitions( self, func: Callable[..., npd.NestedFrame], diff --git a/src/lsdb/catalog/margin_catalog.py b/src/lsdb/catalog/margin_catalog.py index f7332a50..db3abe30 100644 --- a/src/lsdb/catalog/margin_catalog.py +++ b/src/lsdb/catalog/margin_catalog.py @@ -2,7 +2,6 @@ import nested_dask as nd from lsdb.catalog.dataset.healpix_dataset import HealpixDataset -from lsdb.core.search.abstract_search import AbstractSearch from lsdb.types import DaskDFPixelMap @@ -24,19 +23,3 @@ def __init__( hc_structure: hc.catalog.MarginCatalog, ): super().__init__(ddf, ddf_pixel_map, hc_structure) - - def search(self, search: AbstractSearch): - """Find rows by reusable search algorithm. - - Filters partitions in the catalog to those that match some rough criteria and their neighbors. - Filters to points that match some finer criteria. - - Args: - search (AbstractSearch): Instance of AbstractSearch. - - Returns: - A new Catalog containing the points filtered to those matching the search parameters. - """ - filtered_hc_structure = search.filter_hc_catalog(self.hc_structure) - ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search) - return self.__class__(search_ndf, ddf_partition_map, filtered_hc_structure) diff --git a/src/lsdb/core/search/moc_search.py b/src/lsdb/core/search/moc_search.py new file mode 100644 index 00000000..ccaec1f2 --- /dev/null +++ b/src/lsdb/core/search/moc_search.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import astropy.units as u +import nested_pandas as npd +from hats.catalog import TableProperties +from mocpy import MOC + +from lsdb.core.search.abstract_search import AbstractSearch + +if TYPE_CHECKING: + from lsdb.types import HCCatalogTypeVar + + +class MOCSearch(AbstractSearch): + """Filter the catalog by a MOC. + + Filters partitions in the catalog to those that are in a specified moc. + """ + + def __init__(self, moc: MOC, fine: bool = True): + super().__init__(fine) + self.moc = moc + + def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar: + return hc_structure.filter_by_moc(self.moc) + + def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame: + df_ras = frame[metadata.ra_column].to_numpy() + df_decs = frame[metadata.dec_column].to_numpy() + mask = self.moc.contains_lonlat(df_ras * u.deg, df_decs * u.deg) + return frame.iloc[mask] diff --git a/tests/lsdb/catalog/test_moc_search.py b/tests/lsdb/catalog/test_moc_search.py new file mode 100644 index 00000000..16aa397d --- /dev/null +++ b/tests/lsdb/catalog/test_moc_search.py @@ -0,0 +1,31 @@ +import astropy.units as u +import numpy as np +import pandas as pd +from hats.pixel_math import HealpixPixel +from mocpy import MOC + + +def test_moc_search_filters_correct_points(small_sky_order1_catalog): + search_moc = MOC.from_healpix_cells(ipix=np.array([176, 177]), depth=np.array([2, 2]), max_depth=2) + filtered_cat = small_sky_order1_catalog.moc_search(search_moc) + assert filtered_cat.get_healpix_pixels() == [HealpixPixel(1, 44)] + filtered_cat_comp = filtered_cat.compute() + cat_comp = small_sky_order1_catalog.compute() + assert np.all( + search_moc.contains_lonlat( + filtered_cat_comp["ra"].to_numpy() * u.deg, filtered_cat_comp["dec"].to_numpy() * u.deg + ) + ) + assert np.sum( + search_moc.contains_lonlat(cat_comp["ra"].to_numpy() * u.deg, cat_comp["dec"].to_numpy() * u.deg) + ) == len(filtered_cat_comp) + + +def test_moc_search_non_fine(small_sky_order1_catalog): + search_moc = MOC.from_healpix_cells(ipix=np.array([176, 180]), depth=np.array([2, 2]), max_depth=2) + filtered_cat = small_sky_order1_catalog.moc_search(search_moc, fine=False) + assert filtered_cat.get_healpix_pixels() == [HealpixPixel(1, 44), HealpixPixel(1, 45)] + pd.testing.assert_frame_equal( + filtered_cat.compute(), + small_sky_order1_catalog.pixel_search([HealpixPixel(1, 44), HealpixPixel(1, 45)]).compute(), + )