Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-44287: Refactor error handling to work better with measurement framework and quiet excessive logging #37

Merged
merged 4 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 53 additions & 44 deletions python/lsst/meas/extensions/gaap/_gaap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,12 @@


class GaapConvolutionError(measBase.MeasurementError):
"""Collection of any unexpected errors in GAaP during PSF Gaussianization.

The PSF Gaussianization procedure using `modelPsfMatchTask` may throw
exceptions for certain target PSFs. Such errors are caught until all
measurements are at least attempted. The complete traceback information
is lost, but unique error messages are preserved.

Parameters
----------
errors : `dict` [`str`, `Exception`]
The values are exceptions raised, while the keys are the loop variables
(in `str` format) where the exceptions were raised.
"""Raised when there is an error in GAaP convolution.
"""
def __init__(self, errors: dict[str, Exception]):
self.errorDict = errors
message = "Problematic scaling factors = "
message += ", ".join(errors)
message += " Errors: "
message += " | ".join(set(msg.__repr__() for msg in errors.values())) # msg.cpp.what() misses type
super().__init__(message, 1) # the second argument does not matter.


class NoPixelError(Exception):
class NoPixelError(measBase.MeasurementError):
"""Raised when the footprint has no pixels.

This is caught by the measurement framework, which then calls the
`fail` method of the plugin without passing in a value for `error`.
"""


Expand Down Expand Up @@ -175,7 +154,6 @@ def _sigmas(self) -> list:

def setDefaults(self) -> None:
# Docstring inherited
# TODO: DM-27482 might change these values.
self._modelPsfMatch.kernel.active.alardNGauss = 1
self._modelPsfMatch.kernel.active.alardDegGaussDeconv = 1
self._modelPsfMatch.kernel.active.alardDegGauss = [4]
Expand Down Expand Up @@ -557,21 +535,28 @@ def _gaussianizeAndMeasure(self, measRecord: lsst.afw.table.SourceRecord,
This method is the entry point to the mixin from the concrete derived
classes.
"""
# First make sure we have a PSF.
if (psf := exposure.getPsf()) is None:
raise measBase.FatalAlgorithmError("No PSF in exposure")

# Raise errors if the plugin would fail for this record for all
# scaling factors and sigmas.
if measRecord.getFootprint().getArea() == 0:
self._setFlag(measRecord, self.name, "no_pixel")
raise NoPixelError

