Skip to content

Commit

Permalink
Merge branch 'tickets/DM-47526'
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Jan 16, 2025
2 parents afaee74 + 1f908ed commit 8ffcb69
Showing 1 changed file with 64 additions and 4 deletions.
68 changes: 64 additions & 4 deletions python/lsst/pipe/tasks/fit_coadd_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,44 @@ def adjustQuantum(self, inputs, outputs, label, data_id):
# Check which bands are going to be fit
bands_fit, bands_read_only = self.config.get_band_sets()
bands_needed = bands_fit + [band for band in bands_read_only if band not in bands_fit]
bands_needed_set = set(bands_needed)

adjusted_inputs = {}
bands_found, connection_first = None, None
for connection_name, (connection, dataset_refs) in inputs.items():
# Datasets without bands in their dimensions should be fine
if 'band' in connection.dimensions:
datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs}
if not set(bands_needed).issubset(datasets_by_band.keys()):
bands_set = set(datasets_by_band.keys())
if self.config.allow_missing_bands:
# Use the first dataset found as the reference since all
# dataset types with band should have the same bands
# This will only break if one of the calexp/meas datasets
# is missing from a given band, which would surely be an
# upstream problem anyway
if bands_found is None:
bands_found, connection_first = bands_set, connection_name
if len(bands_found) == 0:
raise pipeBase.NoWorkFound(
f'DatasetRefs={dataset_refs} for {connection_name=} is empty'
)
elif not set(bands_read_only).issubset(bands_set):
raise pipeBase.NoWorkFound(
f'DatasetRefs={dataset_refs} has {bands_set=} which is missing at least one'
f' of {bands_read_only=}'
)
# Put the bands to fit first, then any other bands
# needed for initialization/priors only last
bands_needed = [band for band in bands_fit if band in bands_found] + [
band for band in bands_read_only if band not in bands_found
]
elif bands_found != bands_set:
raise RuntimeError(
f'DatasetRefs={dataset_refs} with {connection_name=} has {bands_set=} !='
f' {bands_found=} from {connection_first=}'
)
# All configured bands are treated as necessary
elif not bands_needed_set.issubset(bands_set):
raise pipeBase.NoWorkFound(
f'DatasetRefs={dataset_refs} have data with bands in the'
f' set={set(datasets_by_band.keys())},'
Expand Down Expand Up @@ -252,6 +283,10 @@ class CoaddMultibandFitBaseConfig(
):
"""Base class for multiband fitting."""

allow_missing_bands = pexConfig.Field[bool](
doc="Whether to still fit even if some bands are missing",
default=True,
)
drop_psf_connection = pexConfig.Field[bool](
doc="Whether to drop the PSF model connection, e.g. because PSF parameters are in the input catalog",
default=False,
Expand Down Expand Up @@ -311,7 +346,8 @@ def build_catexps(self, butlerQC, inputRefs, inputs) -> list[CatalogExposureInpu
models_psf = inputs_sorted[2] if has_psf_models else None
dataIds = set(cats).union(set(exps))
models_scarlet = inputs["models_scarlet"]
catexps = {}
catexp_dict = {}
dataId = None
for dataId in dataIds:
catalog = cats[dataId]
exposure = exps[dataId]
Expand All @@ -323,14 +359,33 @@ def build_catexps(self, butlerQC, inputRefs, inputs) -> list[CatalogExposureInpu
removeScarletData=True,
updateFluxColumns=False,
)
catexps[dataId['band']] = CatalogExposureInputs(
catexp_dict[dataId['band']] = CatalogExposureInputs(
catalog=catalog,
exposure=exposure,
table_psf_fits=models_psf[dataId] if has_psf_models else astropy.table.Table(),
dataId=dataId,
id_tract_patch=id_tp,
)
catexps = [catexps[band] for band in self.config.get_band_sets()[0]]
# This shouldn't happen unless this is called with no inputs, but check anyway
if dataId is None:
raise RuntimeError(f"Did not build any catexps for {inputRefs=}")
catexps = []
for band in self.config.get_band_sets()[0]:
if band in catexp_dict:
catexp = catexp_dict[band]
else:
# Make a dummy catexp with a dataId if there's no data
# This should be handled by any subtasks
dataId_band = dataId.to_simple(minimal=True)
dataId_band.dataId["band"] = band
catexp = CatalogExposureInputs(
catalog=afwTable.SourceCatalog(),
exposure=None,
table_psf_fits=astropy.table.Table(),
dataId=dataId.from_simple(dataId_band, universe=dataId.universe),
id_tract_patch=id_tp,
)
catexps.append(catexp)
return catexps


Expand Down Expand Up @@ -360,6 +415,11 @@ def make_kwargs(self, butlerQC, inputRefs, inputs):
def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
catexps = self.build_catexps(butlerQC, inputRefs, inputs)
if not self.config.allow_missing_bands and any([catexp is None for catexp in catexps]):
raise RuntimeError(
f"Got a None catexp with {self.config.allow_missing_band=}; NoWorkFound should have been"
f" raised earlier"
)
kwargs = self.make_kwargs(butlerQC, inputRefs, inputs)
outputs = self.run(catexps=catexps, cat_ref=inputs['cat_ref'], **kwargs)
butlerQC.put(outputs, outputRefs)
Expand Down

0 comments on commit 8ffcb69

Please sign in to comment.