From a969a2001045a31bef38f2ab5ebcaeaabeefcc57 Mon Sep 17 00:00:00 2001 From: Aaron Watkins Date: Wed, 20 Nov 2024 01:33:50 -0800 Subject: [PATCH] Refactor AEW --- python/lsst/pipe/tasks/matchBackgrounds.py | 187 +++------------------ 1 file changed, 21 insertions(+), 166 deletions(-) diff --git a/python/lsst/pipe/tasks/matchBackgrounds.py b/python/lsst/pipe/tasks/matchBackgrounds.py index d2e0cf2be..b8b196fda 100644 --- a/python/lsst/pipe/tasks/matchBackgrounds.py +++ b/python/lsst/pipe/tasks/matchBackgrounds.py @@ -23,7 +23,7 @@ import lsstDebug import numpy as np -from lsst.afw.image import LOCAL, PARENT, ExposureF, ImageF, Mask, MaskedImageF +from lsst.afw.image import LOCAL, PARENT, ImageF, Mask, MaskedImageF from lsst.afw.math import ( MEAN, MEANCLIP, @@ -43,7 +43,7 @@ stringToStatisticsProperty, stringToUndersampleStyle, ) -from lsst.geom import Box2D, Box2I, PointI +from lsst.geom import Box2I, PointI from lsst.pex.config import ChoiceField, Field, ListField, RangeField from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct, TaskError from lsst.pipe.base.connectionTypes import Input, Output @@ -137,7 +137,16 @@ class MatchBackgroundsConfig(PipelineTaskConfig, pipelineConnections=MatchBackgr ) badMaskPlanes = ListField[str]( doc="Names of mask planes to ignore while estimating the background.", - default=["NO_DATA", "DETECTED", "DETECTED_NEGATIVE", "SAT", "BAD", "INTRP", "CR"], + default=[ + "NO_DATA", + "DETECTED", + "DETECTED_NEGATIVE", + "SAT", + "BAD", + "INTRP", + "CR", + "NOT_DEBLENDED", + ], itemCheck=lambda x: x in Mask().getMaskPlaneDict(), ) gridStatistic = ChoiceField( @@ -237,9 +246,6 @@ def __init__(self, *args, **kwargs): super().__init__(**kwargs) self.statsFlag = stringToStatisticsProperty(self.config.gridStatistic) self.statsCtrl = StatisticsControl() - # TODO: Check that setting the mask planes here work - these planes - # can vary from exposure to exposure, I think? - # Aaron: I think only the bit values vary, not the names, which this is referencing. self.statsCtrl.setAndMask(Mask.getPlaneBitMask(self.config.badMaskPlanes)) self.statsCtrl.setNanSafe(True) self.statsCtrl.setNumSigmaClip(self.config.numSigmaClip) @@ -278,7 +284,7 @@ def run(self, warps): raise TaskError("No exposures to match") # Define a reference warp; 'warps' is modified in-place to exclude it - refWarp, refInd = self._defineWarps(warps=warps, refWarpVisit=self.config.refWarpVisit) + refWarp, refInd, bkgd = self._defineWarps(warps=warps, refWarpVisit=self.config.refWarpVisit) # Images must be scaled to a common ZP # Converting everything to nJy to accomplish this @@ -287,29 +293,11 @@ def run(self, warps): self.log.info("Matching %d Exposures", numExp) - # Creating a null BackgroundList object by fitting a blank image - statsFlag = stringToStatisticsProperty(self.config.gridStatistic) - self.statsCtrl.setNumSigmaClip(self.config.numSigmaClip) - self.statsCtrl.setNumIter(self.config.numIter) - - # TODO: refactor below to construct blank bg model - im = refExposure.getMaskedImage() - blankIm = im.clone() - blankIm.image.array *= 0 - - width = blankIm.getWidth() - height = blankIm.getHeight() - nx = width // self.config.binSize - if width % self.config.binSize != 0: - nx += 1 - ny = height // self.config.binSize - if height % self.config.binSize != 0: - ny += 1 - - bctrl = BackgroundControl(nx, ny, self.statsCtrl, statsFlag) - bctrl.setUndersampleStyle(self.config.undersampleStyle) - - bkgd = makeBackground(blankIm, bctrl) + # Blank ref warp background as reference background + bkgdIm = bkgd.getImageF() + bkgdStatsIm = bkgd.getStatsImage() + bkgdIm *= 0 + bkgdStatsIm *= 0 blank = BackgroundList( ( bkgd, @@ -325,7 +313,6 @@ def run(self, warps): backgroundInfoList = [] matchedImageList = [] for exp in warps: - # TODO: simplify what this prints? self.log.info( "Matching background of %s to %s", exp.dataId, @@ -347,7 +334,6 @@ def run(self, warps): toMatchExposure.image /= instFluxToNanojansky # Back to cts matchedImageList.append(toMatchExposure) - # TODO: more elegant solution than inserting blank model at ref ind? backgroundInfoList.insert(refInd, blank) refExposure.image /= instFluxToNanojanskyRef # Back to cts matchedImageList.insert(refInd, refExposure) @@ -377,6 +363,8 @@ def _defineWarps(self, warps, refWarpVisit=None): Reference warped exposure. refWarpIndex : `int` Index of the reference removed from the list of warps. + warpBg : `~lsst.afw.math.BackgroundMI` + Temporary background model, used to make a blank BG for the ref Notes ----- @@ -454,7 +442,7 @@ def _defineWarps(self, warps, refWarpVisit=None): ind = np.nanargmin(costFunctionVals) refWarp = warps.pop(ind) self.log.info("Using best reference visit %d", refWarp.dataId["visit"]) - return refWarp, ind + return refWarp, ind, warpBg def _makeBackground(self, warp: MaskedImageF, binSize) -> tuple[BackgroundMI, BackgroundControl]: """Generate a simple binned background masked image for warped data. @@ -528,11 +516,6 @@ def matchBackgrounds(self, refExposure, sciExposure): model : `~lsst.afw.math.BackgroundMI` Background model of difference image, reference - science """ - # TODO: this is deprecated - if lsstDebug.Info(__name__).savefits: - refExposure.writeFits(lsstDebug.Info(__name__).figpath + "refExposure.fits") - sciExposure.writeFits(lsstDebug.Info(__name__).figpath + "sciExposure.fits") - # Check Configs for polynomials: if self.config.usePolynomial: x, y = sciExposure.getDimensions() @@ -622,17 +605,6 @@ def matchBackgrounds(self, refExposure, sciExposure): resids = bgZ - modelValueArr rms = np.sqrt(np.mean(resids[~np.isnan(resids)] ** 2)) - # TODO: also deprecated; _gridImage() maybe can go? - if lsstDebug.Info(__name__).savefits: - sciExposure.writeFits(lsstDebug.Info(__name__).figpath + "sciMatchedExposure.fits") - - if lsstDebug.Info(__name__).savefig: - bbox = Box2D(refExposure.getMaskedImage().getBBox()) - try: - self._debugPlot(bgX, bgY, bgZ, bgdZ, bkgdImage, bbox, modelValueArr, resids) - except Exception as e: - self.log.warning("Debug plot not generated: %s", e) - meanVar = makeStatistics(diffMI.getVariance(), diffMI.getMask(), MEANCLIP, self.statsCtrl).getValue() diffIm = diffMI.getImage() @@ -642,7 +614,6 @@ def matchBackgrounds(self, refExposure, sciExposure): outBkgd = approx if self.config.usePolynomial else bkgd # Convert this back into counts - # TODO: is there a one-line way to do this? statsIm = outBkgd.getStatsImage() statsIm /= instFluxToNanojansky bkgdIm = outBkgd.getImageF() @@ -667,119 +638,3 @@ def matchBackgrounds(self, refExposure, sciExposure): False, ) ) - - def _debugPlot(self, X, Y, Z, dZ, modelImage, bbox, model, resids): - """ - Consider deleting this entirely - Generate a plot showing the background fit and residuals. - - It is called when lsstDebug.Info(__name__).savefig = True. - Saves the fig to lsstDebug.Info(__name__).figpath. - Displays on screen if lsstDebug.Info(__name__).display = True. - - Parameters - ---------- - X : `np.ndarray`, (N,) - Array of x positions. - Y : `np.ndarray`, (N,) - Array of y positions. - Z : `np.ndarray` - Array of the grid values that were interpolated. - dZ : `np.ndarray`, (len(Z),) - Array of the error on the grid values. - modelImage : `Unknown` - Image of the model of the fit. - model : `np.ndarray`, (len(Z),) - Array of len(Z) containing the grid values predicted by the model. - resids : `Unknown` - Z - model. - """ - import matplotlib.colors - import matplotlib.pyplot as plt - from mpl_toolkits.axes_grid1 import ImageGrid - - zeroIm = MaskedImageF(Box2I(bbox)) - zeroIm += modelImage - x0, y0 = zeroIm.getXY0() - dx, dy = zeroIm.getDimensions() - if len(Z) == 0: - self.log.warning("No grid. Skipping plot generation.") - else: - max, min = np.max(Z), np.min(Z) - norm = matplotlib.colors.normalize(vmax=max, vmin=min) - maxdiff = np.max(np.abs(resids)) - diffnorm = matplotlib.colors.normalize(vmax=maxdiff, vmin=-maxdiff) - rms = np.sqrt(np.mean(resids**2)) - fig = plt.figure(1, (8, 6)) - meanDz = np.mean(dZ) - grid = ImageGrid( - fig, - 111, - nrows_ncols=(1, 2), - axes_pad=0.1, - share_all=True, - label_mode="L", - cbar_mode="each", - cbar_size="7%", - cbar_pad="2%", - cbar_location="top", - ) - im = grid[0].imshow( - zeroIm.getImage().getArray(), extent=(x0, x0 + dx, y0 + dy, y0), norm=norm, cmap="Spectral" - ) - im = grid[0].scatter( - X, Y, c=Z, s=15.0 * meanDz / dZ, edgecolor="none", norm=norm, marker="o", cmap="Spectral" - ) - im2 = grid[1].scatter(X, Y, c=resids, edgecolor="none", norm=diffnorm, marker="s", cmap="seismic") - grid.cbar_axes[0].colorbar(im) - grid.cbar_axes[1].colorbar(im2) - grid[0].axis([x0, x0 + dx, y0 + dy, y0]) - grid[1].axis([x0, x0 + dx, y0 + dy, y0]) - grid[0].set_xlabel("model and grid") - grid[1].set_xlabel("residuals. rms = %0.3f" % (rms)) - if lsstDebug.Info(__name__).savefig: - fig.savefig(lsstDebug.Info(__name__).figpath + self.debugDataIdString + ".png") - if lsstDebug.Info(__name__).display: - plt.show() - plt.clf() - - def _gridImage(self, maskedImage, binsize, statsFlag): - """Private method to grid an image for debugging.""" - width, height = maskedImage.getDimensions() - x0, y0 = maskedImage.getXY0() - xedges = np.arange(0, width, binsize) - yedges = np.arange(0, height, binsize) - xedges = np.hstack((xedges, width)) # add final edge - yedges = np.hstack((yedges, height)) # add final edge - - # Use lists/append to protect against the case where - # a bin has no valid pixels and should not be included in the fit - bgX = [] - bgY = [] - bgZ = [] - bgdZ = [] - - for ymin, ymax in zip(yedges[0:-1], yedges[1:]): - for xmin, xmax in zip(xedges[0:-1], xedges[1:]): - subBBox = Box2I( - PointI(int(x0 + xmin), int(y0 + ymin)), - PointI(int(x0 + xmax - 1), int(y0 + ymax - 1)), - ) - subIm = MaskedImageF(maskedImage, subBBox, PARENT, False) - stats = makeStatistics( - subIm, - MEAN | MEANCLIP | MEDIAN | NPOINT | STDEV, - self.statsCtrl, - ) - npoints, _ = stats.getResult(NPOINT) - if npoints >= 2: - stdev, _ = stats.getResult(STDEV) - if stdev < self.config.gridStdevEpsilon: - stdev = self.config.gridStdevEpsilon - bgX.append(0.5 * (x0 + xmin + x0 + xmax)) - bgY.append(0.5 * (y0 + ymin + y0 + ymax)) - bgdZ.append(stdev / np.sqrt(npoints)) - est, _ = stats.getResult(statsFlag) - bgZ.append(est) - - return np.array(bgX), np.array(bgY), np.array(bgZ), np.array(bgdZ)