Skip to content

Commit

Permalink
Fix empty margin catalogs in from_dataframe (#508)
Browse files Browse the repository at this point in the history
* Create empty catalog for margins

* Fix margin catalog validation

* Calculate threshold when order is specified
  • Loading branch information
camposandro authored Nov 26, 2024
1 parent 8269182 commit 2a36166
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 40 deletions.
21 changes: 21 additions & 0 deletions src/lsdb/catalog/margin_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import hats as hc
import nested_dask as nd
import pyarrow as pa
from hats.io import paths

from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
from lsdb.types import DaskDFPixelMap
Expand All @@ -23,3 +25,22 @@ def __init__(
hc_structure: hc.catalog.MarginCatalog,
):
super().__init__(ddf, ddf_pixel_map, hc_structure)


def _validate_margin_catalog(margin_hc_catalog, hc_catalog):
"""Validate that the margin and main catalogs have compatible schemas. The order of
the pyarrow fields should not matter."""
expected_margin_schema = _create_margin_schema(hc_catalog.schema)
# Compare the fields for the schemas (allowing duplicates). They should match.
margin_catalog_fields = sorted((f.name, f.type) for f in margin_hc_catalog.schema)
expected_margin_fields = sorted((f.name, f.type) for f in expected_margin_schema)
if margin_catalog_fields != expected_margin_fields:
raise ValueError("The margin catalog and the main catalog must have the same schema.")


def _create_margin_schema(main_catalog_schema: pa.Schema) -> pa.Schema:
"""Create a pyarrow schema for the margin catalog from the main catalog schema."""
order_field = pa.field(f"margin_{paths.PARTITION_ORDER}", pa.uint8())
dir_field = pa.field(f"margin_{paths.PARTITION_DIR}", pa.uint64())
pixel_field = pa.field(f"margin_{paths.PARTITION_PIXEL}", pa.uint64())
return main_catalog_schema.append(order_field).append(dir_field).append(pixel_field)
19 changes: 9 additions & 10 deletions src/lsdb/loaders/dataframe/from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def from_dataframe(
partition_size (int): The desired partition size, in number of bytes in-memory.
threshold (int): The maximum number of data points per pixel.
margin_order (int): The order at which to generate the margin cache.
margin_threshold (float): The size of the margin cache boundary, in arcseconds. If None,
the margin cache is not generated. Defaults to 5 arcseconds.
margin_threshold (float): The size of the margin cache boundary, in arcseconds. If None, and
margin order is not specified, the margin cache is not generated. Defaults to 5 arcseconds.
should_generate_moc (bool): should we generate a MOC (multi-order coverage map)
of the data. can improve performance when joining/crossmatching to
other hats-sharded datasets.
Expand Down Expand Up @@ -74,12 +74,11 @@ def from_dataframe(
schema=schema,
**kwargs,
).load_catalog()
if margin_threshold:
catalog.margin = MarginCatalogGenerator(
catalog,
margin_order,
margin_threshold,
use_pyarrow_types,
**kwargs,
).create_catalog()
catalog.margin = MarginCatalogGenerator(
catalog,
margin_order,
margin_threshold,
use_pyarrow_types,
**kwargs,
).create_catalog()
return catalog
49 changes: 37 additions & 12 deletions src/lsdb/loaders/dataframe/margin_catalog_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import Dict, List, Tuple

import hats as hc
Expand All @@ -13,7 +14,7 @@
from hats.pixel_math.healpix_pixel_function import get_pixel_argsort

from lsdb import Catalog
from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.catalog.margin_catalog import MarginCatalog, _create_margin_schema
from lsdb.loaders.dataframe.from_dataframe_utils import (
_extra_property_dict,
_format_margin_partition_dataframe,
Expand All @@ -28,7 +29,7 @@ def __init__(
self,
catalog: Catalog,
margin_order: int = -1,
margin_threshold: float = 5.0,
margin_threshold: float | None = 5.0,
use_pyarrow_types: bool = True,
**kwargs,
) -> None:
Expand All @@ -45,9 +46,9 @@ def __init__(
self.hc_structure = catalog.hc_structure
self.margin_threshold = margin_threshold
self.margin_order = margin_order
self._resolve_margin_order()
self.use_pyarrow_types = use_pyarrow_types
self.catalog_info = self._create_catalog_info(**kwargs)
self.catalog_info_kwargs = kwargs
self.margin_schema = _create_margin_schema(catalog.hc_structure.schema)

def _resolve_margin_order(self):
"""Calculate the order of the margin cache to be generated. If not provided
Expand All @@ -61,6 +62,9 @@ def _resolve_margin_order(self):

if self.margin_order < 0:
self.margin_order = hp.margin2order(margin_thr_arcmin=self.margin_threshold / 60.0)
else:
self.margin_threshold = hp.order2mindist(self.margin_order) * 60.0
warnings.warn("Ignoring margin_threshold because margin_order was specified.", RuntimeWarning)

if self.margin_order < highest_order + 1:
raise ValueError(
Expand All @@ -72,22 +76,44 @@ def _resolve_margin_order(self):
raise ValueError("margin pixels must be larger than margin_threshold")

def create_catalog(self) -> MarginCatalog | None:
"""Create a margin catalog for another pre-computed catalog
"""Create a margin catalog for another pre-computed catalog.
Only one of margin order / threshold can be specified. If the margin order
is not specified: if the threshold is zero the margin is an empty catalog;
if the threshold is None, the margin is not generated (it is None).
Returns:
Margin catalog object, or None if the margin is empty.
Margin catalog object or None if the margin is not generated.
"""
if self.margin_order < 0:
if self.margin_threshold is None:
return None
if self.margin_threshold < 0:
raise ValueError("margin_threshold must be positive.")
if self.margin_threshold == 0:
return self._create_empty_catalog()
return self._create_catalog()

def _create_catalog(self) -> MarginCatalog:
"""Create a non-empty margin catalog"""
self._resolve_margin_order()
pixels, partitions = self._get_margins()
if len(pixels) == 0:
return None
return self._create_empty_catalog()
ddf, ddf_pixel_map, total_rows = self._generate_dask_df_and_map(pixels, partitions)
self.catalog_info.total_rows = total_rows
catalog_info = self._create_catalog_info(**self.catalog_info_kwargs, total_rows=total_rows)
margin_pixels = list(ddf_pixel_map.keys())
margin_structure = hc.catalog.MarginCatalog(
self.catalog_info, margin_pixels, schema=self.hc_structure.schema
)
margin_structure = hc.catalog.MarginCatalog(catalog_info, margin_pixels, schema=self.margin_schema)
return MarginCatalog(ddf, ddf_pixel_map, margin_structure)

def _create_empty_catalog(self) -> MarginCatalog:
"""Create an empty margin catalog"""
dask_meta_schema = self.margin_schema.empty_table().to_pandas()
ddf = nd.NestedFrame.from_pandas(dask_meta_schema, npartitions=1)
catalog_info = self._create_catalog_info(**self.catalog_info_kwargs, total_rows=0)
margin_structure = hc.catalog.MarginCatalog(catalog_info, [], schema=self.margin_schema)
return MarginCatalog(ddf, {}, margin_structure)

def _get_margins(self) -> Tuple[List[HealpixPixel], List[npd.NestedFrame]]:
"""Generates the list of pixels that have margin data, and the dataframes with the margin data for
each partition
Expand Down Expand Up @@ -206,7 +232,6 @@ def _create_catalog_info(self, catalog_name: str | None = None, **kwargs) -> Tab
catalog_type=CatalogType.MARGIN,
ra_column=self.hc_structure.catalog_info.ra_column,
dec_column=self.hc_structure.catalog_info.dec_column,
total_rows=self.hc_structure.catalog_info.total_rows,
primary_catalog=catalog_name,
margin_threshold=self.margin_threshold,
**kwargs,
Expand Down
14 changes: 1 addition & 13 deletions src/lsdb/loaders/hats/read_hats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pyarrow as pa
from hats.catalog import CatalogType
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset
from hats.io import paths
from hats.io.file_io import file_io
from hats.pixel_math import HealpixPixel
from hats.pixel_math.healpix_pixel_function import get_pixel_argsort
Expand All @@ -19,6 +18,7 @@

from lsdb.catalog.association_catalog import AssociationCatalog
from lsdb.catalog.catalog import Catalog, DaskDFPixelMap, MarginCatalog
from lsdb.catalog.margin_catalog import _validate_margin_catalog
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.dask.divisions import get_pixels_divisions
from lsdb.loaders.hats.hats_loading_config import HatsLoadingConfig
Expand Down Expand Up @@ -154,18 +154,6 @@ def _load_object_catalog(hc_catalog, config):
return catalog


def _validate_margin_catalog(margin_hc_catalog, hc_catalog):
"""Validate that the margin catalog and the main catalog are compatible"""
pixel_columns = [paths.PARTITION_ORDER, paths.PARTITION_DIR, paths.PARTITION_PIXEL]
margin_pixel_columns = pixel_columns + ["margin_" + column for column in pixel_columns]
catalog_schema = pa.schema([field for field in hc_catalog.schema if field.name not in pixel_columns])
margin_schema = pa.schema(
[field for field in margin_hc_catalog.schema if field.name not in margin_pixel_columns]
)
if not catalog_schema.equals(margin_schema):
raise ValueError("The margin catalog and the main catalog must have the same schema")


def _create_dask_meta_schema(schema: pa.Schema, config) -> npd.NestedFrame:
"""Creates the Dask meta DataFrame from the HATS catalog schema."""
dask_meta_schema = schema.empty_table().to_pandas(types_mapper=config.get_dtype_mapper())
Expand Down
60 changes: 55 additions & 5 deletions tests/lsdb/loaders/dataframe/test_from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mocpy import MOC

import lsdb
from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.catalog.margin_catalog import MarginCatalog, _validate_margin_catalog


def get_catalog_kwargs(catalog, **kwargs):
Expand Down Expand Up @@ -258,11 +258,41 @@ def test_from_dataframe_small_sky_source_with_margins(small_sky_source_df, small

assert catalog.hc_structure.catalog_info.__pydantic_extra__["hats_builder"].startswith("lsdb")
assert margin.hc_structure.catalog_info.__pydantic_extra__["hats_builder"].startswith("lsdb")
# The margin and main catalog's schemas are the same
assert margin.hc_structure.schema is catalog.hc_structure.schema

# The margin and main catalog's schemas are valid
_validate_margin_catalog(margin.hc_structure, catalog.hc_structure)

def test_from_dataframe_invalid_margin_order(small_sky_source_df):

def test_from_dataframe_margin_threshold_from_order(small_sky_source_df):
# By default, the threshold is set to 5 arcsec, triggering a warning
with pytest.warns(RuntimeWarning, match="Ignoring margin_threshold"):
catalog = lsdb.from_dataframe(
small_sky_source_df,
ra_column="source_ra",
dec_column="source_dec",
lowest_order=0,
highest_order=2,
threshold=3000,
margin_order=3,
)
assert len(catalog.margin.get_healpix_pixels()) == 17
margin_threshold_order3 = hp.order2mindist(3) * 60.0
assert catalog.margin.hc_structure.catalog_info.margin_threshold == margin_threshold_order3
assert catalog.margin._ddf.index.name == catalog._ddf.index.name
_validate_margin_catalog(catalog.margin.hc_structure, catalog.hc_structure)


def test_from_dataframe_invalid_margin_args(small_sky_source_df):
# The provided margin threshold is negative
with pytest.raises(ValueError, match="positive"):
lsdb.from_dataframe(
small_sky_source_df,
ra_column="source_ra",
dec_column="source_dec",
lowest_order=2,
margin_threshold=-1,
)
# Margin order is inferior to the main catalog's highest order
with pytest.raises(ValueError, match="margin_order"):
lsdb.from_dataframe(
small_sky_source_df,
Expand All @@ -281,7 +311,27 @@ def test_from_dataframe_margin_is_empty(small_sky_order1_df):
highest_order=5,
threshold=100,
)
assert catalog.margin is None
assert len(catalog.margin.get_healpix_pixels()) == 0
assert catalog.margin._ddf_pixel_map == {}
assert catalog.margin._ddf.index.name == catalog._ddf.index.name
assert catalog.margin.hc_structure.catalog_info.margin_threshold == 5.0
_validate_margin_catalog(catalog.margin.hc_structure, catalog.hc_structure)


def test_from_dataframe_margin_threshold_zero(small_sky_order1_df):
catalog = lsdb.from_dataframe(
small_sky_order1_df,
catalog_name="small_sky_order1",
catalog_type="object",
highest_order=5,
threshold=100,
margin_threshold=0,
)
assert len(catalog.margin.get_healpix_pixels()) == 0
assert catalog.margin._ddf_pixel_map == {}
assert catalog.margin._ddf.index.name == catalog._ddf.index.name
assert catalog.margin.hc_structure.catalog_info.margin_threshold == 0
_validate_margin_catalog(catalog.margin.hc_structure, catalog.hc_structure)


def test_from_dataframe_moc(small_sky_order1_catalog):
Expand Down

0 comments on commit 2a36166

Please sign in to comment.