Skip to content

Commit

Permalink
Merge pull request #68 from simonsobs/state-passing
Browse files Browse the repository at this point in the history
State passing
  • Loading branch information
kmharrington authored Jul 9, 2024
2 parents 026bfc7 + 517bc62 commit e68c22a
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 47 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"jax[cpu]",
"equinox",
"so3g",
"pixell"
"pixell",
"dataclasses_json",
],
)
15 changes: 14 additions & 1 deletion src/schedlib/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import datetime as dt
from typing import List, Dict, Optional, Tuple, Any
from dataclasses import dataclass, field, replace as dc_replace
from dataclasses_json import dataclass_json

from abc import ABC, abstractmethod
import inspect

Expand All @@ -11,7 +13,7 @@

MIN_DURATION = 0.01


@dataclass_json
@dataclass(frozen=True)
class State:
"""
Expand Down Expand Up @@ -50,6 +52,17 @@ class State:
az_accel_now: Optional[float] = None
prev_state: Optional["State"] = field(default=None, repr=False)

def clear_history(self):
return dc_replace(self, **{"prev_state": None} )

def save(self, fname):
out = self.clear_history()
np.save( fname, out.to_dict() )

@classmethod
def load(cls, fname):
return cls.from_dict( np.load(fname, allow_pickle=True).item() )

def replace(self, **kwargs):
"""
Creates a new instance of the State class with specified attributes replaced with new values,
Expand Down
103 changes: 74 additions & 29 deletions src/schedlib/policies/sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import yaml
import os.path as op
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json

import datetime as dt
from typing import List, Union, Optional, Dict, Any, Tuple
import jax.tree_util as tu
Expand All @@ -17,6 +19,7 @@

logger = u.init_logger(__name__)

