Skip to content

Commit

Permalink
Switch to using sub-tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Aug 1, 2024
1 parent 90b2ea7 commit 460ddde
Showing 1 changed file with 91 additions and 36 deletions.
127 changes: 91 additions & 36 deletions python/lsst/meas/pz/estimate_pz_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,28 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = [
"EsimatePZTrainZTask",
"EstimatePZTaskConfig",
"EstimatePZTask",
"EstimatePZTrainZTask",
"EstimatePZKNNTask",
"EstimatePZTrainZConfig",
"EsimatePZKNNTask",
"EstimatePZKNNConfig",
]

from typing import Any

from abc import ABC

import lsst.pex.config as pexConfig
import lsst.pipe.base.connectionTypes as cT
import numpy as np
from ceci.config import StageConfig as CeciStageConfig
from ceci.config import StageParameter as CeciParam

# from ceci.stage import PipelineStage as CeciPipelineStage
from lsst.daf.butler import DeferredDatasetHandle
from lsst.pipe.base import (
Task,
PipelineTask,
PipelineTaskConfig,
PipelineTaskConnections,
Expand Down Expand Up @@ -89,11 +95,11 @@ class EstimatePZConnections(
)


class EstimatePZConfigBase(
PipelineTaskConfig, pipelineConnections=EstimatePZConnections
class EstimatePZAlgoConfigBase(
pexConfig.Config,
):
"""Base class for configurations of p(z)
estimation pipetasks.
"""Base class for configurations of algorithm specific p(z)
estimation tasks.
This class mostly just translates the RAIL configuration
parameters to pex.config parameters.
Expand All @@ -103,11 +109,9 @@ class EstimatePZConfigBase(
pex.config parameters.
Subclasses will just have to set
`estimator_class` and `estimator_module` and invoke _make_fields.
`stage_class` and invoke _make_fields.
"""

# estimator_class = None
# estimator_module = None
stage_class = None

stage_name = pexConfig.Field(doc="Rail stage name", dtype=str)
Expand Down Expand Up @@ -149,21 +153,21 @@ def _make_fields(cls):
)


class EsimatePZTaskBase(PipelineTask):
"""Base class for p(z) estimation
class EstimatePZAlgoTask(Task, ABC):
"""Task for algorithm specific p(z) estimation
This will provide almost all of the functionality
needed to run RAIL p(z) algorithms
Subclasses will just need to override
`ConfigClass` and `_DefaultName`
"""

ConfigClass = None
_DefaultName = None
ConfigClass = EstimatePZAlgoConfigBase

mag_conv = np.log(10) * 0.4

def __init__(self, **kwargs):
super().__init__(**kwargs)

@staticmethod
def _flux_to_mag(
flux_vals: np.array,
Expand Down Expand Up @@ -308,18 +312,10 @@ def run(
# so that we can pass the rest to RAIL
rail_kwargs = self.config.toDict().copy()
for key in ["saveLogOutput", "stage_name", "mag_offset", "connections"]:
rail_kwargs.pop(key)
rail_kwargs.pop(key, None)
rail_kwargs["output_mode"] = "return"

# Build the RAIL stage
# self._stage = PZFactory.build_cat_estimator_stage(
# self.config.stage_name,
# self.config.estimator_class,
# self.config.estimator_module,
# model_path=pzModel,
# input_path="dummy.in",
# **rail_kwargs,
# )
self._stage = PZFactory.build_stage_instance(
self.config.stage_name,
self.config.estimator_class,
Expand All @@ -343,54 +339,113 @@ def run(
return Struct(pz_pdfs=pz_pdfs)


class EstimatePZTrainZConfig(EstimatePZConfigBase):
class EstimatePZTaskConfig(
PipelineTaskConfig, pipelineConnections=EstimatePZConnections
):
"""Configuration for EstimatePZTask Pipeline task
This just allows picking and configuring of the available algorithms
"""

pz_algo = pexConfig.ConfigurableField(
target=EstimatePZAlgoTask,
doc="Algorithm specific configuration p(z) estimation task",
)


class EstimatePZTask(PipelineTask):
"""PipelineTask for p(z) estimation
This just makes the proper algorithm specfic Task and
passes the input data to it.
"""

ConfigClass = EstimatePZTaskConfig
_DefaultName = "EstimatePZ"

mag_conv = np.log(10) * 0.4

def __init__(self, initInputs, **kwargs):
super().__init__(initInputs=initInputs, **kwargs)
self.makeSubtask("pz_algo")

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
outputs = self.run(**inputs)
butlerQC.put(outputs, outputRefs)

def run(
self,
pzModel: dict[str, Any],
objectTable: DeferredDatasetHandle,
) -> Struct:
ret_struct = self.pz_algo.run(pzModel, objectTable)
return Struct(pz_pdfs=ret_struct.pz_pdfs)


class EstimatePZTrainZConfig(EstimatePZAlgoConfigBase):
"""Config for EstimatePZTrainZTask
This will select and comnfigure the TrainZEsimator p(z)
estimation algorithm
See https://github.com/LSSTDESC/rail_base/blob/main/src/rail/estimation/algos/train_z.py # noqa
for parameters and default values.
"""

# estimator_class = "TrainZEstimator"
# estimator_module = "rail.estimation.algos.train_z"
estimator_class = TrainZEstimator


EstimatePZTrainZConfig._make_fields()


class EstimatePZKNNConfig(EstimatePZConfigBase):
class EstimatePZKNNConfig(EstimatePZAlgoConfigBase):
"""Config for EstimatePZKNNTask
This will select and comnfigure the KNearNeighEstimator p(z)
estimation algorithm
See https://github.com/LSSTDESC/rail_sklearn/blob/main/src/rail/estimation/algos/k_nearneigh.py # noqa
for parameters and default values.
"""

# estimator_class = "KNearNeighEstimator"
# estimator_module = "rail.estimation.algos.k_nearneigh"
estimator_class = KNearNeighEstimator


EstimatePZKNNConfig._make_fields()


class EsimatePZTrainZTask(EsimatePZTaskBase):
"""Task that runs RAIL TrainZ algorithm for p(z) estimation
class EstimatePZTrainZTask(EstimatePZAlgoTask):
"""SubTask that runs RAIL TrainZ algorithm for p(z) estimation
See https://github.com/LSSTDESC/rail_base/blob/main/src/rail/estimation/algos/train_z.py # noqa
for algorithm implementation.
TrainZ is just a placeholder algorithm that assigns that same
p(z) distribution (taken from the input model file) to every object.
"""

ConfigClass = EstimatePZTrainZConfig
_DefaultName = "estimate_pz_trainz"

def _get_mags_and_errs(
self,
fluxes: DataFrame,
mag_offset: float,
) -> dict[str, np.array]:
return dict(lsst_i_mag=fluxes["i_gaap1p0Flux"])


class EsimatePZKNNTask(EsimatePZTaskBase):
"""Task that runs RAIL KNN algorithm for p(z) estimation
class EstimatePZKNNTask(EstimatePZAlgoTask):
"""SubTask that runs RAIL KNN algorithm for p(z) estimation
See https://github.com/LSSTDESC/rail_sklearn/blob/main/src/rail/estimation/algos/k_nearneigh.py # noqa
for algorithm implementation.
KNN estimates the p(z) distribution by taking
a weighted mixture of the nearest neigheboors in
color space.
"""

ConfigClass = EstimatePZKNNConfig
_DefaultName = "estimate_pz_knn"

0 comments on commit 460ddde

Please sign in to comment.