if (psf := exposure.getPsf()) is None:
raise measBase.FatalAlgorithmError("No PSF in exposure")
self._setScalingAndSigmaFlags(measRecord, self.config.scalingFactors)
raise NoPixelError("No good pixels in footprint", 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are moving this around, let's move this below the next block that raises the FatalAlgorithmError. On the very off chance something went so bad in processing that all footprints have zero area and no PSF available, I'd want it to fail with a FatalAlgorithmError.


psfSigma = psf.computeShape(center).getTraceRadius()
if not (psfSigma > 0): # This captures NaN and negative values.
errorCollection = {str(scalingFactor): measBase.MeasurementError("PSF size could not be measured")
for scalingFactor in self.config.scalingFactor}
raise GaapConvolutionError(errorCollection)
center = measRecord.getCentroid()
self.log.debug("Invalid PSF sigma; cannot solve for PSF matching kernel in GAaP for (%f, %f): %s",
center.getX(), center.getY(), "GAaP Convolution Error")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the error message say fail to solve for kernel when we don't even attempt solving for it? It made sense when we were handling this in fail but we can be more specific here.

self._setScalingAndSigmaFlags(
measRecord,
self.config.scalingFactors,
specificFlag="flag_gaussianization",
)
raise GaapConvolutionError("Failed to solve for PSF matching kernel", 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd raise this as a plain MeasurementError and keep the error handling in GaapConvolutionError instead of trivially subclassing it from MeasurementError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean. I don't see the harm in having a slightly more specific error so that if you look at the debug log there's marginally more information (though there is still the message). But I don't like having the error handling in the Error code because that's not what exceptions are supposed to do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, if we want to keep error handling/error collating outside of GaapConvolutionError, then raising this specific type is fine.

else:
errorCollection = dict()

Expand Down Expand Up @@ -630,7 +615,19 @@ def _gaussianizeAndMeasure(self, measRecord: lsst.afw.table.SourceRecord,
# Raise GaapConvolutionError before exiting the plugin
# if the collection of errors is not empty
if errorCollection:
raise GaapConvolutionError(errorCollection)
message = "Problematic scaling factors = "
message += ", ".join(errorCollection)
message += " Errors: "
message += " | ".join(set(msg.__repr__() for msg in errorCollection.values()))
center = measRecord.getCentroid()
self.log.debug("Failed to solve for PSF matching kernel in GAaP for (%f, %f): %s",
center.getX(), center.getY(), message)
self._setScalingAndSigmaFlags(
measRecord,
errorCollection.keys(),
specificFlag="flag_gaussianization",
)
raise GaapConvolutionError("Failed to solve for PSF matching kernel", 1)

@staticmethod
def _setFlag(measRecord, baseName, flagName=None):
Expand Down Expand Up @@ -658,6 +655,27 @@ def _setFlag(measRecord, baseName, flagName=None):
genericFlagKey = measRecord.schema.join(baseName, "flag")
measRecord.set(genericFlagKey, True)

def _setScalingAndSigmaFlags(self, measRecord, scalingFactors, specificFlag=None):
"""Set a full suite of flags for scalingFactors/sigmas.

Parameters
----------
measRecord : `~lsst.afw.table.SourceRecord`
Record describing the source being measured.
scalingFactors : `list` [`float`]
List of scaling factors.
specificFlag : `str`, optional
Specific type of flag to set if needed.
"""
for scalingFactor in scalingFactors:
if specificFlag is not None:
flagName = self.ConfigClass._getGaapResultName(scalingFactor, specificFlag,
self.name)
measRecord.set(flagName, True)
for sigma in self.config._sigmas:
baseName = self.ConfigClass._getGaapResultName(scalingFactor, sigma, self.name)
self._setFlag(measRecord, baseName)

def _isAllFailure(self, measRecord, scalingFactor, targetSigma) -> bool:
"""Check if all measurements would result in failure.

Expand Down Expand Up @@ -722,18 +740,9 @@ def fail(self, measRecord, error=None):
error : `Exception`
Error causing failure, or `None`.
"""
if error is not None:
center = measRecord.getCentroid()
self.log.error("Failed to solve for PSF matching kernel in GAaP for (%f, %f): %s",
center.getX(), center.getY(), error)
for scalingFactor in error.errorDict:
flagName = self.ConfigClass._getGaapResultName(scalingFactor, "flag_gaussianization",
self.name)
measRecord.set(flagName, True)
for sigma in self.config._sigmas:
baseName = self.ConfigClass._getGaapResultName(scalingFactor, sigma, self.name)
self._setFlag(measRecord, baseName)
else:
# We only need to set the failKey if no error was specified which
# signifies that the flagging was already handled.
if error is None:
measRecord.set(self._failKey, True)


Expand Down
3 changes: 2 additions & 1 deletion python/lsst/meas/extensions/gaap/_gaussianizePsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def _solve(self, kernelCellSet, basisList):
spatialKernel, spatialBackground = spatialkv.getSolutionPair()
spatialSolution = spatialkv.getKernelSolution()
except Exception as e:
self.log.error("ERROR: Unable to calculate psf matching kernel")
# This is just a debug log because it is caught by the GAaP plugin.
self.log.debug("Unable to calculate psf matching kernel")
getTraceLogger(self.log.getChild("_solve"), 1).debug("%s", e)
raise e

Expand Down
92 changes: 77 additions & 15 deletions tests/test_gaap.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def check(self, psfSigma=0.5, flux=1000., scalingFactors=[1.15], forced=False):
measConfig = TaskClass.ConfigClass()
algName = "ext_gaap_GaapFlux"

# Remove sky coordinate plugin because we don't have the columns
# in the tests.
if "base_SkyCoord" in measConfig.plugins.names:
measConfig.plugins.names.remove("base_SkyCoord")

measConfig.plugins.names.add(algName)

if forced:
Expand Down Expand Up @@ -259,43 +264,100 @@ def testFail(self, scalingFactors=[100.], sigmas=[500.]):
exposure, catalog = self.dataset.realize(0.0, sfmTask.schema)
self.recordPsfShape(catalog)

# Expected error messages in the logs when running `sfmTask`.
# Expected debug messages in the logs when running `sfmTask`.
errorMessage = [("Failed to solve for PSF matching kernel in GAaP for (100.000000, 670.000000): "
"Problematic scaling factors = 100.0 "
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')"),
("MeasurementError in ext_gaap_GaapFlux.measure on record 1: "
"Failed to solve for PSF matching kernel"),
("Failed to solve for PSF matching kernel in GAaP for (100.000000, 870.000000): "
"Problematic scaling factors = 100.0 "
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')"),
("MeasurementError in ext_gaap_GaapFlux.measure on record 2: "
"Failed to solve for PSF matching kernel"),
("Failed to solve for PSF matching kernel in GAaP for (-10.000000, -20.000000): "
"Problematic scaling factors = 100.0 "
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')")]
"Errors: RuntimeError('Unable to determine kernel sum; 0 candidates')"),
("MeasurementError in ext_gaap_GaapFlux.measure on record 3: "
"Failed to solve for PSF matching kernel")]

testCatalog = catalog.copy(deep=True)
plugin_logger_name = sfmTask.log.getChild(algName).name
self.assertEqual(plugin_logger_name, "lsst.measurement.ext_gaap_GaapFlux")
with self.assertLogs(plugin_logger_name, "ERROR") as cm:
sfmTask.run(catalog, exposure)
with self.assertLogs(plugin_logger_name, "DEBUG") as cm:
sfmTask.run(testCatalog, exposure)
self.assertEqual([record.message for record in cm.records], errorMessage)

self._checkAllFlags(
testCatalog,
algName,
scalingFactors,
sigmas,
gaapConfig,
specificFlag="flag_gaussianization",
)

# Trigger a "not (psfSigma > 0) error":
exposureJunkPsf = exposure.clone()
testCatalog = catalog.copy(deep=True)
junkPsf = afwDetection.GaussianPsf(1, 1, 0)
exposureJunkPsf.setPsf(junkPsf)
sfmTask.run(testCatalog, exposureJunkPsf)

self._checkAllFlags(
testCatalog,
algName,
scalingFactors,
sigmas,
gaapConfig,
specificFlag="flag_gaussianization",
)

# Trigger a NoPixelError.
testCatalog = catalog.copy(deep=True)
testCatalog[0].setFootprint(afwDetection.Footprint())
with self.assertLogs(plugin_logger_name, "DEBUG") as cm:
sfmTask.run(testCatalog, exposure)

self.assertEqual(
cm.records[0].message,
"MeasurementError in ext_gaap_GaapFlux.measure on record 1: No good pixels in footprint",
)
self.assertEqual(testCatalog[f"{algName}_flag_no_pixel"][0], True)
self.assertEqual(testCatalog[f"{algName}_flag"][0], True)

self._checkAllFlags(testCatalog[0: 1], algName, scalingFactors, sigmas, gaapConfig, allFailFlag=True)

# Try and "fail" with no PSF.
# Since fatal exceptions are not caught by the measurement framework,
# use a context manager and catch it here.
exposure.setPsf(None)
with self.assertRaises(lsst.meas.base.FatalAlgorithmError):
sfmTask.run(catalog, exposure)

def _checkAllFlags(
self,
catalog,
algName,
scalingFactors,
sigmas,
gaapConfig,
specificFlag=None,
allFailFlag=False
):
for record in catalog:
self.assertFalse(record[algName + "_flag"])
self.assertEqual(record[algName + "_flag"], allFailFlag)
for scalingFactor in scalingFactors:
flagName = gaapConfig._getGaapResultName(scalingFactor, "flag_gaussianization", algName)
self.assertTrue(record[flagName])
if specificFlag is not None:
flagName = gaapConfig._getGaapResultName(scalingFactor, specificFlag, algName)
self.assertTrue(record[flagName])
for sigma in sigmas + ["Optimal"]:
baseName = gaapConfig._getGaapResultName(scalingFactor, sigma, algName)
self.assertTrue(record[baseName + "_flag"])
self.assertFalse(record[baseName + "_flag_bigPsf"])

baseName = gaapConfig._getGaapResultName(scalingFactor, "PsfFlux", algName)
self.assertTrue(record[baseName + "_flag"])

# Try and "fail" with no PSF.
# Since fatal exceptions are not caught by the measurement framework,
# use a context manager and catch it here.
exposure.setPsf(None)
with self.assertRaises(lsst.meas.base.FatalAlgorithmError):
sfmTask.run(catalog, exposure)

def testFlags(self, sigmas=[0.4, 0.5, 0.7], scalingFactors=[1.15, 1.25, 1.4, 100.]):
"""Test that GAaP flags are set properly.

Expand Down
Loading