Skip to content

Commit

Permalink
Fix R_L changes and more type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
jlashner committed Aug 28, 2024
1 parent 654e0ca commit 270a826
Showing 1 changed file with 101 additions and 79 deletions.
180 changes: 101 additions & 79 deletions sotodlib/site_pipeline/update_det_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from tqdm.auto import tqdm
import logging
import sys
from typing import Optional, Union, Dict, List, cast
from typing import Optional, Union, Dict, List, cast, Any
from queue import Queue

from sotodlib import core
from sotodlib.io.metadata import write_dataset, ResultSet
from sotodlib.io.load_book import get_cal_obsids
import sotodlib.site_pipeline.util as sp_util
import multiprocessing as mp
import sodetlib.tes_param_correction as tpc # type: ignore
import sodetlib.tes_param_correction as tpc
from sodetlib.operations.iv import IVAnalysis
from sodetlib.operations.bias_steps import BiasStepAnalysis


# stolen from pysmurf, max bias volt / num_bits
Expand Down Expand Up @@ -75,59 +77,78 @@ class DetCalCfg:
Path to the root of the data directory. If None, it will be determined
from the context object.
"""
def __init__(
self,
root_dir: str,
context_path: str,
*,
data_root: Optional[str] = None,
raise_exceptions: bool = False,
apply_cal_correction: bool = True,
index_path: str = "det_cal.sqlite",
h5_path: str = "det_cal.h5",
cache_failed_obsids: bool = True,
failed_cache_file: str = "failed_obsids.yaml",
show_pb: bool = True,
param_correction_config: Union[Dict[str, Any], None, tpc.AnalysisCfg] = None,
run_method: str = "site",
nprocs_obs_info: int = 1,
nprocs_result_set: int = 10,
num_obs: Optional[int] = None,
log_level: str = "DEBUG"
) -> None:
self.root_dir = root_dir
self.context_path = os.path.expandvars(context_path)
self.ctx = core.Context(self.context_path)
if data_root is None:
self.data_root = get_data_root(self.ctx)
self.raise_exceptions = raise_exceptions
self.apply_cal_correction = apply_cal_correction
self.cache_failed_obsids = cache_failed_obsids
self.show_pb = show_pb
self.run_method = run_method

root_dir: str
context_path: str
data_root: Optional[str] = None
raise_exceptions: bool = False
apply_cal_correction: bool = True
index_path: str = "det_cal.sqlite"
h5_path: str = "det_cal.h5"
cache_failed_obsids: bool = True
failed_cache_file: str = "failed_obsids.yaml"
show_pb: bool = True

run_method: str = "site"
nprocs_obs_info: int = 1
nprocs_result_set: int = 10

num_obs: Optional[int] = None
log_level: str = "DEBUG"
param_correction_config: tpc.AnalysisCfg = None

def __post_init__(self):
if self.run_method not in ["site", "nersc"]:
raise ValueError("run_method must be in: ['site', 'nersc']")

self.nprocs_obs_info = nprocs_obs_info
self.nprocs_result_set = nprocs_result_set
self.num_obs = num_obs
self.log_level = log_level

self.root_dir = os.path.expandvars(self.root_dir)
if not os.path.exists(self.root_dir):
raise ValueError(f"Root dir does not exist: {self.root_dir}")

self.context_path = os.path.expandvars(self.context_path)

def parse_path(path):
def parse_path(path: str) -> str:
"Expand vars and make path absolute"
p = os.path.expandvars(path)
if not os.path.isabs(p):
p = os.path.join(self.root_dir, p)
return p

self.index_path = parse_path(self.index_path)
self.h5_path = parse_path(self.h5_path)
self.failed_cache_file = parse_path(self.failed_cache_file)
self.index_path = parse_path(index_path)
self.h5_path = parse_path(h5_path)
self.failed_cache_file = parse_path(failed_cache_file)

kw = {"show_pb": False, "default_nprocs": self.nprocs_result_set}
if self.param_correction_config is None:
self.param_correction_config = tpc.AnalysisCfg(**kw)
elif isinstance(self.param_correction_config, dict):
kw.update(self.param_correction_config)
self.param_correction_config = tpc.AnalysisCfg(**kw)
if param_correction_config is None:
self.param_correction_config = tpc.AnalysisCfg(**kw) # type: ignore
elif isinstance(param_correction_config, dict):
kw.update(param_correction_config)
self.param_correction_config = tpc.AnalysisCfg(**kw) #type: ignore
else:
self.param_correction_config = param_correction_config

self.setup_files()

@classmethod
def from_yaml(cls, path):
def from_yaml(cls, path) -> "DetCalCfg":
with open(path, "r") as f:
return cls(**yaml.safe_load(f))
d = yaml.safe_load(f)
return cls(**d)

def setup(self):
def setup_files(self) -> None:
"""Create directories and databases if they don't exist"""
if not os.path.exists(self.failed_cache_file):
# If file doesn't exist yet, just create an empty one
Expand All @@ -141,9 +162,6 @@ def setup(self):
db = core.metadata.ManifestDb(scheme=scheme)
db.to_file(self.index_path)

