Skip to content

Commit

Permalink
moving glitch flags path to init
Browse files Browse the repository at this point in the history
  • Loading branch information
iparask committed Sep 23, 2024
1 parent 7c44e53 commit 51b0e3b
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions sotodlib/mapmaking/ml_mapmaker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import numpy as np
from pixell import bunch, enmap, tilemap, utils
import h5py
import so3g
from pixell import bunch, enmap, tilemap, putils

from .. import coords
from .pointing_matrix import *
from .utilities import *
from .pointing_matrix import PmatCut
from .utilities import (MultiZipper, get_flags_from_path, recentering_to_quat_lonlat,
evaluate_recentering, TileMapZipper, MapZipper,
safe_invert_div, unarr, ArrayZipper)
from .noise_model import NmatUncorr


class MLMapmaker:
Expand Down Expand Up @@ -147,7 +152,7 @@ def transeval(self, id, obs, other, x, tod=None):

class Signal:
"""This class represents a thing we want to solve for, e.g. the sky, ground, cut samples, etc."""
def __init__(self, name, ofmt, output, ext):
def __init__(self, name, ofmt, output, ext, glitch_flags: str = "flags.glitch_flags"):
"""Initialize a Signal. It probably doesn't make sense to construct a generic signal
directly, though. Use one of the subclasses.
Arguments:
Expand All @@ -162,6 +167,7 @@ def __init__(self, name, ofmt, output, ext):
self.ext = ext
self.dof = None
self.ready = False
self.glitch_flags = glitch_flags
def add_obs(self, id, obs, nmat, Nd): pass
def prepare(self): self.ready = True
def forward (self, id, tod, x): pass
Expand All @@ -177,12 +183,12 @@ class SignalMap(Signal):
"""Signal describing a non-distributed sky map."""
def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", output=True,
ext="fits", dtype=np.float32, sys=None, recenter=None, tile_shape=(500,500), tiled=False,
interpol=None):
interpol=None, glitch_flags: str = "flags.glitch_flags"):
"""Signal describing a sky map in the coordinate system given by "sys", which defaults
to equatorial coordinates. If tiled==True, then this will be a distributed map with
the given tile_shape, otherwise it will be a plain enmap. interpol controls the
pointing matrix interpolation mode. See so3g's Projectionist docstring for details."""
Signal.__init__(self, name, ofmt, output, ext)
Signal.__init__(self, name, ofmt, output, ext, glitch_flags)
self.comm = comm
self.comps = comps
self.sys = sys
Expand All @@ -203,15 +209,15 @@ def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", out
self.div = enmap.zeros((ncomp,ncomp)+shape, wcs, dtype=dtype)
self.hits= enmap.zeros( shape, wcs, dtype=dtype)

def add_obs(self, id, obs, nmat, Nd, pmap=None, glitch_flags:str ="flags.glitch_flags"):
def add_obs(self, id, obs, nmat, Nd, pmap=None):
"""Add and process an observation, building the pointing matrix
and our part of the RHS. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data.
"""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
ctime = obs.timestamps
pcut = PmatCut(get_flags_from_path(obs, glitch_flags)) # could pass this in, but fast to construct
pcut = PmatCut(get_flags_from_path(obs, self.glitch_flags)) # could pass this in, but fast to construct
if pmap is None:
# Build the local geometry and pointing matrix for this observation
if self.recenter:
Expand Down Expand Up @@ -348,6 +354,7 @@ def transeval(self, id, obs, other, map, tod):
# Currently we don't support any actual translation, but could handle
# resolution changes in the future (probably not useful though)
self._checkcompat(other)
ctime = obs.timestamp
# Build the local geometry and pointing matrix for this observation
if self.recenter:
rot = recentering_to_quat_lonlat(*evaluate_recentering(self.recenter,
Expand All @@ -362,9 +369,9 @@ def transeval(self, id, obs, other, map, tod):

class SignalCut(Signal):
def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
output=False, cut_type=None):
output=False, cut_type=None, glitch_flags:str ="flags.glitch_flags"):
"""Signal for handling the ML solution for the values of the cut samples."""
Signal.__init__(self, name, ofmt, output, ext="hdf")
Signal.__init__(self, name, ofmt, output, ext="hdf", glitch_flags=glitch_flags)
self.comm = comm
self.data = {}
self.dtype = dtype
Expand All @@ -373,12 +380,12 @@ def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
self.rhs = []
self.div = []

def add_obs(self, id, obs, nmat, Nd, glitch_flags="flags.glitch_flags"):
def add_obs(self, id, obs, nmat, Nd):
"""Add and process an observation. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data."""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
pcut = PmatCut(get_flags_from_path(obs, glitch_flags), model=self.cut_type)
pcut = PmatCut(get_flags_from_path(obs, self.glitch_flags), model=self.cut_type)
# Build our RHS
obs_rhs = np.zeros(pcut.njunk, self.dtype)
pcut.backward(Nd, obs_rhs)
Expand Down Expand Up @@ -442,15 +449,15 @@ def translate(self, other, junk):
so3g.translate_cuts(odata.pcut.cuts, sdata.pcut.cuts, sdata.pcut.model, sdata.pcut.params, junk[odata.i1:odata.i2], res[sdata.i1:sdata.i2])
return res

def transeval(self, id, obs, other, junk, tod, glitch_flags: str="flags.glitch_flags"):
def transeval(self, id, obs, other, junk, tod):
"""Translate data junk from SignalCut other to the current SignalCut,
and then evaluate it for the given observation, returning a tod.
This is used when building a signal-free tod for the noise model
in multipass mapmaking."""
self._checkcompat(other)
# We have to make a pointing matrix from scratch because add_obs
# won't have been called yet at this point
spcut = PmatCut(get_flags_from_path(obs, glitch_flags), model=self.cut_type)
spcut = PmatCut(get_flags_from_path(obs, self.glitch_flags), model=self.cut_type)
# We do have one for other though, since that will be the output
# from the previous round of multiplass mapmaking.
odata = other.data[id]
Expand Down

0 comments on commit 51b0e3b

Please sign in to comment.