Skip to content

Commit

Permalink
Replace pandas.DataFrame usage with astropy.Table
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Sep 16, 2024
1 parent 5065538 commit 5d20891
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 81 deletions.
148 changes: 86 additions & 62 deletions python/lsst/pipe/tasks/diff_matched_tract_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from abc import ABCMeta, abstractmethod
from astropy.stats import mad_std
import astropy.table
import astropy.units as u
from dataclasses import dataclass
from decimal import Decimal
Expand Down Expand Up @@ -75,14 +76,14 @@ class DiffMatchedTractCatalogConnections(
cat_ref = cT.Input(
doc="Reference object catalog to match from",
name="{name_input_cat_ref}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
cat_target = cT.Input(
doc="Target object catalog to match",
name="{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
Expand All @@ -95,33 +96,33 @@ class DiffMatchedTractCatalogConnections(
cat_match_ref = cT.Input(
doc="Reference match catalog with indices of target matches",
name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
cat_match_target = cT.Input(
doc="Target match catalog with indices of references matches",
name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
columns_match_target = cT.Input(
doc="Target match catalog columns",
name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
storageClass="DataFrameIndex",
storageClass="ArrowColumnList",
dimensions=("tract", "skymap"),
)
cat_matched = cT.Output(
doc="Catalog with reference and target columns for joined sources",
name="matched_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)
diff_matched = cT.Output(
doc="Table with aggregated counts, difference and chi statistics",
name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)

Expand All @@ -137,6 +138,8 @@ def __init__(self, *, config=None):
dimensions=(),
deferLoad=old.deferLoad,
)
if not (config.compute_stats and len(config.columns_flux) > 0):
del self.diff_matched


class MatchedCatalogFluxesConfig(pexConfig.Config):
Expand Down Expand Up @@ -685,25 +688,25 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):

def run(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
catalog_match_ref: pd.DataFrame,
catalog_match_target: pd.DataFrame,
catalog_ref: pd.DataFrame | astropy.table.Table,
catalog_target: pd.DataFrame | astropy.table.Table,
catalog_match_ref: pd.DataFrame | astropy.table.Table,
catalog_match_target: pd.DataFrame | astropy.table.Table,
wcs: afwGeom.SkyWcs = None,
) -> pipeBase.Struct:
"""Load matched reference and target (measured) catalogs, measure summary statistics, and output
a combined matched catalog with columns from both inputs.
Parameters
----------
catalog_ref : `pandas.DataFrame`
catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
A reference catalog to diff objects/sources from.
catalog_target : `pandas.DataFrame`
catalog_target : `pandas.DataFrame` | `astropy.table.Table`
A target catalog to diff reference objects/sources to.
catalog_match_ref : `pandas.DataFrame`
catalog_match_ref : `pandas.DataFrame` | `astropy.table.Table`
A catalog with match indices of target sources and selection flags
for each reference source.
catalog_match_target : `pandas.DataFrame`
catalog_match_target : `pandas.DataFrame` | `astropy.table.Table`
A catalog with selection flags for each target source.
wcs : `lsst.afw.image.SkyWcs`
A coordinate system to convert catalog positions to sky coordinates,
Expand All @@ -718,16 +721,31 @@ def run(
# Would be nice if this could refer directly to ConfigClass
config: DiffMatchedTractCatalogConfig = self.config

select_ref = catalog_match_ref['match_candidate'].values
is_ref_pd = isinstance(catalog_ref, pd.DataFrame)
is_target_pd = isinstance(catalog_target, pd.DataFrame)
is_match_ref_pd = isinstance(catalog_match_ref, pd.DataFrame)
is_match_target_pd = isinstance(catalog_match_target, pd.DataFrame)
if is_ref_pd:
catalog_ref = astropy.table.Table.from_pandas(catalog_ref)
if is_target_pd:
catalog_target = astropy.table.Table.from_pandas(catalog_target)
if is_match_ref_pd:
catalog_match_ref = astropy.table.Table.from_pandas(catalog_match_ref)
if is_match_target_pd:
catalog_match_target = astropy.table.Table.from_pandas(catalog_match_target)
if is_ref_pd or is_target_pd:
self.log.warning("pandas DataFrame inputs are deprecated")

select_ref = catalog_match_ref['match_candidate']
# Add additional selection criteria for target sources beyond those for matching
# (not recommended, but can be done anyway)
select_target = (catalog_match_target['match_candidate'].values
select_target = (catalog_match_target['match_candidate']
if 'match_candidate' in catalog_match_target.columns
else np.ones(len(catalog_match_target), dtype=bool))
for column in config.columns_target_select_true:
select_target &= catalog_target[column].values
select_target &= catalog_target[column]
for column in config.columns_target_select_false:
select_target &= ~catalog_target[column].values
select_target &= ~catalog_target[column]

ref, target = config.coord_format.format_catalogs(
catalog_ref=catalog_ref, catalog_target=catalog_target,
Expand All @@ -739,9 +757,9 @@ def run(

if config.include_unmatched:
for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
cat_add['match_candidate'] = cat_match['match_candidate'].values
cat_add['match_candidate'] = cat_match['match_candidate']

match_row = catalog_match_ref['match_row'].values
match_row = catalog_match_ref['match_row']
matched_ref = match_row >= 0
matched_row = match_row[matched_ref]
matched_target = np.zeros(n_target, dtype=bool)
Expand All @@ -761,48 +779,44 @@ def run(
) if config.coord_format.coords_spherical else np.hypot(
target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2,
)
cat_target_matched = cat_target[matched_row]
# This will convert a masked array to an array filled with nans
# wherever there are bad values (otherwise sphdist can raise)
c1_err, c2_err = (
np.ma.getdata(cat_target_matched[c_err]) for c_err in (coord1_target_err, coord2_target_err)
)
# Should probably explicitly add cosine terms if ref has errors too
dist_err[matched_row] = sphdist(
target_match_c1, target_match_c2,
target_match_c1 + cat_target.iloc[matched_row][coord1_target_err].values,
target_match_c2 + cat_target.iloc[matched_row][coord2_target_err].values,
) if config.coord_format.coords_spherical else np.hypot(
cat_target.iloc[matched_row][coord1_target_err].values,
cat_target.iloc[matched_row][coord2_target_err].values
)
target_match_c1, target_match_c2, target_match_c1 + c1_err, target_match_c2 + c2_err
) if config.coord_format.coords_spherical else np.hypot(c1_err, c2_err)
cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err

# Create a matched table, preserving the target catalog's named index (if it has one)
cat_left = cat_target.iloc[matched_row]
has_index_left = cat_left.index.name is not None
cat_right = cat_ref[matched_ref].reset_index()
cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns]
cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1)
cat_left = cat_target[matched_row]
cat_right = cat_ref[matched_ref]
cat_right.rename_columns(
list(cat_right.columns),
new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns],
)
cat_matched = astropy.table.hstack((cat_left, cat_right))

if config.include_unmatched:
# Create an unmatched table with the same schema as the matched one
# ... but only for objects with no matches (for completeness/purity)
# and that were selected for matching (or inclusion via config)
cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False)
cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns)
match_row_target = catalog_match_target['match_row'].values
cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index(
drop=not has_index_left)
cat_right = astropy.table.Table(
cat_ref[~matched_ref & select_ref]
)
cat_right.rename_columns(
cat_right.colnames,
[f"{config.column_matched_prefix_ref}{col}" for col in cat_right.colnames],
)
match_row_target = catalog_match_target['match_row']
cat_left = cat_target[~(match_row_target >= 0) & select_target]
# This may be slower than pandas but will, for example, create
# masked columns for booleans, which pandas does not support.
# See https://github.com/pandas-dev/pandas/issues/46662
# astropy masked columns would handle this much more gracefully
# Unfortunately, that would require storageClass migration
# So we use pandas "extended" nullable types for now
for cat_i in (cat_left, cat_right):
for colname in cat_i.columns:
column = cat_i[colname]
dtype = str(column.dtype)
if dtype == "bool":
cat_i[colname] = column.astype("boolean")
elif dtype.startswith("int"):
cat_i[colname] = column.astype(f"Int{dtype[3:]}")
elif dtype.startswith("uint"):
cat_i[colname] = column.astype(f"UInt{dtype[3:]}")
cat_unmatched = pd.concat(objs=(cat_left, cat_right))
cat_unmatched = astropy.table.vstack([cat_left, cat_right])

for columns_convert_base, prefix in (
(config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref),
Expand All @@ -812,8 +826,14 @@ def run(
columns_convert = {
f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items()
} if prefix else columns_convert_base
for cat_convert in (cat_matched, cat_unmatched):
cat_convert.rename(columns=columns_convert, inplace=True)
to_convert = [cat_matched]
if config.include_unmatched:
to_convert.append(cat_unmatched)
for cat_convert in to_convert:
cat_convert.rename_columns(
tuple(columns_convert.keys()),
tuple(columns_convert.values()),
)
for column_flux in columns_convert.values():
cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])

Expand All @@ -822,7 +842,8 @@ def run(
n_bands = len(band_fluxes)

# TODO: Deprecated by RFC-1017 and to be removed in DM-44988
if self.config.compute_stats and (n_bands > 0):
do_stats = self.config.compute_stats and (n_bands > 0)
if do_stats:
# Slightly smelly hack for when a column (like distance) is already relative to truth
column_dummy = 'dummy'
cat_ref[column_dummy] = np.zeros_like(ref.coord1)
Expand All @@ -831,7 +852,7 @@ def run(
# TODO: remove the assumption of a boolean column
extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)

extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut
extended_target = cat_target[config.column_target_extended] >= config.extendedness_cut

# Define difference/chi columns and statistics thereof
suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
Expand Down Expand Up @@ -999,7 +1020,7 @@ def run(

if n_match > 0:
rows_matched = match_row_bin[match_good]
subset_target = cat_target.iloc[rows_matched]
subset_target = cat_target[rows_matched]
if (is_extended is not None) and (idx_model == 0):
right_type = extended_target[rows_matched] == is_extended
n_total = len(right_type)
Expand All @@ -1016,15 +1037,15 @@ def run(
# compute stats for this bin, for all columns
for column, (column_ref, column_target, column_err_target, skip_diff) \
in columns_target.items():
values_ref = cat_ref[column_ref][match_good].values
values_ref = cat_ref[column_ref][match_good]
errors_target = (
subset_target[column_err_target].values
subset_target[column_err_target]
if column_err_target is not None
else None
)
compute_stats(
values_ref,
subset_target[column_target].values,
subset_target[column_target],
errors_target,
row,
stats,
Expand Down Expand Up @@ -1066,7 +1087,10 @@ def run(
mag_ref_first = mag_ref

if config.include_unmatched:
cat_matched = pd.concat((cat_matched, cat_unmatched))
# This is probably less efficient than just doing an outer join originally; worth checking
cat_matched = astropy.table.vstack([cat_matched, cat_unmatched])

retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data))
retStruct = pipeBase.Struct(cat_matched=cat_matched)
if do_stats:
retStruct.diff_matched = astropy.table.Table(data)
return retStruct
26 changes: 13 additions & 13 deletions python/lsst/pipe/tasks/match_tract_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from abc import ABC, abstractmethod

import astropy.table
import pandas as pd
from typing import Tuple, Set

Expand All @@ -50,14 +51,14 @@ class MatchTractCatalogConnections(
cat_ref = cT.Input(
doc="Reference object catalog to match from",
name="{name_input_cat_ref}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
cat_target = cT.Input(
doc="Target object catalog to match",
name="{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
deferLoad=True,
)
Expand All @@ -67,17 +68,16 @@ class MatchTractCatalogConnections(
storageClass="SkyMap",
dimensions=("skymap",),
)
# TODO: Change outputs to ArrowAstropy in DM-44159
cat_output_ref = cT.Output(
doc="Reference matched catalog with indices of target matches",
name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)
cat_output_target = cT.Output(
doc="Target matched catalog with indices of reference matches",
name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
storageClass="DataFrame",
storageClass="ArrowAstropy",
dimensions=("tract", "skymap"),
)

Expand Down Expand Up @@ -128,17 +128,17 @@ def __init__(self, **kwargs):
@abstractmethod
def run(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
catalog_ref: pd.DataFrame | astropy.table.Table,
catalog_target: pd.DataFrame | astropy.table.Table,
wcs: afwGeom.SkyWcs = None,
) -> pipeBase.Struct:
"""Match sources in a reference tract catalog with a target catalog.
Parameters
----------
catalog_ref : `pandas.DataFrame`
catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
A reference catalog to match objects/sources from.
catalog_target : `pandas.DataFrame`
catalog_target : `pandas.DataFrame` | `astropy.table.Table`
A target catalog to match reference objects/sources to.
wcs : `lsst.afw.image.SkyWcs`
A coordinate system to convert catalog positions to sky coordinates.
Expand Down Expand Up @@ -211,17 +211,17 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):

def run(
self,
catalog_ref: pd.DataFrame,
catalog_target: pd.DataFrame,
catalog_ref: pd.DataFrame | astropy.table.Table,
catalog_target: pd.DataFrame | astropy.table.Table,
wcs: afwGeom.SkyWcs = None,
) -> pipeBase.Struct:
"""Match sources in a reference tract catalog with a target catalog.
Parameters
----------
catalog_ref : `pandas.DataFrame`
catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
A reference catalog to match objects/sources from.
catalog_target : `pandas.DataFrame`
catalog_target : `pandas.DataFrame` | `astropy.table.Table`
A target catalog to match reference objects/sources to.
wcs : `lsst.afw.image.SkyWcs`
A coordinate system to convert catalog positions to sky coordinates,
Expand Down
Loading

0 comments on commit 5d20891

Please sign in to comment.