if self.data_root is None:
ctx = core.Context(self.context_path)
self.data_root = get_data_root(ctx)


@dataclass
Expand Down Expand Up @@ -254,8 +272,7 @@ class ObsInfo:
class ObsInfoResult:
obs_id: str
success: bool = False
traceback: Optional[str] = None

traceback: str = ''
obs_info: Optional[ObsInfo] = None

def get_obs_info(cfg: DetCalCfg, obs_id: str) -> ObsInfoResult:
Expand Down Expand Up @@ -320,12 +337,12 @@ def get_obs_info(cfg: DetCalCfg, obs_id: str) -> ObsInfoResult:
raise ValueError(f"Bias step data not found for in cal obs {oid}")
else:
logger.debug("missing bias step data for %s", dset)

if rtm_bit_to_volt is None:
rtm_bit_to_volt = DEFAULT_RTM_BIT_TO_VOLT
if pA_per_phi0 is None:
pA_per_phi0 = DEFAULT_pA_per_phi0

res.obs_info = ObsInfo(
obs_id=obs_id, am=am, iv_obsids=iv_obsids,
bs_obsids=bias_step_obsids,
Expand All @@ -344,7 +361,6 @@ class CalRessetResult:
"""
Results object for the get_cal_resset function.
"""

obs_info: ObsInfo
success: bool = False
traceback: Optional[str] = None
Expand Down Expand Up @@ -384,22 +400,24 @@ def get_cal_resset(cfg: DetCalCfg, obs_info: ObsInfo, pool=None) -> CalRessetRes
am = obs_info.am

ivas = {
dset: np.load(iva_file, allow_pickle=True).item()
dset: IVAnalysis.load(iva_file)
for dset, iva_file in obs_info.iva_files.items()
}
bsas = {
dset: np.load(bsa_file, allow_pickle=True).item()
dset: BiasStepAnalysis.load(bsa_file)
for dset, bsa_file in obs_info.bsa_files.items()
}

for iva in ivas.values(): # Run R_L correction if analysis version is old...
if iva.get('analysis_version', 0) == 0:
# This will eddit IVA dicts in place
tpc.recompute_iv_pars(iva, cfg.param_correction_config)
if cfg.apply_cal_correction:
for iva in ivas.values(): # Run R_L correction if analysis version is old...
if getattr(iva, 'analysis_version', 0) == 0:
# This will edit IVA dicts in place
logger.debug("Recomputing IV analysis for %s", obs_id)
tpc.recompute_ivpars(iva, cfg.param_correction_config)

iva = list(ivas.values())[0]
rtm_bit_to_volt = iva["meta"]["rtm_bit_to_volt"]
pA_per_phi0 = iva["meta"]["pA_per_phi0"]
rtm_bit_to_volt = iva.meta["rtm_bit_to_volt"]
pA_per_phi0 = iva.meta["pA_per_phi0"]
cals = [CalInfo(rid) for rid in am.det_info.readout_id]
if len(cals) == 0:
raise ValueError(f"No detectors found for {obs_id}")
Expand All @@ -414,15 +432,15 @@ def get_cal_resset(cfg: DetCalCfg, obs_info: ObsInfo, pool=None) -> CalRessetRes
if iva is None: # No IV analysis for this detset
continue

ridx = np.where((iva["bands"] == band) & (iva["channels"] == chan))[0]
ridx = np.where((iva.bands == band) & (iva.channels == chan))[0]
if not ridx: # Channel doesn't exist in IV analysis
continue

ridx = ridx[0]
cal.bg = iva["bgmap"][ridx]
cal.polarity = iva["polarity"][ridx]
cal.r_n = iva["R_n"][ridx]
cal.p_sat = iva["p_sat"][ridx]
cal.bg = iva.bgmap[ridx]
cal.polarity = iva.polarity[ridx]
cal.r_n = iva.R_n[ridx] # type: ignore
cal.p_sat = iva.p_sat[ridx] # type: ignore

obs_biases = dict(
zip(am.bias_lines.vals, am.biases[:, 0] * 2 * rtm_bit_to_volt)
Expand All @@ -434,8 +452,8 @@ def get_cal_resset(cfg: DetCalCfg, obs_info: ObsInfo, pool=None) -> CalRessetRes
if bsa is None:
continue

for bg, vb_bsa in enumerate(bsa["Vbias"]):
bl_label = f"{bsa['meta']['stream_id']}_b{bg:0>2}"
for bg, vb_bsa in enumerate(bsa.Vbias):
bl_label = f"{bsa.meta['stream_id']}_b{bg:0>2}"
if np.isnan(vb_bsa):
bias_line_is_valid[bl_label] = False
continue
Expand All @@ -454,7 +472,7 @@ def get_cal_resset(cfg: DetCalCfg, obs_info: ObsInfo, pool=None) -> CalRessetRes
# logger.debug(f"Applying correction for {dset}")
rs = []
if pool is None:
for b, c in zip(ivas[dset]["bands"], ivas[dset]["channels"]):
for b, c in zip(ivas[dset].bands, ivas[dset].channels):
chdata = tpc.RpFitChanData.from_data(
ivas[dset], bsas[dset], b, c
)
Expand Down Expand Up @@ -489,7 +507,7 @@ def find_correction_results(band, chan, dset):
if not bias_line_is_valid[bl_label]:
continue

ridx = np.where((bsa["bands"] == band) & (bsa["channels"] == chan))[0]
ridx = np.where((bsa.bands == band) & (bsa.channels == chan))[0]
if not ridx: # Channel doesn't exist in bias step analysis
continue

Expand All @@ -506,30 +524,33 @@ def find_correction_results(band, chan, dset):
use_correction = False

ridx = ridx[0]
cal.tau_eff = bsa["tau_eff"][ridx]
cal.tau_eff = bsa.tau_eff[ridx]
if bg != -1:
cal.v_bias = bsa["Vbias"][bg]

if use_correction:
correction = cast(tpc.CorrectionResults, correction)
cal.r_tes = correction.corrected_R0
cal.r_frac = correction.corrected_R0 / cal.r_n
cal.s_i = correction.corrected_Si * 1e6
cal.p_bias = correction.corrected_Pj * 1e-12
cal.loopgain = correction.loopgain
cal.v_bias = bsa.Vbias[bg]

if use_correction and correction.corrected_params is not None:
cpars = correction.corrected_params
cal.r_tes = cpars.corrected_R0
cal.r_frac = cpars.corrected_R0 / cal.r_n
cal.s_i = cpars.corrected_Si * 1e6
cal.p_bias = cpars.corrected_Pj * 1e-12
cal.loopgain = cpars.loopgain
else:
cal.r_tes = bsa["R0"][ridx]
cal.r_frac = bsa["Rfrac"][ridx]
cal.p_bias = bsa["Pj"][ridx]
cal.s_i = bsa["Si"][ridx]
cal.phase_to_pW = pA_per_phi0 / (2 * np.pi) / cal.s_i * cal.polarity
cal.r_tes = bsa.R0[ridx]
cal.r_frac = bsa.Rfrac[ridx]
cal.p_bias = bsa.Pj[ridx]
cal.s_i = bsa.Si[ridx]

if cal.s_i == 0:
cal.phase_to_pW = np.nan
else:
cal.phase_to_pW = pA_per_phi0 / (2 * np.pi) / cal.s_i * cal.polarity

res.result_set = np.array([astuple(c) for c in cals], dtype=CalInfo.dtype())
res.success = True
except Exception as e:
res.traceback = traceback.format_exc()
res.fail_msg = res.traceback
# res.fail_msg = str(e)
if cfg.raise_exceptions:
raise
return res
Expand Down Expand Up @@ -620,7 +641,6 @@ def run_update_site(cfg: DetCalCfg):
for ch in logger.handlers:
ch.setLevel(getattr(logging, cfg.log_level.upper()))

cfg.setup()
obs_ids = get_obsids_to_run(cfg)

logger.info(f"Processing {len(obs_ids)} obsids...")
Expand All @@ -632,6 +652,9 @@ def run_update_site(cfg: DetCalCfg):
logger.info(f"Could not get obs info for obs id: {oid}")
logger.error(res.traceback)

if res.obs_info is None:
continue

result_set = get_cal_resset(cfg, res.obs_info, pool=pool)
handle_result(result_set, cfg)

Expand All @@ -656,7 +679,6 @@ def run_update_nersc(cfg: DetCalCfg):
for ch in logger.handlers:
ch.setLevel(getattr(logging, cfg.log_level.upper()))

cfg.setup()
obs_ids = get_obsids_to_run(cfg)
# obs_ids = ['obs_1713962395_satp1_0000100']
# obs_ids = ['obs_1713758716_satp1_1000000']
Expand Down

0 comments on commit 270a826

Please sign in to comment.