Skip to content

Commit

Permalink
Add __getitem__ accessor for catalogs (#281)
Browse files Browse the repository at this point in the history
* Add __getitem__ accessor for catalogs

* pylint
  • Loading branch information
smcguire-cmu authored Apr 22, 2024
1 parent 4bcbf3b commit 286fa7a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def __init__(
super().__init__(ddf, hc_structure)
self._ddf_pixel_map = ddf_pixel_map

def __getitem__(self, item):
result = self._ddf.__getitem__(item)
if isinstance(result, dd.core.DataFrame):
return self.__class__(result, self._ddf_pixel_map, self.hc_structure)
return result

def get_healpix_pixels(self) -> List[HealpixPixel]:
"""Get all HEALPix pixels that are contained in the catalog
Expand Down
26 changes: 26 additions & 0 deletions tests/lsdb/catalog/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hipscat.pixel_math import HealpixPixel, hipscat_id_to_healpix

import lsdb
from lsdb import Catalog
from lsdb.dask.merge_catalog_functions import filter_by_hipscat_index_to_pixel


Expand Down Expand Up @@ -402,3 +403,28 @@ def test_plot_pixels(small_sky_order1_catalog, mocker):

hp.mollview.assert_called_once()
assert (hp.mollview.call_args[0][0] == img).all()


def test_square_bracket_columns(small_sky_order1_catalog):
columns = ["ra", "dec", "id"]
column_subset = small_sky_order1_catalog[columns]
assert all(column_subset.columns == columns)
assert isinstance(column_subset, Catalog)
pd.testing.assert_frame_equal(column_subset.compute(), small_sky_order1_catalog.compute()[columns])
assert np.all(column_subset.compute().index.values == small_sky_order1_catalog.compute().index.values)


def test_square_bracket_column(small_sky_order1_catalog):
column_name = "ra"
column = small_sky_order1_catalog[column_name]
pd.testing.assert_series_equal(column.compute(), small_sky_order1_catalog.compute()[column_name])
assert np.all(column.compute().index.values == small_sky_order1_catalog.compute().index.values)
assert isinstance(column, dd.Series)


def test_square_bracket_filter(small_sky_order1_catalog):
filtered_id = small_sky_order1_catalog[small_sky_order1_catalog["id"] > 750]
assert isinstance(filtered_id, Catalog)
ss_computed = small_sky_order1_catalog.compute()
pd.testing.assert_frame_equal(filtered_id.compute(), ss_computed[ss_computed["id"] > 750])
assert np.all(filtered_id.compute().index.values == ss_computed[ss_computed["id"] > 750].index.values)

0 comments on commit 286fa7a

Please sign in to comment.