@dataclass_json
@dataclass(frozen=True)
class State(cmd.State):
"""
Expand All @@ -41,7 +44,12 @@ class State(cmd.State):
hwp_spinning: bool = False
last_ufm_relock: Optional[dt.datetime] = None
last_bias_step: Optional[dt.datetime] = None
last_bias_step_boresight: Optional[float] = None
last_bias_step_elevation: Optional[float] = None
last_iv: Optional[dt.datetime] = None
last_iv_boresight: Optional[float] = None
last_iv_elevation: Optional[float] = None
# relock sets to false, tracks if detectors are biased at all
is_det_setup: bool = False


Expand Down Expand Up @@ -93,7 +101,7 @@ class CalTarget:
# return state, 10, ["do something"]

@cmd.operation(name="sat.preamble", duration=0)
def preamble(hwp_cfg):
def preamble():
return [
"from nextline import disable_trace",
"import time",
Expand Down Expand Up @@ -122,7 +130,10 @@ def ufm_relock(state):
doit = False

if doit:
state = state.replace(last_ufm_relock=state.curr_time)
state = state.replace(
last_ufm_relock=state.curr_time,
is_det_setup=False,
)
return state, 15*u.minute, [
"############# Daily Relock",
"for smurf in pysmurfs:",
Expand All @@ -131,7 +142,7 @@ def ufm_relock(state):
" smurf.zero_biases.wait()",
"",
"time.sleep(120)",
"run.smurf.take_noise(concurrent=True, tag='oper,take_noise,res_check')",
"run.smurf.take_noise(concurrent=True, tag='res_check')",
"run.smurf.uxm_relock(concurrent=True)",
"",
]
Expand Down Expand Up @@ -177,29 +188,48 @@ def det_setup(state, block, apply_boresight_rot=True, iv_cadence=None):
# -> should always be done if det setup has not been done yet
# -> should be done at a regular interval if iv_cadence is not None
# -> should always be done if boresight rotation has changed
doit = (block.subtype == 'cal') or (block.alt != state.el_now)
doit = doit or (not state.is_det_setup)
doit = doit or (iv_cadence is not None and ((state.last_iv is None) or ((state.curr_time - state.last_iv).total_seconds() > iv_cadence)))
if apply_boresight_rot and (block.boresight_angle != state.boresight_rot_now):
doit = True
doit = (block.subtype == 'cal')
doit = doit or (not state.is_det_setup) or (state.last_iv is None)
if not doit:
if state.last_iv_elevation is not None:
doit = doit or (
not np.isclose(state.last_iv_elevation, block.alt, atol=1)
)
if apply_boresight_rot and state.last_iv_boresight is not None:
doit = doit or (
not np.isclose(
state.last_iv_boresight,
block.boresight_angle,
atol=1
)
)
if iv_cadence is not None:
time_since_last = (state.curr_time - state.last_iv).total_seconds()
doit = doit or (time_since_last > iv_cadence)

if doit:
commands = [
"",
"################### Detector Setup######################",
"run.smurf.take_bgmap(concurrent=True)",
"run.smurf.take_noise(concurrent=True, tag='res_check')",
"run.smurf.iv_curve(concurrent=True, ",
" iv_kwargs={'run_serially': False, 'cool_wait': 60*5})",
"run.smurf.bias_dets(concurrent=True)",
"time.sleep(180)",
"run.smurf.bias_step(concurrent=True)",
"run.smurf.take_noise(concurrent=True, tag='bias_check')",
"#################### Detector Setup Over ####################",
"",
]
state = state.replace(
last_bias_step=state.curr_time,
is_det_setup=True,
last_iv = state.curr_time,
last_bias_step=state.curr_time,
last_iv_elevation = block.alt,
last_iv_boresight = block.boresight_angle,
last_bias_step_elevation = block.alt,
last_bias_step_boresight = block.boresight_angle,
)
return state, 12*u.minute, commands
else:
Expand Down Expand Up @@ -298,9 +328,35 @@ def setup_boresight(state, block, apply_boresight_rot=True):

# passthrough any arguments, to be used in any sched-mode
@cmd.operation(name='sat.bias_step', return_duration=True)
def bias_step(state, min_interval=10*u.minute):
if state.last_bias_step is None or (state.curr_time - state.last_bias_step).total_seconds() > min_interval:
state = state.replace(last_bias_step=state.curr_time)
def bias_step(state, block, min_interval=15*u.minute):
doit = state.last_bias_step is None
if not doit:
time_since = (state.curr_time - state.last_bias_step).total_seconds()
doit = doit or (time_since > min_interval)

if state.last_bias_step_elevation is not None:
doit = doit or (
not np.isclose(
state.last_bias_step_elevation,
block.alt,
atol=1
)
)
if state.last_bias_step_boresight is not None:
doit = doit or (
not np.isclose(
state.last_bias_step_boresight,
block.boresight_angle,
atol=1
)
)

if doit :
state = state.replace(
last_bias_step=state.curr_time,
last_bias_step_elevation = block.alt,
last_bias_step_boresight = block.boresight_angle,
)
return state, 60, [ "run.smurf.bias_step(concurrent=True)", ]
else:
return state, 0, []
Expand Down Expand Up @@ -347,7 +403,6 @@ class SATPolicy:
cal_targets: List[CalTarget] = field(default_factory=list)
cal_policy: str = 'round-robin'
scan_tag: Optional[str] = None
iso_scan_speeds: Optional =None
boresight_override: Optional[float] = None
az_speed: float = 1. # deg / s
az_accel: float = 2. # deg / s^2
Expand Down Expand Up @@ -423,7 +478,7 @@ def construct_seq(loader_cfg):
# update az speed in scan blocks
blocks = core.seq_map_when(
lambda b: isinstance(b, inst.ScanBlock),
lambda b: b.replace(az_speed=self.az_speed),
lambda b: b.replace(az_speed=self.az_speed,az_accel=self.az_accel),
blocks
)

Expand Down Expand Up @@ -588,19 +643,6 @@ def apply(self, blocks: core.BlocksTree) -> core.BlocksTree:

blocks = core.seq_sort(blocks['baseline']['cmb'] + blocks['calibration'], flatten=True)

# alternate scan speeds for ISO testing
# run after flattening to presume order is set
if self.iso_scan_speeds is not None:
self.c = 0
def iso_set(block):
if block.subtype != 'cmb':
return block
v,a = self.iso_scan_speeds[ self.c % len(self.iso_scan_speeds)]
self.c += 1
return block.replace(az_speed=v, az_accel=a)
blocks = core.seq_map(iso_set, blocks)
del self.c

return blocks

def init_state(self, t0: dt.datetime) -> State:
Expand Down Expand Up @@ -631,7 +673,8 @@ def seq2cmd(
seq,
t0: dt.datetime,
t1: dt.datetime,
state: Optional[State] = None
state: Optional[State] = None,
return_state: bool = False,
) -> List[Any]:
"""
Converts a sequence of blocks into a list of commands to be executed
Expand Down Expand Up @@ -665,7 +708,9 @@ def seq2cmd(

# load building stage
build_op = get_build_stage('build_op', **self.stages.get('build_op', {}))
ops = build_op.apply(seq, t0, t1, state, self.operations)
ops, state = build_op.apply(seq, t0, t1, state, self.operations)
if return_state:
return ops, state
return ops

def cmd2txt(self, irs, t0, t1, state=None):
Expand Down
46 changes: 34 additions & 12 deletions src/schedlib/policies/satp1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass
import datetime as dt

from typing import Optional

from .. import source as src, utils as u
from .sat import SATPolicy, State, CalTarget
from ..commands import SchedMode
Expand Down Expand Up @@ -161,16 +163,19 @@ def make_blocks(master_file):
def make_operations(
az_speed, az_accel, disable_hwp=False,
apply_boresight_rot=True, hwp_cfg=None, hwp_dir=True,
iv_cadence=4*u.hour,
iv_cadence=4*u.hour, stow_at_end=False, run_relock=False
):
if hwp_cfg is None:
hwp_cfg = { 'iboot2': 'power-iboot-hwp-2', 'pid': 'hwp-pid', 'pmx': 'hwp-pmx', 'hwp-pmx': 'pmx', 'gripper': 'hwp-gripper', 'forward':hwp_dir }
pre_session_ops = [
{ 'name': 'sat.preamble' , 'sched_mode': SchedMode.PreSession, 'hwp_cfg': hwp_cfg, },
{ 'name': 'sat.preamble' , 'sched_mode': SchedMode.PreSession},
{ 'name': 'start_time' ,'sched_mode': SchedMode.PreSession},
{ 'name': 'sat.ufm_relock' , 'sched_mode': SchedMode.PreSession, },
{ 'name': 'set_scan_params' , 'sched_mode': SchedMode.PreSession, 'az_speed': az_speed, 'az_accel': az_accel, },
]
if run_relock:
pre_session_ops += [
{ 'name': 'sat.ufm_relock' , 'sched_mode': SchedMode.PreSession, }
]
cal_ops = [
{ 'name': 'sat.setup_boresight' , 'sched_mode': SchedMode.PreCal, 'apply_boresight_rot': apply_boresight_rot, },
{ 'name': 'sat.det_setup' , 'sched_mode': SchedMode.PreCal, 'apply_boresight_rot': apply_boresight_rot, 'iv_cadence':iv_cadence },
Expand All @@ -186,18 +191,20 @@ def make_operations(
{ 'name': 'sat.cmb_scan' , 'sched_mode': SchedMode.InObs, },
{ 'name': 'sat.bias_step' , 'sched_mode': SchedMode.PostObs, 'indent': 4, 'divider': ['']},
]
post_session_ops = [
{ 'name': 'sat.hwp_spin_down' , 'sched_mode': SchedMode.PostSession, 'disable_hwp': disable_hwp, },
{ 'name': 'sat.wrap_up' , 'sched_mode': SchedMode.PostSession, 'az_stow': 180, 'el_stow': 50},
]
if stow_at_end:
post_session_ops = [
{ 'name': 'sat.hwp_spin_down' , 'sched_mode': SchedMode.PostSession, 'disable_hwp': disable_hwp, },
{ 'name': 'sat.wrap_up' , 'sched_mode': SchedMode.PostSession, 'az_stow': 180, 'el_stow': 50},
]
else:
post_session_ops = []
return pre_session_ops + cal_ops + cmb_ops + post_session_ops

def make_config(
master_file,
az_speed,
az_accel,
cal_targets,
iso_scan_speeds=None,
boresight_override=None,
**op_cfg
):
Expand All @@ -222,7 +229,6 @@ def make_config(
'operations': operations,
'cal_targets': cal_targets,
'scan_tag': None,
'iso_scan_speeds': iso_scan_speeds,
'boresight_override': boresight_override,
'az_speed' : az_speed,
'az_accel' : az_accel,
Expand All @@ -248,20 +254,36 @@ def make_config(

@dataclass
class SATP1Policy(SATPolicy):
state_file: Optional[str] = None

@classmethod
def from_defaults(cls, master_file, az_speed=0.8, az_accel=1.5,
cal_targets=[], iso_scan_speeds=None, boresight_override=None, **op_cfg
cal_targets=[], boresight_override=None,
state_file=None, **op_cfg
):
return cls(**make_config(
x = cls(**make_config(
master_file, az_speed, az_accel,
cal_targets, iso_scan_speeds, boresight_override, **op_cfg
cal_targets, boresight_override, **op_cfg
))
x.state_file=state_file
return x

def add_cal_target(self, *args, **kwargs):
self.cal_targets.append(make_cal_target(*args, **kwargs))

def init_state(self, t0: dt.datetime) -> State:
"""customize typical initial state for satp1, if needed"""
if self.state_file is not None:
logger.info(f"using state from {self.state_file}")
state = State.load(self.state_file)
if state.curr_time < t0:
logger.info(
f"Loaded state is at {state.curr_time}. Updating time to"
f" {t0}"
)
state = state.replace(curr_time = t0)
return state

return State(
curr_time=t0,
az_now=180,
Expand Down
2 changes: 1 addition & 1 deletion src/schedlib/policies/satp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def make_operations(
if hwp_cfg is None:
hwp_cfg = { 'iboot2': 'power-iboot-hwp-2', 'pid': 'hwp-pid', 'pmx': 'hwp-pmx', 'hwp-pmx': 'pmx', 'gripper': 'hwp-gripper', 'forward':hwp_dir }
pre_session_ops = [
{ 'name': 'sat.preamble' , 'sched_mode': SchedMode.PreSession, 'hwp_cfg': hwp_cfg, },
{ 'name': 'sat.preamble' , 'sched_mode': SchedMode.PreSession, },
{ 'name': 'sat.ufm_relock' , 'sched_mode': SchedMode.PreSession, },
{ 'name': 'set_scan_params' , 'sched_mode': SchedMode.PreSession, 'az_speed': az_speed, 'az_accel': az_accel, },
]
Expand Down
6 changes: 3 additions & 3 deletions src/schedlib/policies/stages/build_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ def apply(self, seq, t0, t1, state, operations):

# now we do lowering further into full ops
logger.info(f"================ lowering (ops) ================")
ir_ops = self.lower_ops(ir, init_state)
ir_ops, out_state = self.lower_ops(ir, init_state)
logger.info(u.pformat(ir_ops))

logger.info(f"================ done ================")

return ir_ops
return ir_ops, out_state

def lower(self, seq, t0, t1, state, operations):
ir = []
Expand Down Expand Up @@ -370,7 +370,7 @@ def resolve_block(state, ir):
for ir in irs:
state, op_blocks = resolve_block(state, ir)
ir_lowered += op_blocks
return ir_lowered
return ir_lowered, state

def _apply_ops(self, state, op_cfgs, block=None, az=None, alt=None):
"""
Expand Down

0 comments on commit e68c22a

Please sign in to comment.