Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-44159: Replace pandas.DataFrame usage with astropy.Table #978

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 86 additions & 62 deletions python/lsst/pipe/tasks/diff_matched_tract_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from abc import ABCMeta, abstractmethod
from astropy.stats import mad_std
import astropy.table
import astropy.units as u
from dataclasses import dataclass
from decimal import Decimal
Expand Down Expand Up @@ -75,14 +76,14 @@ class DiffMatchedTractCatalogConnections(
cat_ref = cT.Input(
doc="Reference object catalog to match from",
name="{name_input_cat_ref}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
cat_target = cT.Input(
doc="Target object catalog to match",
name="{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
Expand All @@ -95,33 +96,33 @@ class DiffMatchedTractCatalogConnections(
cat_match_ref = cT.Input(
doc="Reference match catalog with indices of target matches",
name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
cat_match_target = cT.Input(
doc="Target match catalog with indices of references matches",
name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
columns_match_target = cT.Input(
doc="Target match catalog columns",
name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
storageClass="DataFrameIndex",
storageClass="ArrowColumnList",
dimensions=("tract", "skymap"),
)
cat_matched = cT.Output(
doc="Catalog with reference and target columns for joined sources",
name="matched_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)
diff_matched = cT.Output(
doc="Table with aggregated counts, difference and chi statistics",
name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)

Expand All @@ -137,6 +138,8 @@ def __init__(self, *, config=None):
dimensions=(),
deferLoad=old.deferLoad,
)
if not (config.compute_stats and len(config.columns_flux) > 0):
del self.diff_matched


class MatchedCatalogFluxesConfig(pexConfig.Config):
Expand Down Expand Up @@ -685,25 +688,25 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):

def run(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
catalog_match_ref: pd.DataFrame,
catalog_match_target: pd.DataFrame,
catalog_ref: pd.DataFrame | astropy.table.Table,
catalog_target: pd.DataFrame | astropy.table.Table,
catalog_match_ref: pd.DataFrame | astropy.table.Table,
catalog_match_target: pd.DataFrame | astropy.table.Table,
wcs: afwGeom.SkyWcs = None,
) -> pipeBase.Struct:
"""Load matched reference and target (measured) catalogs, measure summary statistics, and output
a combined matched catalog with columns from both inputs.

Parameters
----------
catalog_ref : `pandas.DataFrame`
catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
A reference catalog to diff objects/sources from.
catalog_target : `pandas.DataFrame`
catalog_target : `pandas.DataFrame` | `astropy.table.Table`
A target catalog to diff reference objects/sources to.
catalog_match_ref : `pandas.DataFrame`
catalog_match_ref : `pandas.DataFrame` | `astropy.table.Table`
A catalog with match indices of target sources and selection flags
for each reference source.
catalog_match_target : `pandas.DataFrame`
catalog_match_target : `pandas.DataFrame` | `astropy.table.Table`
A catalog with selection flags for each target source.
wcs : `lsst.afw.image.SkyWcs`
A coordinate system to convert catalog positions to sky coordinates,
Expand All @@ -718,16 +721,31 @@ def run(
# Would be nice if this could refer directly to ConfigClass
config: DiffMatchedTractCatalogConfig = self.config

select_ref = catalog_match_ref['match_candidate'].values
is_ref_pd = isinstance(catalog_ref, pd.DataFrame)
is_target_pd = isinstance(catalog_target, pd.DataFrame)
is_match_ref_pd = isinstance(catalog_match_ref, pd.DataFrame)
is_match_target_pd = isinstance(catalog_match_target, pd.DataFrame)
if is_ref_pd:
catalog_ref = astropy.table.Table.from_pandas(catalog_ref)
if is_target_pd:
catalog_target = astropy.table.Table.from_pandas(catalog_target)
if is_match_ref_pd:
catalog_match_ref = astropy.table.Table.from_pandas(catalog_match_ref)
if is_match_target_pd:
catalog_match_target = astropy.table.Table.from_pandas(catalog_match_target)
if is_ref_pd or is_target_pd:
self.log.warning("pandas DataFrame inputs are deprecated")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a removal ticket and a proper deprecation warning as with meas_astrom.


select_ref = catalog_match_ref['match_candidate']
# Add additional selection criteria for target sources beyond those for matching
# (not recommended, but can be done anyway)
select_target = (catalog_match_target['match_candidate'].values
select_target = (catalog_match_target['match_candidate']
if 'match_candidate' in catalog_match_target.columns
else np.ones(len(catalog_match_target), dtype=bool))
for column in config.columns_target_select_true:
select_target &= catalog_target[column].values
select_target &= catalog_target[column]
for column in config.columns_target_select_false:
select_target &= ~catalog_target[column].values
select_target &= ~catalog_target[column]

ref, target = config.coord_format.format_catalogs(
catalog_ref=catalog_ref, catalog_target=catalog_target,
Expand All @@ -739,9 +757,9 @@ def run(

if config.include_unmatched:
for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
cat_add['match_candidate'] = cat_match['match_candidate'].values
cat_add['match_candidate'] = cat_match['match_candidate']

match_row = catalog_match_ref['match_row'].values
match_row = catalog_match_ref['match_row']
matched_ref = match_row >= 0
matched_row = match_row[matched_ref]
matched_target = np.zeros(n_target, dtype=bool)
Expand All @@ -761,48 +779,44 @@ def run(
) if config.coord_format.coords_spherical else np.hypot(
target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2,
)
cat_target_matched = cat_target[matched_row]
# This will convert a masked array to an array filled with nans
# wherever there are bad values (otherwise sphdist can raise)
c1_err, c2_err = (
np.ma.getdata(cat_target_matched[c_err]) for c_err in (coord1_target_err, coord2_target_err)
)
# Should probably explicitly add cosine terms if ref has errors too
dist_err[matched_row] = sphdist(
target_match_c1, target_match_c2,
target_match_c1 + cat_target.iloc[matched_row][coord1_target_err].values,
target_match_c2 + cat_target.iloc[matched_row][coord2_target_err].values,
) if config.coord_format.coords_spherical else np.hypot(
cat_target.iloc[matched_row][coord1_target_err].values,
cat_target.iloc[matched_row][coord2_target_err].values
)
target_match_c1, target_match_c2, target_match_c1 + c1_err, target_match_c2 + c2_err
) if config.coord_format.coords_spherical else np.hypot(c1_err, c2_err)
cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err

# Create a matched table, preserving the target catalog's named index (if it has one)
cat_left = cat_target.iloc[matched_row]
has_index_left = cat_left.index.name is not None
cat_right = cat_ref[matched_ref].reset_index()
cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns]
cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1)
cat_left = cat_target[matched_row]
cat_right = cat_ref[matched_ref]
cat_right.rename_columns(
list(cat_right.columns),
new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns],
)
cat_matched = astropy.table.hstack((cat_left, cat_right))

if config.include_unmatched:
# Create an unmatched table with the same schema as the matched one
# ... but only for objects with no matches (for completeness/purity)
# and that were selected for matching (or inclusion via config)
cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False)
cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns)
match_row_target = catalog_match_target['match_row'].values
cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index(
drop=not has_index_left)
cat_right = astropy.table.Table(
cat_ref[~matched_ref & select_ref]
)
cat_right.rename_columns(
cat_right.colnames,
[f"{config.column_matched_prefix_ref}{col}" for col in cat_right.colnames],
)
match_row_target = catalog_match_target['match_row']
cat_left = cat_target[~(match_row_target >= 0) & select_target]
# This may be slower than pandas but will, for example, create
# masked columns for booleans, which pandas does not support.
# See https://github.com/pandas-dev/pandas/issues/46662
# astropy masked columns would handle this much more gracefully
# Unfortunately, that would require storageClass migration
# So we use pandas "extended" nullable types for now
for cat_i in (cat_left, cat_right):
for colname in cat_i.columns:
column = cat_i[colname]
dtype = str(column.dtype)
if dtype == "bool":
cat_i[colname] = column.astype("boolean")
elif dtype.startswith("int"):
cat_i[colname] = column.astype(f"Int{dtype[3:]}")
elif dtype.startswith("uint"):
cat_i[colname] = column.astype(f"UInt{dtype[3:]}")
cat_unmatched = pd.concat(objs=(cat_left, cat_right))
cat_unmatched = astropy.table.vstack([cat_left, cat_right])

for columns_convert_base, prefix in (
(config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref),
Expand All @@ -812,8 +826,14 @@ def run(
columns_convert = {
f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items()
} if prefix else columns_convert_base
for cat_convert in (cat_matched, cat_unmatched):
cat_convert.rename(columns=columns_convert, inplace=True)
to_convert = [cat_matched]
if config.include_unmatched:
to_convert.append(cat_unmatched)
for cat_convert in to_convert:
cat_convert.rename_columns(
tuple(columns_convert.keys()),
tuple(columns_convert.values()),
)
for column_flux in columns_convert.values():
cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])

Expand All @@ -822,7 +842,8 @@ def run(
n_bands = len(band_fluxes)

# TODO: Deprecated by RFC-1017 and to be removed in DM-44988
if self.config.compute_stats and (n_bands > 0):
do_stats = self.config.compute_stats and (n_bands > 0)
if do_stats:
# Slightly smelly hack for when a column (like distance) is already relative to truth
column_dummy = 'dummy'
cat_ref[column_dummy] = np.zeros_like(ref.coord1)
Expand All @@ -831,7 +852,7 @@ def run(
# TODO: remove the assumption of a boolean column
extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)

extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut
extended_target = cat_target[config.column_target_extended] >= config.extendedness_cut

# Define difference/chi columns and statistics thereof
suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
Expand Down Expand Up @@ -999,7 +1020,7 @@ def run(

if n_match > 0:
rows_matched = match_row_bin[match_good]
subset_target = cat_target.iloc[rows_matched]
subset_target = cat_target[rows_matched]
if (is_extended is not None) and (idx_model == 0):
right_type = extended_target[rows_matched] == is_extended
n_total = len(right_type)
Expand All @@ -1016,15 +1037,15 @@ def run(
# compute stats for this bin, for all columns
for column, (column_ref, column_target, column_err_target, skip_diff) \
in columns_target.items():
values_ref = cat_ref[column_ref][match_good].values
values_ref = cat_ref[column_ref][match_good]
errors_target = (
subset_target[column_err_target].values
subset_target[column_err_target]
if column_err_target is not None
else None
)
compute_stats(
values_ref,
subset_target[column_target].values,
subset_target[column_target],
errors_target,
row,
stats,
Expand Down Expand Up @@ -1066,7 +1087,10 @@ def run(
mag_ref_first = mag_ref

if config.include_unmatched:
cat_matched = pd.concat((cat_matched, cat_unmatched))
# This is probably less efficient than just doing an outer join originally; worth checking
cat_matched = astropy.table.vstack([cat_matched, cat_unmatched])

retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data))
retStruct = pipeBase.Struct(cat_matched=cat_matched)
if do_stats:
retStruct.diff_matched = astropy.table.Table(data)
return retStruct
26 changes: 13 additions & 13 deletions python/lsst/pipe/tasks/match_tract_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from abc import ABC, abstractmethod

import astropy.table
import pandas as pd
from typing import Tuple, Set

Expand All @@ -50,14 +51,14 @@ class MatchTractCatalogConnections(
cat_ref = cT.Input(
doc="Reference object catalog to match from",
name="{name_input_cat_ref}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
cat_target = cT.Input(
doc="Target object catalog to match",
name="{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
Expand All @@ -67,17 +68,16 @@ class MatchTractCatalogConnections(
storageClass="SkyMap",
dimensions=("skymap",),
)
# TODO: Change outputs to ArrowAstropy in DM-44159
cat_output_ref = cT.Output(
doc="Reference matched catalog with indices of target matches",
name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)
cat_output_target = cT.Output(
doc="Target matched catalog with indices of reference matches",
name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)

Expand Down Expand Up @@ -128,17 +128,17 @@ def __init__(self, **kwargs):
@abstractmethod
def run(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
catalog_ref: pd.DataFrame | astropy.table.Table,
catalog_target: pd.DataFrame | astropy.table.Table,
wcs: afwGeom.SkyWcs = None,
) -> pipeBase.Struct:
"""Match sources in a reference tract catalog with a target catalog.

Parameters
----------
catalog_ref : `pandas.DataFrame`
catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
A reference catalog to match objects/sources from.
catalog_target : `pandas.DataFrame`
catalog_target : `pandas.DataFrame` | `astropy.table.Table`
A target catalog to match reference objects/sources to.
wcs : `lsst.afw.image.SkyWcs`
A coordinate system to convert catalog positions to sky coordinates.
Expand Down Expand Up @@ -211,17 +211,17 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):

def run(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
catalog_ref: pd.DataFrame | astropy.table.Table,
catalog_target: pd.DataFrame | astropy.table.Table,
wcs: afwGeom.SkyWcs = None,
) -> pipeBase.Struct:
"""Match sources in a reference tract catalog with a target catalog.

Parameters
----------
catalog_ref : `pandas.DataFrame`
catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
A reference catalog to match objects/sources from.
catalog_target : `pandas.DataFrame`
catalog_target : `pandas.DataFrame` | `astropy.table.Table`
A target catalog to match reference objects/sources to.
wcs : `lsst.afw.image.SkyWcs`
A coordinate system to convert catalog positions to sky coordinates,
Expand Down
Loading
Loading