Skip to content

Commit

Permalink
Swtich to explicitly using and import RAIL class
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Jul 31, 2024
1 parent 75460b0 commit 90b2ea7
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions python/lsst/meas/pz/estimate_pz_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@
import numpy as np
from ceci.config import StageConfig as CeciStageConfig
from ceci.config import StageParameter as CeciParam
from ceci.stage import PipelineStage as CeciPipelineStage
# from ceci.stage import PipelineStage as CeciPipelineStage
from lsst.daf.butler import DeferredDatasetHandle
from lsst.pipe.base import (PipelineTask, PipelineTaskConfig,
PipelineTaskConnections, Struct)
from lsst.pipe.base import (
PipelineTask,
PipelineTaskConfig,
PipelineTaskConnections,
Struct,
)
from pandas import DataFrame
from rail.interfaces import PZFactory
from rail.estimation.algos.k_nearneigh import KNearNeighEstimator
from rail.estimation.algos.train_z import TrainZEstimator


class EstimatePZConnections(
Expand Down Expand Up @@ -100,17 +106,19 @@ class EstimatePZConfigBase(
`estimator_class` and `estimator_module` and invoke _make_fields.
"""

estimator_class = None
estimator_module = None
# estimator_class = None
# estimator_module = None
stage_class = None

stage_name = pexConfig.Field(doc="Rail stage name", dtype=str)
mag_offset = pexConfig.Field(doc="Magnitude offset", dtype=float, default=31.4)

@classmethod
def _make_fields(cls):
stage_class = CeciPipelineStage.get_stage(
cls.estimator_class, cls.estimator_module
)
# stage_class = CeciPipelineStage.get_stage(
# cls.estimator_class, cls.estimator_module
# )
stage_class = cls.estimator_class
for key, val in stage_class.config_options.items():
if isinstance(val, CeciStageConfig):
val = val.get(key)
Expand Down Expand Up @@ -301,12 +309,20 @@ def run(
rail_kwargs = self.config.toDict().copy()
for key in ["saveLogOutput", "stage_name", "mag_offset", "connections"]:
rail_kwargs.pop(key)
rail_kwargs["output_mode"] = "return"

# Build the RAIL stage
self._stage = PZFactory.build_cat_estimator_stage(
# self._stage = PZFactory.build_cat_estimator_stage(
# self.config.stage_name,
# self.config.estimator_class,
# self.config.estimator_module,
# model_path=pzModel,
# input_path="dummy.in",
# **rail_kwargs,
# )
self._stage = PZFactory.build_stage_instance(
self.config.stage_name,
self.config.estimator_class,
self.config.estimator_module,
model_path=pzModel,
input_path="dummy.in",
**rail_kwargs,
Expand Down Expand Up @@ -334,8 +350,9 @@ class EstimatePZTrainZConfig(EstimatePZConfigBase):
for parameters and default values.
"""

estimator_class = "TrainZEstimator"
estimator_module = "rail.estimation.algos.train_z"
# estimator_class = "TrainZEstimator"
# estimator_module = "rail.estimation.algos.train_z"
estimator_class = TrainZEstimator


EstimatePZTrainZConfig._make_fields()
Expand All @@ -348,8 +365,9 @@ class EstimatePZKNNConfig(EstimatePZConfigBase):
for parameters and default values.
"""

estimator_class = "KNearNeighEstimator"
estimator_module = "rail.estimation.algos.k_nearneigh"
# estimator_class = "KNearNeighEstimator"
# estimator_module = "rail.estimation.algos.k_nearneigh"
estimator_class = KNearNeighEstimator


EstimatePZKNNConfig._make_fields()
Expand Down

0 comments on commit 90b2ea7

Please sign in to comment.