Skip to content

Commit

Permalink
Merge pull request #140 from astronomy-commons/sean/refactor-join-cro…
Browse files Browse the repository at this point in the history
…ssmatch

Refactor join and crossmatch partition alignment and applying code
  • Loading branch information
smcguire-cmu authored Feb 2, 2024
2 parents 1ca91c3 + 724014d commit 07225d1
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 167 deletions.
72 changes: 20 additions & 52 deletions src/lsdb/dask/crossmatch_catalog_data.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Tuple, Type, cast
from typing import TYPE_CHECKING, Tuple, Type

import dask
import dask.dataframe as dd
import numpy as np
from hipscat.pixel_tree import PixelAlignment, PixelAlignmentType, align_trees

from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm
from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm
from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch
from lsdb.dask.divisions import get_pixels_divisions
from lsdb.dask.merge_catalog_functions import (
align_catalogs_to_alignment_mapping,
align_and_apply,
construct_catalog_args,
filter_by_hipscat_index_to_pixel,
generate_meta_df_for_joined_tables,
get_healpix_pixels_from_alignment,
get_partition_map_from_alignment_pixels,
)
from lsdb.types import DaskDFPixelMap

Expand All @@ -29,15 +27,13 @@
# pylint: disable=too-many-arguments
@dask.delayed
def perform_crossmatch(
algorithm,
left_df,
right_df,
left_order,
left_pixel,
right_order,
right_pixel,
left_pix,
right_pix,
left_hc_structure,
right_hc_structure,
algorithm,
suffixes,
**kwargs,
):
Expand All @@ -46,15 +42,15 @@ def perform_crossmatch(
Filters the left catalog before performing the cross-match to stop duplicate points appearing in
the result.
"""
if right_order > left_order:
left_df = filter_by_hipscat_index_to_pixel(left_df, right_order, right_pixel)
if right_pix.order > left_pix.order:
left_df = filter_by_hipscat_index_to_pixel(left_df, right_pix.order, right_pix.pixel)
return algorithm(
left_df,
right_df,
left_order,
left_pixel,
right_order,
right_pixel,
left_pix.order,
left_pix.pixel,
right_pix.order,
right_pix.pixel,
left_hc_structure,
right_hc_structure,
suffixes,
Expand Down Expand Up @@ -95,54 +91,26 @@ def crossmatch_catalog_data(
right.hc_structure.pixel_tree,
alignment_type=PixelAlignmentType.INNER,
)
join_pixels = alignment.pixel_mapping

# align partitions from the catalogs to match the pixel alignment
left_aligned_partitions, right_aligned_partitions = align_catalogs_to_alignment_mapping(
join_pixels, left, right
)

# get lists of HEALPix pixels from alignment to pass to cross-match
left_pixels, right_pixels = get_healpix_pixels_from_alignment(join_pixels)
left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment)

# perform the crossmatch on each partition pairing using dask delayed for lazy computation
apply_crossmatch = np.vectorize(
lambda left_df, right_df, left_pix, right_pix: perform_crossmatch(
crossmatch_algorithm,
left_df,
right_df,
left_pix.order,
left_pix.pixel,
right_pix.order,
right_pix.pixel,
left.hc_structure,
right.hc_structure,
suffixes,
**kwargs,
)
)

joined_partitions = apply_crossmatch(
left_aligned_partitions,
right_aligned_partitions,
left_pixels,
right_pixels,
joined_partitions = align_and_apply(
[(left, left_pixels), (right, right_pixels)],
perform_crossmatch,
crossmatch_algorithm,
suffixes,
**kwargs,
)

# generate dask df partition map from alignment
partition_map = get_partition_map_from_alignment_pixels(join_pixels)

# generate meta table structure for dask df
meta_df = generate_meta_df_for_joined_tables(
[left, right], suffixes, extra_columns=crossmatch_algorithm.extra_columns
)

# create dask df from delayed partitions
divisions = get_pixels_divisions(list(partition_map.keys()))
ddf = dd.from_delayed(joined_partitions, meta=meta_df, divisions=divisions)
ddf = cast(dd.DataFrame, ddf)

return ddf, partition_map, alignment
return construct_catalog_args(joined_partitions, meta_df, alignment)


def get_crossmatch_algorithm(
Expand Down
104 changes: 34 additions & 70 deletions src/lsdb/dask/join_catalog_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,23 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Tuple, cast
from typing import TYPE_CHECKING, Tuple

import dask
import dask.dataframe as dd
import hipscat as hc
import numpy as np
import pandas as pd
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN
from hipscat.pixel_tree import PixelAlignment, PixelAlignmentType, align_trees

from lsdb.catalog.association_catalog import AssociationCatalog
from lsdb.dask.divisions import get_pixels_divisions
from lsdb.dask.merge_catalog_functions import (
align_catalog_to_partitions,
align_catalogs_to_alignment_mapping,
align_and_apply,
construct_catalog_args,
filter_by_hipscat_index_to_pixel,
generate_meta_df_for_joined_tables,
get_healpix_pixels_from_alignment,
get_partition_map_from_alignment_pixels,
)
from lsdb.types import DaskDFPixelMap

Expand Down Expand Up @@ -50,25 +47,30 @@ def rename_columns_with_suffixes(left: pd.DataFrame, right: pd.DataFrame, suffix
return left, right


# pylint: disable=unused-argument
@dask.delayed
def perform_join_on(
left: pd.DataFrame,
right: pd.DataFrame,
left_on: str,
right_on: str,
left_pixel: HealpixPixel,
right_pixel: HealpixPixel,
left_structure: hc.catalog.Catalog,
right_structure: hc.catalog.Catalog,
left_on: str,
right_on: str,
suffixes: Tuple[str, str],
):
"""Performs a join on two catalog partitions
Args:
left (pd.DataFrame): the left partition to merge
right (pd.DataFrame): the right partition to merge
left_on (str): the column to join on from the left partition
right_on (str): the column to join on from the right partition
left_pixel (HealpixPixel): the HEALPix pixel of the left partition
right_pixel (HealpixPixel): the HEALPix pixel of the right partition
left_structure (hc.Catalog): the hipscat structure of the left catalog
right_structure (hc.Catalog): the hipscat structure of the right catalog
left_on (str): the column to join on from the left partition
right_on (str): the column to join on from the right partition
suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names
Returns:
Expand All @@ -82,14 +84,18 @@ def perform_join_on(
return merged


# pylint: disable=unused-argument
@dask.delayed
def perform_join_through(
left: pd.DataFrame,
right: pd.DataFrame,
through: pd.DataFrame,
left_pixel: HealpixPixel,
right_pixel: HealpixPixel,
catalog_info: hc.catalog.association_catalog.AssociationCatalogInfo,
through_pixel: HealpixPixel,
left_catalog: hc.catalog.Catalog,
right_catalog: hc.catalog.Catalog,
assoc_catalog: hc.catalog.AssociationCatalog,
suffixes: Tuple[str, str],
):
"""Performs a join on two catalog partitions through an association catalog
Expand All @@ -100,12 +106,16 @@ def perform_join_through(
through (pd.DataFrame): the association column partition to merge with
left_pixel (HealpixPixel): the HEALPix pixel of the left partition
right_pixel (HealpixPixel): the HEALPix pixel of the right partition
catalog_info (AssociationCatalogInfo): the catalog_info of the association catalog
through_pixel (HealpixPixel): the HEALPix pixel of the association partition
left_catalog (hc.Catalog): the hipscat structure of the left catalog
right_catalog (hc.Catalog): the hipscat structure of the right catalog
assoc_catalog (hc.AssociationCatalog): the hipscat structure of the association catalog
suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names
Returns:
A dataframe with the result of merging the left and right partitions on the specified columns
"""
catalog_info = assoc_catalog.catalog_info
if catalog_info.primary_column is None or catalog_info.join_column is None:
raise ValueError("Invalid catalog_info")
if right_pixel.order > left_pixel.order:
Expand Down Expand Up @@ -157,35 +167,16 @@ def join_catalog_data_on(
alignment = align_trees(
left.hc_structure.pixel_tree, right.hc_structure.pixel_tree, alignment_type=PixelAlignmentType.INNER
)
join_pixels = alignment.pixel_mapping
left_aligned_partitions, right_aligned_partitions = align_catalogs_to_alignment_mapping(
join_pixels, left, right
)

left_pixels, right_pixels = get_healpix_pixels_from_alignment(join_pixels)
left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment)

apply_join = np.vectorize(
lambda left_df, right_df, left_pix, right_pix: perform_join_on(
left_df,
right_df,
left_on,
right_on,
left_pix,
right_pix,
suffixes,
)
joined_partitions = align_and_apply(
[(left, left_pixels), (right, right_pixels)], perform_join_on, left_on, right_on, suffixes
)

joined_partitions = apply_join(
left_aligned_partitions, right_aligned_partitions, left_pixels, right_pixels
)

partition_map = get_partition_map_from_alignment_pixels(join_pixels)
meta_df = generate_meta_df_for_joined_tables([left, right], suffixes)
divisions = get_pixels_divisions(list(partition_map.keys()))
ddf = dd.from_delayed(joined_partitions, meta=meta_df, divisions=divisions)
ddf = cast(dd.DataFrame, ddf)
return ddf, partition_map, alignment

return construct_catalog_args(joined_partitions, meta_df, alignment)


def join_catalog_data_through(
Expand Down Expand Up @@ -222,48 +213,21 @@ def join_catalog_data_through(
alignment = align_trees(
left.hc_structure.pixel_tree, right.hc_structure.pixel_tree, alignment_type=PixelAlignmentType.INNER
)
join_pixels = alignment.pixel_mapping
left_aligned_partitions, right_aligned_partitions = align_catalogs_to_alignment_mapping(
join_pixels, left, right
)
association_aligned_to_join_partitions = align_catalog_to_partitions(
association,
join_pixels,
order_col=PixelAlignment.PRIMARY_ORDER_COLUMN_NAME,
pixel_col=PixelAlignment.PRIMARY_PIXEL_COLUMN_NAME,
)

left_pixels, right_pixels = get_healpix_pixels_from_alignment(join_pixels)
left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment)

apply_join = np.vectorize(
lambda left_df, right_df, assoc_df, left_pix, right_pix: perform_join_through(
left_df,
right_df,
assoc_df,
left_pix,
right_pix,
association.hc_structure.catalog_info,
suffixes,
)
joined_partitions = align_and_apply(
[(left, left_pixels), (right, right_pixels), (association, left_pixels)],
perform_join_through,
suffixes,
)

joined_partitions = apply_join(
left_aligned_partitions,
right_aligned_partitions,
association_aligned_to_join_partitions,
left_pixels,
right_pixels,
)

partition_map = get_partition_map_from_alignment_pixels(alignment.pixel_mapping)
association_join_columns = [
association.hc_structure.catalog_info.primary_column_association,
association.hc_structure.catalog_info.join_column_association,
]
# pylint: disable=protected-access
extra_df = association._ddf._meta.drop(NON_JOINING_ASSOCIATION_COLUMNS + association_join_columns, axis=1)
meta_df = generate_meta_df_for_joined_tables([left, extra_df, right], [suffixes[0], "", suffixes[1]])
divisions = get_pixels_divisions(list(partition_map.keys()))
ddf = dd.from_delayed(joined_partitions, meta=meta_df, divisions=divisions)
ddf = cast(dd.DataFrame, ddf)
return ddf, partition_map, alignment

return construct_catalog_args(joined_partitions, meta_df, alignment)
Loading

0 comments on commit 07225d1

Please sign in to comment.