diff --git a/python/lsst/pipe/tasks/processBrightStars.py b/python/lsst/pipe/tasks/processBrightStars.py index 350ce716a..660261ec2 100644 --- a/python/lsst/pipe/tasks/processBrightStars.py +++ b/python/lsst/pipe/tasks/processBrightStars.py @@ -21,10 +21,13 @@ """Extract bright star cutouts; normalize and warp to the same pixel grid.""" -__all__ = ["ProcessBrightStarsTask"] +__all__ = ["ProcessBrightStarsConnections", "ProcessBrightStarsConfig", "ProcessBrightStarsTask"] + +from typing import cast import astropy.units as u import numpy as np +from astropy.table import Table from lsst.afw.cameraGeom import PIXELS, TAN_PIXELS from lsst.afw.detection import FootprintSet, Threshold from lsst.afw.geom.transformFactory import makeIdentityTransform, makeTransform @@ -40,7 +43,6 @@ from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader from lsst.meas.algorithms.brightStarStamps import BrightStarStamp, BrightStarStamps from lsst.pex.config import ChoiceField, ConfigField, Field, ListField -from lsst.pex.exceptions import InvalidParameterError from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput from lsst.utils.timer import timeMethod @@ -50,19 +52,19 @@ class ProcessBrightStarsConnections(PipelineTaskConnections, dimensions=("instru """Connections for ProcessBrightStarsTask.""" inputExposure = Input( - doc="Input exposure from which to extract bright star stamps", + doc="Input exposure from which to extract bright star stamps.", name="calexp", storageClass="ExposureF", dimensions=("visit", "detector"), ) skyCorr = Input( - doc="Input Sky Correction to be subtracted from the calexp if doApplySkyCorr=True", + doc="Input sky correction to be subtracted from the calexp if doApplySkyCorr=True.", name="skyCorr", storageClass="Background", dimensions=("instrument", "visit", "detector"), ) refCat = PrerequisiteInput( - doc="Reference catalog that contains bright star positions", + doc="Reference catalog that contains bright star positions.", name="gaia_dr2_20200414", storageClass="SimpleCatalog", dimensions=("skypix",), @@ -77,6 +79,7 @@ class ProcessBrightStarsConnections(PipelineTaskConnections, dimensions=("instru ) def __init__(self, *, config=None): + config = cast(ProcessBrightStarsConfig, config) super().__init__(config=config) if not config.doApplySkyCorr: self.inputs.remove("skyCorr") @@ -85,37 +88,33 @@ def __init__(self, *, config=None): class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBrightStarsConnections): """Configuration parameters for ProcessBrightStarsTask.""" - magLimit = Field( - dtype=float, + magLimit = Field[float]( doc="Magnitude limit, in Gaia G; all stars brighter than this value will be processed.", default=18, ) - stampSize = ListField( - dtype=int, + stampSize = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(250, 250), ) - modelStampBuffer = Field( - dtype=float, + modelStampBuffer = Field[float]( doc=( "'Buffer' factor to be applied to determine the size of the stamp the processed stars will be " "saved in. This will also be the size of the extended PSF model." ), default=1.1, ) - doRemoveDetected = Field( - dtype=bool, - doc="Whether DETECTION footprints, other than that for the central object, should be changed to BAD.", + doRemoveDetected = Field[bool]( + doc="Whether secondary DETECTION footprints (i.e., footprints of objects other than the central " + "primary object) should be changed to BAD.", default=True, ) - doApplyTransform = Field( - dtype=bool, + doApplyTransform = Field[bool]( doc="Apply transform to bright star stamps to correct for optical distortions?", default=True, ) warpingKernelName = ChoiceField( dtype=str, - doc="Warping kernel", + doc="Warping kernel.", default="lanczos5", allowed={ "bilinear": "bilinear interpolation", @@ -124,8 +123,7 @@ class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBr "lanczos5": "Lanczos kernel of order 5", }, ) - annularFluxRadii = ListField( - dtype=int, + annularFluxRadii = ListField[int]( doc="Inner and outer radii of the annulus used to compute AnnularFlux for normalization, in pixels.", default=(70, 80), ) @@ -139,60 +137,43 @@ class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBr "MEANCLIP": "clipped mean", }, ) - numSigmaClip = Field( - dtype=float, + numSigmaClip = Field[float]( doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=4, ) - numIter = Field( - dtype=int, + numIter = Field[int]( doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=3, ) - badMaskPlanes = ListField( - dtype=str, + badMaskPlanes = ListField[str]( doc="Mask planes that identify pixels to not include in the computation of the annular flux.", default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), ) - minValidAnnulusFraction = Field( - dtype=float, + minValidAnnulusFraction = Field[float]( doc="Minumum number of valid pixels that must fall within the annulus for the bright star to be " "saved for subsequent generation of a PSF.", default=0.0, ) - doApplySkyCorr = Field( - dtype=bool, + doApplySkyCorr = Field[bool]( doc="Apply full focal plane sky correction before extracting stars?", default=True, ) - discardNanFluxStars = Field( - dtype=bool, + discardNanFluxStars = Field[bool]( doc="Should stars with NaN annular flux be discarded?", default=False, ) - refObjLoader = ConfigField( + refObjLoader = ConfigField[LoadReferenceObjectsConfig]( dtype=LoadReferenceObjectsConfig, doc="Reference object loader for astrometric calibration.", ) class ProcessBrightStarsTask(PipelineTask): - """The description of the parameters for this Task are detailed in - :lsst-task:`~lsst.pipe.base.PipelineTask`. - - Parameters - ---------- - initInputs : `Unknown` - *args - Additional positional arguments. - **kwargs - Additional keyword arguments. - - Notes - ----- - `ProcessBrightStarsTask` is used to extract, process, and store small - image cut-outs (or "postage stamps") around bright stars. It relies on - three methods, called in succession: + """Extract bright star cutouts; normalize and warp to the same pixel grid. + + This task is used to extract, process, and store small image cut-outs + (or "postage stamps") around bright stars. It relies on three methods, + called in succession: `extractStamps` Find bright stars within the exposure using a reference catalog and @@ -211,35 +192,128 @@ class ProcessBrightStarsTask(PipelineTask): def __init__(self, initInputs=None, *args, **kwargs): super().__init__(*args, **kwargs) - # Compute (model) stamp size depending on provided "buffer" value - self.modelStampSize = [ - int(self.config.stampSize[0] * self.config.modelStampBuffer), - int(self.config.stampSize[1] * self.config.modelStampBuffer), - ] - # force it to be odd-sized so we have a central pixel - if not self.modelStampSize[0] % 2: - self.modelStampSize[0] += 1 - if not self.modelStampSize[1] % 2: - self.modelStampSize[1] += 1 - # central pixel - self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 self.setModelStamp() - # configure Gaia refcat - if butler is not None: - self.makeSubtask("refObjLoader", butler=butler) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + self.config = cast(ProcessBrightStarsConfig, self.config) + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = str(butlerQC.quantum.dataId) + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.refObjLoader, + ) + output = self.run(**inputs, refObjLoader=refObjLoader) + # Only ingest stamp if it exists; prevent ingesting an empty FITS file. + if output: + butlerQC.put(output, outputRefs) + + @timeMethod + def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, then preprocess them. + + Bright star preprocessing steps are: shifting, warping and potentially + rotating them to the same pixel grid; computing their annular flux, + and; normalizing them. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright star stamps should be extracted. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (including detector) that bright stars + should be extracted from. + skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional + Full focal plane sky correction obtained by `SkyCorrectionTask`. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + self.config = cast(ProcessBrightStarsConfig, self.config) + + if self.config.doApplySkyCorr: + self.log.info("Applying sky correction to exposure %s (exposure modified in-place).", dataId) + self.applySkyCorr(inputExposure, skyCorr) + + self.log.info("Extracting bright stars from exposure %s", dataId) + # Extract stamps around bright stars. + extractedStamps = self.extractStamps(inputExposure, refObjLoader=refObjLoader) + if not extractedStamps.starStamps: + self.log.info("No suitable bright star found.") + return None + # Warp (and shift, and potentially rotate) them. + self.log.info( + "Applying warp and/or shift to %i star stamps from exposure %s.", + len(extractedStamps.starStamps), + dataId, + ) + warpOutputs = self.warpStamps(extractedStamps.starStamps, extractedStamps.pixCenters) + warpedStars = warpOutputs.warpedStars + xy0s = warpOutputs.xy0s + brightStarList = [ + BrightStarStamp( + stamp_im=warp, + archive_element=transform, + position=xy0s[j], + gaiaGMag=extractedStamps.gMags[j], + gaiaId=extractedStamps.gaiaIds[j], + minValidAnnulusFraction=self.config.minValidAnnulusFraction, + ) + for j, (warp, transform) in enumerate(zip(warpedStars, warpOutputs.warpTransforms)) + ] + # Compute annularFlux and normalize + self.log.info( + "Computing annular flux and normalizing %i bright stars from exposure %s.", + len(warpedStars), + dataId, + ) + # annularFlux statistic set-up, excluding mask planes + statsControl = StatisticsControl() + statsControl.setNumSigmaClip(self.config.numSigmaClip) + statsControl.setNumIter(self.config.numIter) + + innerRadius, outerRadius = self.config.annularFluxRadii + statsFlag = stringToStatisticsProperty(self.config.annularFluxStatistic) + brightStarStamps = BrightStarStamps.initAndNormalize( + brightStarList, + innerRadius=innerRadius, + outerRadius=outerRadius, + nb90Rots=warpOutputs.nb90Rots, + imCenter=self.modelCenter, + use_archive=True, + statsControl=statsControl, + statsFlag=statsFlag, + badMaskPlanes=self.config.badMaskPlanes, + discardNanFluxObjects=(self.config.discardNanFluxStars), + ) + # Do not create empty FITS files if there aren't any normalized stamps. + if not brightStarStamps._stamps: + self.log.info("No normalized stamps exist for this exposure.") + return None + return Struct(brightStarStamps=brightStarStamps) def applySkyCorr(self, calexp, skyCorr): - """Apply correction to the sky background level. + """Apply sky correction to the input exposure. - Sky corrections can be generated using the ``SkyCorrectionTask``. - As the sky model generated there extends over the full focal plane, - this should produce a more optimal sky subtraction solution. + Sky corrections can be generated using the + `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. + As the sky model generated via that task extends over the full focal + plane, this should produce a more optimal sky subtraction solution. Parameters ---------- calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` - Calibrated exposure. - skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional + Calibrated exposure to correct. + skyCorr : `~lsst.afw.math.backgroundList.BackgroundList` Full focal plane sky correction from ``SkyCorrectionTask``. Notes @@ -250,116 +324,195 @@ def applySkyCorr(self, calexp, skyCorr): calexp = calexp.getMaskedImage() calexp -= skyCorr.getImage() - def extractStamps(self, inputExposure, refObjLoader=None, inputBrightStarStamps=None): - """Read the position of bright stars within an input exposure using a - refCat and extract them. + def extractStamps( + self, inputExposure, filterName="phot_g_mean", refObjLoader=None, inputBrightStarStamps=None + ): + """Identify the positions of bright stars within an input exposure using + a reference catalog and extract them. Parameters ---------- inputExposure : `~lsst.afw.image.ExposureF` - The image from which bright star stamps should be extracted. + The image to extract bright star stamps from. + filterName : `str`, optional + Name of the camera filter to use for reference catalog filtering. refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional Loader to find objects within a reference catalog. + inputBrightStarStamps: + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`, optional + Provides information about the stars that have already been + extracted from the inputExposure in other steps of the pipeline. + For example, this is used in the `SubtractBrightStarsTask` to avoid + extracting stars that already have been extracted when running + `ProcessBrightStarsTask` to produce brightStarStamps. Returns ------- result : `~lsst.pipe.base.Struct` Results as a struct with attributes: - ``starIms`` + ``starStamps`` Postage stamps (`list`). ``pixCenters`` Corresponding coords to each star's center, in pixels (`list`). - ``GMags`` + ``gMags`` Corresponding (Gaia) G magnitudes (`list`). ``gaiaIds`` Corresponding unique Gaia identifiers (`np.ndarray`). + + Notes + """ if refObjLoader is None: refObjLoader = self.refObjLoader - starIms = [] - pixCenters = [] - GMags = [] - ids = [] + wcs = inputExposure.getWcs() - # select stars within, or close enough to input exposure from refcat - inputIm = inputExposure.maskedImage - inputExpBBox = inputExposure.getBBox() - # Attempt to include stars that are outside of the exposure but their - # stamps overlap with the exposure. + inputBBox = inputExposure.getBBox() + + # Trim the reference catalog to only those objects within the exposure + # bounding box dilated by half the bright star stamp size. This ensures + # all stars that overlap the exposure are included. dilatationExtent = Extent2I(np.array(self.config.stampSize) // 2) - # TODO (DM-25894): handle catalog with stars missing from Gaia - withinCalexp = refObjLoader.loadPixelBox( - inputExpBBox.dilatedBy(dilatationExtent), - wcs, - filterName="phot_g_mean", + withinExposure = refObjLoader.loadPixelBox( + inputBBox.dilatedBy(dilatationExtent), wcs, filterName=filterName + ) + refCat = withinExposure.refCat + fluxField = withinExposure.fluxField + + # Define ref cat bright subset: objects brighter than the mag limit. + fluxLimit = ((self.config.magLimit * u.ABmag).to(u.nJy)).to_value() # AB magnitudes. + refCatBright = Table( + refCat.extract("id", "coord_ra", "coord_dec", fluxField, where=refCat[fluxField] > fluxLimit) ) - refCat = withinCalexp.refCat - # keep bright objects - fluxLimit = ((self.config.magLimit * u.ABmag).to(u.nJy)).to_value() - GFluxes = np.array(refCat["phot_g_mean_flux"]) - bright = GFluxes > fluxLimit - # convert to AB magnitudes - allGMags = np.array([((gFlux * u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]]) - allIds = refCat.columns.extract("id", where=bright)["id"] - selectedColumns = refCat.columns.extract("coord_ra", "coord_dec", where=bright) + refCatBright["mag"] = (refCatBright[fluxField][:] * u.nJy).to(u.ABmag).to_value() # AB magnitudes. + + # Remove input bright stars (if provided) from the bright subset. if inputBrightStarStamps is not None: - existings = np.array(inputBrightStarStamps.getGaiaIds()) - existed = np.isin(allIds, existings) - allGMags = allGMags[~existed] - allIds = allIds[~existed] - selectedColumns["coord_ra"] = selectedColumns["coord_ra"][~existed] - selectedColumns["coord_dec"] = selectedColumns["coord_dec"][~existed] - for j, (ra, dec) in enumerate(zip(selectedColumns["coord_ra"], selectedColumns["coord_dec"])): - sp = SpherePoint(ra, dec, radians) - cpix = wcs.skyToPixel(sp) - try: - starIm = inputExposure.getCutout(sp, Extent2I(self.config.stampSize)) - except InvalidParameterError: - # star is beyond boundary - bboxCorner = np.array(cpix) - np.array(self.config.stampSize) / 2 - # compute bbox as it would be otherwise - idealBBox = Box2I(Point2I(bboxCorner), Extent2I(self.config.stampSize)) - clippedStarBBox = Box2I(idealBBox) - clippedStarBBox.clip(inputExpBBox) - if clippedStarBBox.getArea() > 0: - # create full-sized stamp with all pixels - # flagged as NO_DATA - starIm = ExposureF(bbox=idealBBox) - starIm.image[:] = np.nan - starIm.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) - # recover pixels from intersection with the exposure - clippedIm = inputIm.Factory(inputIm, clippedStarBBox) - starIm.maskedImage[clippedStarBBox] = clippedIm - # set detector and wcs, used in warpStars - starIm.setDetector(inputExposure.getDetector()) - starIm.setWcs(inputExposure.getWcs()) - else: - continue + # Extract the IDs of stars that have already been extracted. + existing = np.isin(refCatBright["id"][:], inputBrightStarStamps.getGaiaIds()) + refCatBright = refCatBright[~existing] + + # Loop over each reference bright star, extract a stamp around it. + pixCenters = [] + starStamps = [] + badRows = [] + for row, object in enumerate(refCatBright): + coordSky = SpherePoint(object["coord_ra"], object["coord_dec"], radians) + coordPix = wcs.skyToPixel(coordSky) + # TODO: Replace this method with exposure getCutout after DM-40042. + starStamp = self._getCutout(inputExposure, coordPix, self.config.stampSize.list()) + if not starStamp: + badRows.append(row) + continue if self.config.doRemoveDetected: - # give detection footprint of other objects the BAD flag - detThreshold = Threshold(starIm.mask.getPlaneBitMask("DETECTED"), Threshold.BITMASK) - omask = FootprintSet(starIm.mask, detThreshold) - allFootprints = omask.getFootprints() - otherFootprints = [] - for fs in allFootprints: - if not fs.contains(Point2I(cpix)): - otherFootprints.append(fs) - nbMatchingFootprints = len(allFootprints) - len(otherFootprints) - if not nbMatchingFootprints == 1: - self.log.warning( - "Failed to uniquely identify central DETECTION footprint for star " - "%s; found %d footprints instead.", - allIds[j], - nbMatchingFootprints, - ) - omask.setFootprints(otherFootprints) - omask.setMask(starIm.mask, "BAD") - starIms.append(starIm) - pixCenters.append(cpix) - GMags.append(allGMags[j]) - ids.append(allIds[j]) - return Struct(starIms=starIms, pixCenters=pixCenters, GMags=GMags, gaiaIds=ids) + self._replaceSecondaryFootprints(starStamp, coordPix, object["id"]) + starStamps.append(starStamp) + pixCenters.append(coordPix) + + # Remove bad rows from the reference catalog; set up return data. + refCatBright.remove_rows(badRows) + gMags = list(refCatBright["mag"][:]) + ids = list(refCatBright["id"][:]) + return Struct(starStamps=starStamps, pixCenters=pixCenters, gMags=gMags, gaiaIds=ids) + + def _getCutout(self, inputExposure, coordPix: Point2D, stampSize: list[int]): + """Get a cutout from an input exposure, handling edge cases. + + Generate a cutout from an input exposure centered on a given position + and with a given size. + If any part of the cutout is outside the input exposure bounding box, + the cutout is padded with NaNs. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image to extract bright star stamps from. + coordPix : `~lsst.geom.Point2D` + Center of the cutout in pixel space. + stampSize : `list` [`int`] + Size of the cutout, in pixels. + + Returns + ------- + stamp : `~lsst.afw.image.ExposureF` or `None` + The cutout, or `None` if the cutout is entirely outside the input + exposure bounding box. + + Notes + ----- + This method is a short-term workaround until DM-40042 is implemented. + At that point, it should be replaced by a call to the Exposure method + ``getCutout``, which will handle edge cases automatically. + """ + corner = Point2I(np.array(coordPix) - np.array(stampSize) / 2) + dimensions = Extent2I(stampSize) + stampBBox = Box2I(corner, dimensions) + overlapBBox = Box2I(stampBBox) + overlapBBox.clip(inputExposure.getBBox()) + if overlapBBox.getArea() > 0: + # Create full-sized stamp with pixels initially flagged as NO_DATA. + stamp = ExposureF(bbox=stampBBox) + stamp.image[:] = np.nan + stamp.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) + # Restore pixels which overlap the input exposure. + inputMI = inputExposure.maskedImage + overlap = inputMI.Factory(inputMI, overlapBBox) + stamp.maskedImage[overlapBBox] = overlap + # Set detector and WCS. + stamp.setDetector(inputExposure.getDetector()) + stamp.setWcs(inputExposure.getWcs()) + else: + stamp = None + return stamp + + def _replaceSecondaryFootprints(self, stamp, coordPix, objectId, find="DETECTED", replace="BAD"): + """Replace all secondary footprints in a stamp with another mask flag. + + This method identifies all secondary footprints in a stamp as those + whose ``find`` footprints do not overlap the given pixel coordinates. + If then sets these secondary footprints to the ``replace`` flag. + + Parameters + ---------- + stamp : `~lsst.afw.image.ExposureF` + The postage stamp to modify. + coordPix : `~lsst.geom.Point2D` + The pixel coordinates of the central primary object. + objectId : `int` + The unique identifier of the central primary object. + find : `str`, optional + The mask plane to use to identify secondary footprints. + replace : `str`, optional + The mask plane to set secondary footprints to. + + Notes + ----- + This method modifies the input ``stamp`` in-place. + """ + # Find a FootprintSet given an Image and a threshold. + detThreshold = Threshold(stamp.mask.getPlaneBitMask(find), Threshold.BITMASK) + footprintSet = FootprintSet(stamp.mask, detThreshold) + allFootprints = footprintSet.getFootprints() + # Identify secondary objects (i.e., not the central primary object). + secondaryFootprints = [] + for footprint in allFootprints: + if not footprint.contains(Point2I(coordPix)): + secondaryFootprints.append(footprint) + # Set secondary object footprints to BAD. + # Note: the value of numPrimaryFootprints can only be 0 or 1. If it is + # 0, then the primary object was not found overlapping a footprint. + # This can occur for low-S/N stars, for example. Processing can still + # continue beyond this point in an attempt to utilize this faint flux. + if (numPrimaryFootprints := len(allFootprints) - len(secondaryFootprints)) == 0: + self.log.warning( + "Could not uniquely identify central %s footprint for star %s; " + "found %d footprints instead.", + find, + objectId, + numPrimaryFootprints, + ) + footprintSet.setFootprints(secondaryFootprints) + footprintSet.setMask(stamp.mask, replace) def warpStamps(self, stamps, pixCenters): """Warps and shifts all given stamps so they are sampled on the same @@ -458,116 +611,15 @@ def warpStamps(self, stamps, pixCenters): return Struct(warpedStars=warpedStars, warpTransforms=warpTransforms, xy0s=xy0s, nb90Rots=nb90Rots) def setModelStamp(self): + """Compute (model) stamp size depending on provided buffer value.""" self.modelStampSize = [ int(self.config.stampSize[0] * self.config.modelStampBuffer), int(self.config.stampSize[1] * self.config.modelStampBuffer), ] - # force it to be odd-sized so we have a central pixel + # Force stamp to be odd-sized so we have a central pixel. if not self.modelStampSize[0] % 2: self.modelStampSize[0] += 1 if not self.modelStampSize[1] % 2: self.modelStampSize[1] += 1 - # central pixel + # Central pixel. self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 - - @timeMethod - def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): - """Identify bright stars within an exposure using a reference catalog, - extract stamps around each, then preprocess them. The preprocessing - steps are: shifting, warping and potentially rotating them to the same - pixel grid; computing their annular flux and normalizing them. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The image from which bright star stamps should be extracted. - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional - Loader to find objects within a reference catalog. - dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The dataId of the exposure (and detector) bright stars should be - extracted from. - skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional - Full focal plane sky correction obtained by `SkyCorrectionTask`. - - Returns - ------- - result : `~lsst.pipe.base.Struct` - Results as a struct with attributes: - - ``brightStarStamps`` - (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) - """ - if self.config.doApplySkyCorr: - self.log.info( - "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId - ) - self.applySkyCorr(inputExposure, skyCorr) - self.log.info("Extracting bright stars from exposure %s", dataId) - # Extract stamps around bright stars - extractedStamps = self.extractStamps(inputExposure, refObjLoader=refObjLoader) - if not extractedStamps.starIms: - self.log.info("No suitable bright star found.") - return None - # Warp (and shift, and potentially rotate) them - self.log.info( - "Applying warp and/or shift to %i star stamps from exposure %s.", - len(extractedStamps.starIms), - dataId, - ) - warpOutputs = self.warpStamps(extractedStamps.starIms, extractedStamps.pixCenters) - warpedStars = warpOutputs.warpedStars - xy0s = warpOutputs.xy0s - brightStarList = [ - BrightStarStamp( - stamp_im=warp, - archive_element=transform, - position=xy0s[j], - gaiaGMag=extractedStamps.GMags[j], - gaiaId=extractedStamps.gaiaIds[j], - minValidAnnulusFraction=self.config.minValidAnnulusFraction, - ) - for j, (warp, transform) in enumerate(zip(warpedStars, warpOutputs.warpTransforms)) - ] - # Compute annularFlux and normalize - self.log.info( - "Computing annular flux and normalizing %i bright stars from exposure %s.", - len(warpedStars), - dataId, - ) - # annularFlux statistic set-up, excluding mask planes - statsControl = StatisticsControl() - statsControl.setNumSigmaClip(self.config.numSigmaClip) - statsControl.setNumIter(self.config.numIter) - innerRadius, outerRadius = self.config.annularFluxRadii - statsFlag = stringToStatisticsProperty(self.config.annularFluxStatistic) - brightStarStamps = BrightStarStamps.initAndNormalize( - brightStarList, - innerRadius=innerRadius, - outerRadius=outerRadius, - nb90Rots=warpOutputs.nb90Rots, - imCenter=self.modelCenter, - use_archive=True, - statsControl=statsControl, - statsFlag=statsFlag, - badMaskPlanes=self.config.badMaskPlanes, - discardNanFluxObjects=(self.config.discardNanFluxStars), - ) - # Dont create empty fits files if there is no normalized stamp! - if not len(brightStarStamps._stamps) > 0: - self.log.info("No normalized stamps exists for this exposure!") - return None - return Struct(brightStarStamps=brightStarStamps) - - def runQuantum(self, butlerQC, inputRefs, outputRefs): - inputs = butlerQC.get(inputRefs) - inputs["dataId"] = str(butlerQC.quantum.dataId) - refObjLoader = ReferenceObjectLoader( - dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], - refCats=inputs.pop("refCat"), - name=self.config.connections.refCat, - config=self.config.refObjLoader, - ) - output = self.run(**inputs, refObjLoader=refObjLoader) - # This if block prevents the code to produce an emtpy fits file in case there is no stamp. - if output: - butlerQC.put(output, outputRefs) diff --git a/python/lsst/pipe/tasks/subtractBrightStars.py b/python/lsst/pipe/tasks/subtractBrightStars.py index 2d18f2805..ff4e9a973 100644 --- a/python/lsst/pipe/tasks/subtractBrightStars.py +++ b/python/lsst/pipe/tasks/subtractBrightStars.py @@ -28,8 +28,8 @@ from operator import ior import numpy as np -from lsst.afw.image import Exposure, ExposureF, MaskedImageF from lsst.afw.geom import SpanSet, Stencil +from lsst.afw.image import Exposure, ExposureF, MaskedImageF from lsst.afw.math import ( StatisticsControl, WarpingControl, @@ -52,9 +52,11 @@ class SubtractBrightStarsConnections( PipelineTaskConnections, dimensions=("instrument", "visit", "detector"), - defaultTemplates={"outputExposureName": "brightStar_subtracted", - "outputBackgroundName": "brightStars", - "badStampsName": "brightStars"}, + defaultTemplates={ + "outputExposureName": "brightStar_subtracted", + "outputBackgroundName": "brightStars", + "badStampsName": "brightStars", + }, ): inputExposure = Input( doc="Input exposure from which to subtract bright star stamps.", @@ -117,8 +119,8 @@ class SubtractBrightStarsConnections( ), ) outputBadStamps = Output( - doc="The stamps the are not normalized and consequently not subtracted from the exposure.", - name="{badStampsName}_unsubtracted_stapms", + doc="The stamps that are not normalized and consequently not subtracted from the exposure.", + name="{badStampsName}_unsubtracted_stamps", storageClass="BrightStarStamps", dimensions=( "visit", @@ -150,23 +152,23 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", default=18, ) - minValidAnnulusFraction = Field( + minValidAnnulusFraction = Field[float]( dtype=float, - doc="Minumum number of valid pixels that must fall within the annulus for the bright star to be " + doc="Minimum number of valid pixels that must fall within the annulus for the bright star to be " "saved for subsequent generation of a PSF.", default=0.0, ) - numSigmaClip = Field( + numSigmaClip = Field[float]( dtype=float, doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=4, ) - numIter = Field( + numIter = Field[int]( dtype=int, doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=3, ) - warpingKernelName = ChoiceField[str]( + warpingKernelName = ChoiceField( dtype=str, doc="Warping kernel", default="lanczos5", @@ -179,7 +181,7 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract "lanczos7": "Lanczos kernel of order 7", }, ) - scalingType = ChoiceField[str]( + scalingType = ChoiceField( dtype=str, doc="How the model should be scaled to each bright star; implemented options are " "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform " @@ -210,16 +212,18 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract # interest) also get set to `BAD`. default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), ) - subtractionBox = ListField( + subtractionBox = ListField[int]( dtype=int, doc="Size of the stamps to be extracted, in pixels.", default=(250, 250), ) - subtractionBoxBuffer = Field( + subtractionBoxBuffer = Field[float]( dtype=float, doc=( - "'Buffer' factor to be applied to determine the size of the stamp the processed stars will be " - "saved in. This will also be the size of the extended PSF model." + "'Buffer' (multiplicative) factor to be applied to determine the size of the stamp the " + "processed stars will be saved in. This is also the size of the extended PSF model. The buffer " + "region is masked and contain no data and subtractionBox determines the region where contains " + "the data." ), default=1.1, ) @@ -251,8 +255,128 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Placeholders to set up Statistics if scalingType is leastSquare. self.statsControl, self.statsFlag = None, None - # warping control; only contains shiftingALg provided in config - self.warpCont = WarpingControl(self.config.warpingKernelName) + # Warping control; only contains shiftingALg provided in config. + self.warpControl = WarpingControl(self.config.warpingKernelName) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + # Docstring inherited. + inputs = butlerQC.get(inputRefs) + dataId = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.refObjLoader, + ) + subtractor, _, badStamps = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) + if self.config.doWriteSubtractedExposure: + outputExposure = inputs["inputExposure"].clone() + outputExposure.image -= subtractor.image + else: + outputExposure = None + outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None + # In its current state, the code produces outputBadStamps which are the + # stamps of stars that have not been subtracted from the image for any + # reason. If all the stars are subtracted from the calexp, the output + # is an empty fits file. + output = Struct( + outputExposure=outputExposure, + outputBackgroundExposure=outputBackgroundExposure, + outputBadStamps=badStamps, + ) + butlerQC.put(output, outputRefs) + + def run( + self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None, refObjLoader=None + ): + """Iterate over all bright stars in an exposure to scale the extended + PSF model before subtracting bright stars. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright stars should be subtracted. + inputBrightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, + produced by running + `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. + inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf` + Extended PSF model, produced by + `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (and detector) bright stars should be + subtracted from. + skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional + Full focal plane sky correction, obtained by running + `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. If + `doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + + Returns + ------- + subtractorExp : `~lsst.afw.image.ExposureF` + An Exposure containing a scaled bright star model fit to every + bright star profile; its image can then be subtracted from the + input exposure. + invImages : `list` [`~lsst.afw.image.MaskedImageF`] + A list of small images ("stamps") containing the model, each scaled + to its corresponding input bright star. + """ + self.inputExpBBox = inputExposure.getBBox() + if self.config.doApplySkyCorr and (skyCorr is not None): + self.log.info( + "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId + ) + self.applySkyCorr(inputExposure, skyCorr) + + # Create an empty image the size of the exposure. + # TODO: DM-31085 (set mask planes). + subtractorExp = ExposureF(bbox=inputExposure.getBBox()) + subtractor = subtractorExp.maskedImage + + # Make a copy of the input model. + self.model = inputExtendedPsf(dataId["detector"]).clone() + self.modelStampSize = self.model.getDimensions() + # Number of 90 deg. rotations to reverse each stamp's rotation. + self.inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 + self.model = rotateImageBy90(self.model, self.inv90Rots) + + brightStarList = self.makeBrightStarList(inputBrightStarStamps, inputExposure, refObjLoader) + invImages = [] + subtractor, invImages = self.buildSubtractor( + inputBrightStarStamps, subtractor, invImages, multipleAnnuli=False + ) + if brightStarList: + self.setMissedStarsStatsControl() + # This may change when multiple star bins are used for PSF + # creation. + innerRadius = inputBrightStarStamps._innerRadius + outerRadius = inputBrightStarStamps._outerRadius + brightStarStamps, badStamps = BrightStarStamps.initAndNormalize( + brightStarList, + innerRadius=innerRadius, + outerRadius=outerRadius, + nb90Rots=self.warpOutputs.nb90Rots, + imCenter=self.warper.modelCenter, + use_archive=True, + statsControl=self.missedStatsControl, + statsFlag=self.missedStatsFlag, + badMaskPlanes=self.warper.config.badMaskPlanes, + discardNanFluxObjects=False, + forceFindFlux=True, + ) + + self.psf_annular_fluxes = self.findPsfAnnularFluxes(brightStarStamps) + subtractor, invImages = self.buildSubtractor( + brightStarStamps, subtractor, invImages, multipleAnnuli=True + ) + else: + badStamps = [] + badStamps = BrightStarStamps(badStamps) + + return subtractorExp, invImages, badStamps def _setUpStatistics(self, exampleMask): """Configure statistics control and flag, for use if ``scalingType`` is @@ -284,7 +408,7 @@ def applySkyCorr(self, calexp, skyCorr): calexp = calexp.getMaskedImage() calexp -= skyCorr.getImage() - def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=None): + def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=1.0): """Compute scaling factor to be applied to the extended PSF so that its amplitude matches that of an individual star. @@ -292,13 +416,18 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=Non ---------- model : `~lsst.afw.image.MaskedImageF` The extended PSF model, shifted (and potentially warped) to match - the bright star's positioning. + the bright star position. star : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` A stamp centered on the bright star to be subtracted. inPlace : `bool` Whether the model should be scaled in place. Default is `True`. nb90Rots : `int` The number of 90-degrees rotations to apply to the star stamp. + psf_annular_flux: `float`, optional + The annular flux of the PSF model at the radius where the flux of + the given star is determined. This is 1 for stars present in + inputBrightStarStamps, but can be different for stars that are + missing from inputBrightStarStamps. Returns ------- @@ -306,8 +435,6 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=Non The factor by which the model image should be multiplied for it to be scaled to the input bright star. """ - if psf_annular_flux is None: - psf_annular_flux = 1 if self.config.scalingType == "annularFlux": scalingFactor = star.annularFlux * psf_annular_flux elif self.config.scalingType == "leastSquare": @@ -332,9 +459,13 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=Non model.image *= scalingFactor return scalingFactor - def _overRideWarperConfig(self): + def _overrideWarperConfig(self): """Override the warper config with the config of this task. + + This override is necessary for stars that are missing from the + inputBrightStarStamps object but still need to be subtracted. """ + # TODO: Replace these copied values with a warperConfig. self.warper.config.minValidAnnulusFraction = self.config.minValidAnnulusFraction self.warper.config.numSigmaClip = self.config.numSigmaClip self.warper.config.numIter = self.config.numIter @@ -342,10 +473,12 @@ def _overRideWarperConfig(self): self.warper.config.badMaskPlanes = self.config.badMaskPlanes self.warper.config.stampSize = self.config.subtractionBox self.warper.modelStampBuffer = self.config.subtractionBoxBuffer + self.warper.config.magLimit = self.config.magLimit self.warper.setModelStamp() def setMissedStarsStatsControl(self): - """Configure statistics control for processing missing stars from inputBrightStarStamps. + """Configure statistics control for processing missing stars from + inputBrightStarStamps. """ self.missedStatsControl = StatisticsControl() self.missedStatsControl.setNumSigmaClip(self.warper.config.numSigmaClip) @@ -353,20 +486,23 @@ def setMissedStarsStatsControl(self): self.missedStatsFlag = stringToStatisticsProperty(self.warper.config.annularFluxStatistic) def setWarpTask(self): - """Create an instance of ProcessBrightStarsTask that will be used to produce stamps of stars to be - subtracted. + """Create an instance of ProcessBrightStarsTask that will be used to + produce stamps of stars to be subtracted. """ self.warper = ProcessBrightStarsTask() - self._overRideWarperConfig() + self._overrideWarperConfig() self.warper.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 def makeBrightStarList(self, inputBrightStarStamps, inputExposure, refObjLoader): - """Make a list of bright stars that are missing from inputBrightStarStamps to be subtracted. + """Make a list of bright stars that are missing from + inputBrightStarStamps to be subtracted. Parameters ---------- - inputBrightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` - Set of stamps centered on each bright star to be subtracted, produced by running + inputBrightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, + produced by running `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. inputExposure : `~lsst.afw.image.ExposureF` The image from which bright stars should be subtracted. @@ -376,26 +512,31 @@ def makeBrightStarList(self, inputBrightStarStamps, inputExposure, refObjLoader) Returns ------- brightStarList: - A list containing `lsst.meas.algorithms.brightStarStamps.BrightStarStamp` of stars to be - subtracted. + A list containing + `lsst.meas.algorithms.brightStarStamps.BrightStarStamp` of stars to + be subtracted. """ self.setWarpTask() missedStars = self.warper.extractStamps( inputExposure, refObjLoader=refObjLoader, inputBrightStarStamps=inputBrightStarStamps ) - self.warpOutputs = self.warper.warpStamps(missedStars.starIms, missedStars.pixCenters) - brightStarList = [ - BrightStarStamp( - stamp_im=warp, - archive_element=transform, - position=self.warpOutputs.xy0s[j], - gaiaGMag=missedStars.GMags[j], - gaiaId=missedStars.gaiaIds[j], - minValidAnnulusFraction=self.warper.config.minValidAnnulusFraction, - ) - for j, (warp, transform) in enumerate(zip(self.warpOutputs.warpedStars, - self.warspOutputs.warpTransforms)) - ] + if missedStars.starStamps: + self.warpOutputs = self.warper.warpStamps(missedStars.starStamps, missedStars.pixCenters) + brightStarList = [ + BrightStarStamp( + stamp_im=warp, + archive_element=transform, + position=self.warpOutputs.xy0s[j], + gaiaGMag=missedStars.gMags[j], + gaiaId=missedStars.gaiaIds[j], + minValidAnnulusFraction=self.warper.config.minValidAnnulusFraction, + ) + for j, (warp, transform) in enumerate( + zip(self.warpOutputs.warpedStars, self.warpOutputs.warpTransforms) + ) + ] + else: + brightStarList = [] return brightStarList def initAnnulusImage(self): @@ -412,11 +553,19 @@ def initAnnulusImage(self): return annulusImage def createAnnulus(self, brightStarStamp): - """Create an annulus of the given star. + """Create a circular annulus around the given star. + + The circular annulus is set based on the inner and outer optimal radii. + These radii describe the annulus where the flux of the star is found. + The aim is to create the same annulus for the PSF model, eventually + measuring the model flux around that annulus. + An optimal radius usually differs from the radius where the PSF model + is normalized. Parameters ---------- - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` A stamp of a bright star to be subtracted. Returns @@ -424,7 +573,7 @@ def createAnnulus(self, brightStarStamp): annulus : `~lsst.afw.image.MaskedImageF` An annulus of the given star. """ - # Create SpanSet of annulus + # Create SpanSet of annulus. outerCircle = SpanSet.fromShape( brightStarStamp.optimalOuterRadius, Stencil.CIRCLE, offset=self.warper.modelCenter ) @@ -445,7 +594,8 @@ def applyStatsControl(self, annulusImage): Returns ------- annularFlux: float - The annular flux of the PSF model at the radius where the flux of the given star is determined. + The annular flux of the PSF model at the radius where the flux of + the given star is determined. """ andMask = reduce( ior, (annulusImage.mask.getPlaneBitMask(bm) for bm in self.warper.config.badMaskPlanes) @@ -455,20 +605,25 @@ def applyStatsControl(self, annulusImage): return annulusStat.getValue() def findPsfAnnularFlux(self, brightStarStamp, maskedModel): - """Find the annular flux of the PSF model at the radius where the flux of the given star is - determined. + """Find the annular flux of the PSF model within a specified annulus. + + This flux will be used for re-scaling the PSF to the level of stars + with bad stamps. Stars with bad stamps are those without a flux within + the normalization annulus. Parameters ---------- - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` A stamp of a bright star to be subtracted. maskedModel : `~lsst.afw.image.MaskedImageF` A masked image of the PSF model. Returns ------- - annularFlux: float - The annular flux of the PSF model at the radius where the flux of the given star is determined. + annularFlux: float (between 0 and 1) + The annular flux of the PSF model at the radius where the flux of + the given star is determined. """ annulusImage = self.initAnnulusImage() annulus = self.createAnnulus(brightStarStamp) @@ -481,26 +636,28 @@ def findPsfAnnularFluxes(self, brightStarStamps): Parameters ---------- - brightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + brightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` The stamps of stars that will be subtracted from the exposure. Returns ------- PsfAnnularFluxes: numpy.array - A two column numpy.array containing annular fluxes of the PSF at radii where the flux for stars - exist (could be found). + A two column numpy.array containing annular fluxes of the PSF at + radii where the flux for stars exist (could be found). Notes ----- - While the PSF model is normalized at a certain radius, the flux of a star at that radius might be - impossible to find. Therefore, we have to scale the PSF model considering a radius where the star has - an identified flux. To do that, the flux of the model should be found and used to adjust the scaling - step. + While the PSF model is normalized at a certain radius, the annular flux + of a star around that radius might be impossible to find. Therefore, we + have to scale the PSF model considering a radius where the star has an + identified flux. To do that, the flux of the model should be found and + used to adjust the scaling step. """ outerRadii = [] annularFluxes = [] maskedModel = MaskedImageF(self.model.image) - # the model has wrong bbox values. Should be fixed in extended_psf.py? + # The model has wrong bbox values. Should be fixed in extended_psf.py? maskedModel.setXY0(0, 0) for star in brightStarStamps: if star.optimalOuterRadius not in outerRadii: @@ -510,25 +667,39 @@ def findPsfAnnularFluxes(self, brightStarStamps): return np.array([outerRadii, annularFluxes]).T def preparePlaneModelStamp(self, brightStarStamp): - """Prepare the PSF model before scaling. + """Prepare the PSF plane model stamp. + + It is called PlaneModel because, while it is a PSF model stamp that is + warped and rotated to the same orientation of a chosen star, it is not + yet scaled to the brightness level of the star. Parameters ---------- - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` The stamp of the star to which the PSF model will be scaled. Returns ------- bbox: `~lsst.geom.Box2I` - Contains the corner coordination and the dimensions of the model stamp. + Contains the corner coordination and the dimensions of the model + stamp. invImage: `~lsst.afw.image.MaskedImageF` - The extended PSF model, shifted (and potentially warped) to match the bright star's positioning. + The extended PSF model, shifted (and potentially warped and + rotated) to match the bright star position. Raises ------ RuntimeError - Raised if warping of the model is failed. + Raised if warping of the model failed. + + Notes + ----- + Since detectors have different orientations, the PSF model should be + rotated to match the orientation of the detectors in some cases. To do + that, the code uses the inverse of the transform that is applied to the + bright star stamp to match the orientation of the detector. """ # Set the origin. self.model.setXY0(brightStarStamp.position) @@ -538,12 +709,13 @@ def preparePlaneModelStamp(self, brightStarStamp): bbox = Box2I(corner=invOrigin, dimensions=self.modelStampSize) invImage = MaskedImageF(bbox) # Apply inverse transform. - goodPix = warpImage(invImage, self.model, invTransform, self.warpCont) + goodPix = warpImage(invImage, self.model, invTransform, self.warpControl) if not goodPix: - # Do we want to find another way or just subtract the non-warped scaled model? + # Do we want to find another way or just subtract the non-warped + # scaled model? # Currently the code just leaves the failed ones un-subtracted. raise RuntimeError( - f"Warping of a model failed for star {brightStarStamp.gaiaId}: " "no good pixel in output" + f"Warping of a model failed for star {brightStarStamp.gaiaId}: no good pixel in output." ) return bbox, invImage @@ -553,189 +725,81 @@ def addScaledModel(self, subtractor, brightStarStamp, multipleAnnuli=False): Parameters ---------- subtractor : `~lsst.afw.image.MaskedImageF` - The Exposure containing the scaled model of brigth stars to be subtracted from the input - exposure. - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` - The stamp of the star of which the PSF model will be scaled and added to the subtractor. + The full image containing the scaled model of bright stars to be + subtracted from the input exposure. + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + The stamp of the star of which the PSF model will be scaled and + added to the subtractor. multipleAnnuli : bool, optional - If true, the model should be scaled based on a flux at a radius other than its normalization - radius. + If true, the model should be scaled based on a flux at a radius + other than its normalization radius. Returns ------- subtractor : `~lsst.afw.image.MaskedImageF` - The input subtractor Exposure with the added scaled model at the given star's location in the - exposure. + The input subtractor full image with the added scaled model at the + given star's location in the exposure. invImage: `~lsst.afw.image.MaskedImageF` - The extended PSF model, shifted (and potentially warped) to match the bright star's positioning. + The extended PSF model, shifted (and potentially warped) to match + the bright star position. """ bbox, invImage = self.preparePlaneModelStamp(brightStarStamp) - if multipleAnnuli: - cond = self.psf_annular_fluxes[:, 0] == brightStarStamp.optimalOuterRadius - psf_annular_flux = self.psf_annular_fluxes[cond, 1][0] - self.scaleModel(invImage, - brightStarStamp, - inPlace=True, - nb90Rots=self.inv90Rots, - psf_annular_flux=psf_annular_flux) - else: - self.scaleModel(invImage, brightStarStamp, inPlace=True, nb90Rots=self.inv90Rots) - # Replace NaNs before subtraction (note all NaN pixels have - # the NO_DATA flag). - invImage.image.array[np.isnan(invImage.image.array)] = 0 bbox.clip(self.inputExpBBox) if bbox.getArea() > 0: + if multipleAnnuli: + cond = self.psf_annular_fluxes[:, 0] == brightStarStamp.optimalOuterRadius + psf_annular_flux = self.psf_annular_fluxes[cond, 1][0] + self.scaleModel( + invImage, + brightStarStamp, + inPlace=True, + nb90Rots=self.inv90Rots, + psf_annular_flux=psf_annular_flux, + ) + else: + self.scaleModel(invImage, brightStarStamp, inPlace=True, nb90Rots=self.inv90Rots) + # Replace NaNs before subtraction (all NaNs have the NO_DATA flag). + invImage.image.array[np.isnan(invImage.image.array)] = 0 subtractor[bbox] += invImage[bbox] return subtractor, invImage def buildSubtractor(self, brightStarStamps, subtractor, invImages, multipleAnnuli=False): - """Build an image containing potentially multiple scaled PSF models, each at the location of a given - brigth star. + """Build an image containing potentially multiple scaled PSF models, + each at the location of a given bright star. Parameters ---------- - brightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` - Set of stamps centered on each bright star to be subtracted, produced by running + brightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, + produced by running `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. subtractor : `~lsst.afw.image.MaskedImageF` - The Exposure that will contain the scaled model of brigth stars to be subtracted from the - exposure. + The Exposure that will contain the scaled model of bright stars to + be subtracted from the exposure. invImages : `list` - A list containing extended PSF models, shifted (and potentially warped) to match the bright stars - positionings. + A list containing extended PSF models, shifted (and potentially + warped) to match the bright stars positions. multipleAnnuli : bool, optional This will be passed to addScaledModel method, by default False. Returns ------- subtractor : `~lsst.afw.image.MaskedImageF` - An Exposure containing a scaled bright star model fit to every bright star profile; its image can - then be subtracted from the input exposure. + An Exposure containing a scaled bright star model fit to every + bright star profile; its image can then be subtracted from the + input exposure. invImages: list - A list containing the extended PSF models, shifted (and potentially warped) to match bright - stars' positionings. + A list containing the extended PSF models, shifted (and potentially + warped) to match bright stars' positions. """ for star in brightStarStamps: if star.gaiaGMag < self.config.magLimit: try: - # Adding the scaled model at the star location to the subtractor. + # Add the scaled model at the star location to subtractor. subtractor, invImage = self.addScaledModel(subtractor, star, multipleAnnuli) invImages.append(invImage) except RuntimeError as err: logger.error(err) return subtractor, invImages - - def runQuantum(self, butlerQC, inputRefs, outputRefs): - # Docstring inherited. - inputs = butlerQC.get(inputRefs) - dataId = butlerQC.quantum.dataId - refObjLoader = ReferenceObjectLoader( - dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], - refCats=inputs.pop("refCat"), - name=self.config.connections.refCat, - config=self.config.refObjLoader, - ) - subtractor, _, badStamps = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) - if self.config.doWriteSubtractedExposure: - outputExposure = inputs["inputExposure"].clone() - outputExposure.image -= subtractor.image - else: - outputExposure = None - outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None - # in its current state, the code produces outputBadStamps which are the stamps of stars that have not - # been subtracted from the image for any reason. If all the stars are subtracted from the calexp, the - # output is an empty fits file. - output = Struct(outputExposure=outputExposure, - outputBackgroundExposure=outputBackgroundExposure, - outputBadStamps=badStamps) - butlerQC.put(output, outputRefs) - - def run(self, - inputExposure, - inputBrightStarStamps, - inputExtendedPsf, - dataId, - skyCorr=None, - refObjLoader=None): - """Iterate over all bright stars in an exposure to scale the extended - PSF model before subtracting bright stars. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The image from which bright stars should be subtracted. - inputBrightStarStamps : - `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` - Set of stamps centered on each bright star to be subtracted, - produced by running - `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. - inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf` - Extended PSF model, produced by - `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. - dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The dataId of the exposure (and detector) bright stars should be - subtracted from. - skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional - Full focal plane sky correction, obtained by running - `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. If - `doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`. - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional - Loader to find objects within a reference catalog. - - Returns - ------- - subtractorExp : `~lsst.afw.image.ExposureF` - An Exposure containing a scaled bright star model fit to every - bright star profile; its image can then be subtracted from the - input exposure. - invImages : `list` [`~lsst.afw.image.MaskedImageF`] - A list of small images ("stamps") containing the model, each scaled - to its corresponding input bright star. - """ - self.inputExpBBox = inputExposure.getBBox() - if self.config.doApplySkyCorr and (skyCorr is not None): - self.log.info( - "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId - ) - self.applySkyCorr(inputExposure, skyCorr) - # Create an empty image the size of the exposure. - # TODO: DM-31085 (set mask planes). - subtractorExp = ExposureF(bbox=inputExposure.getBBox()) - subtractor = subtractorExp.maskedImage - # Make a copy of the input model. - self.model = inputExtendedPsf(dataId["detector"]).clone() - self.modelStampSize = self.model.getDimensions() - self.inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 - self.model = rotateImageBy90(self.model, self.inv90Rots) - - brightStarList = self.makeBrightStarList(inputBrightStarStamps, inputExposure, refObjLoader) - self.setMissedStarsStatsControl() - # This might change when we use multiple categories of stars for creating PSF. - innerRadius = inputBrightStarStamps._innerRadius - outerRadius = inputBrightStarStamps._outerRadius - brightStarStamps, badStamps = BrightStarStamps.initAndNormalize( - brightStarList, - innerRadius=innerRadius, - outerRadius=outerRadius, - nb90Rots=self.warpOutputs.nb90Rots, - imCenter=self.warper.modelCenter, - use_archive=True, - statsControl=self.missedStatsControl, - statsFlag=self.missedStatsFlag, - badMaskPlanes=self.warper.config.badMaskPlanes, - discardNanFluxObjects=False, - forceFindFlux=True, - ) - - invImages = [] - subtractor, invImages = self.buildSubtractor( - inputBrightStarStamps, subtractor, invImages, multipleAnnuli=False - ) - if len(brightStarStamps) > 0: - self.psf_annular_fluxes = self.findPsfAnnularFluxes(brightStarStamps) - subtractor, invImages = self.buildSubtractor( - brightStarStamps, subtractor, invImages, multipleAnnuli=True - ) - badStamps = BrightStarStamps(badStamps) - - return subtractorExp, invImages, badStamps