Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gp/fix/act flags #947

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
6 changes: 5 additions & 1 deletion sotodlib/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import logging
import numpy as np

from typing import Union, Dict, Tuple, List

from . import metadata
from .util import tag_substr
from .axisman import AxisManager, OffsetAxis, AxisInterface

logger = logging.getLogger(__name__)


class Context(odict):
# Sets of special handlers may be registered in this class variable, then
# requested by name in the context.yaml key "context_hooks".
Expand Down Expand Up @@ -322,7 +325,8 @@ def get_meta(self,
check=False,
ignore_missing=False,
on_missing=None,
det_info_scan=False):
det_info_scan=False
):
"""Load supporting metadata for an observation and return it in an
AxisManager.

Expand Down
2 changes: 1 addition & 1 deletion sotodlib/core/g3_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

"""

from spt3g import core
from so3g.spt3g import core


class DataG3Module(object):
Expand Down
63 changes: 39 additions & 24 deletions sotodlib/mapmaking/ml_mapmaker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import numpy as np
from pixell import enmap, utils, tilemap, bunch
import h5py
import so3g
from typing import Optional
from pixell import bunch, enmap, tilemap
from pixell import utils as putils

from .. import coords
from .utilities import *
from .pointing_matrix import *
from .pointing_matrix import PmatCut
from .utilities import (MultiZipper, get_flags_from_path, recentering_to_quat_lonlat,
evaluate_recentering, TileMapZipper, MapZipper,
safe_invert_div, unarr, ArrayZipper)
from .noise_model import NmatUncorr


class MLMapmaker:
def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False):
def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False, glitch_flags:str = "flags.glitch_flags"):
"""Initialize a Maximum Likelihood Mapmaker.
Arguments:
* signals: List of Signal-objects representing the models that will be solved
Expand All @@ -26,6 +34,7 @@ def __init__(self, signals=[], noise_model=None, dtype=np.float32, verbose=False
self.data = []
self.dof = MultiZipper()
self.ready = False
self.glitch_flags_path = glitch_flags

def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None):
# Prepare our tod
Expand All @@ -36,7 +45,7 @@ def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None)
# the noise model, if available
if signal_estimate is not None: tod -= signal_estimate
if deslope:
utils.deslope(tod, w=5, inplace=True)
putils.deslope(tod, w=5, inplace=True)
# Allow the user to override the noise model on a per-obs level
if noise_model is None: noise_model = self.noise_model
# Build the noise model from the obs unless a fully
Expand All @@ -55,12 +64,12 @@ def add_obs(self, id, obs, deslope=True, noise_model=None, signal_estimate=None)
# The signal estimate might not be desloped, so
# adding it back can reintroduce a slope. Fix that here.
if deslope:
utils.deslope(tod, w=5, inplace=True)
putils.deslope(tod, w=5, inplace=True)
# And apply it to the tod
tod = nmat.apply(tod)
# Add the observation to each of our signals
for signal in self.signals:
signal.add_obs(id, obs, nmat, tod)
signal.add_obs(id, obs, nmat, tod, glitch_flags=self.glitch_flags_path)
# Save what we need about this observation
self.data.append(bunch.Bunch(id=id, ndet=obs.dets.count, nsamp=len(ctime),
dets=obs.dets.vals, nmat=nmat))
Expand Down Expand Up @@ -119,7 +128,7 @@ def solve(self, maxiter=500, maxerr=1e-6, x0=None):
self.prepare()
rhs = self.dof.zip(*[signal.rhs for signal in self.signals])
if x0 is not None: x0 = self.dof.zip(*x0)
solver = utils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0)
solver = putils.CG(self.A, rhs, M=self.M, dot=self.dof.dot, x0=x0)
while solver.i < maxiter and solver.err > maxerr:
solver.step()
yield bunch.Bunch(i=solver.i, err=solver.err, x=self.dof.unzip(solver.x))
Expand All @@ -146,7 +155,7 @@ def transeval(self, id, obs, other, x, tod=None):

class Signal:
"""This class represents a thing we want to solve for, e.g. the sky, ground, cut samples, etc."""
def __init__(self, name, ofmt, output, ext):
def __init__(self, name, ofmt, output, ext, glitch_flags: str = "flags.glitch_flags"):
"""Initialize a Signal. It probably doesn't make sense to construct a generic signal
directly, though. Use one of the subclasses.
Arguments:
Expand All @@ -161,7 +170,8 @@ def __init__(self, name, ofmt, output, ext):
self.ext = ext
self.dof = None
self.ready = False
def add_obs(self, id, obs, nmat, Nd): pass
self.glitch_flags = glitch_flags
def add_obs(self, id, obs, nmat, Nd, glitch_flags:Optional[str]): pass
def prepare(self): self.ready = True
def forward (self, id, tod, x): pass
def backward(self, id, tod, x): pass
Expand All @@ -176,12 +186,12 @@ class SignalMap(Signal):
"""Signal describing a non-distributed sky map."""
def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", output=True,
ext="fits", dtype=np.float32, sys=None, recenter=None, tile_shape=(500,500), tiled=False,
interpol=None):
interpol=None, glitch_flags: str = "flags.glitch_flags"):
"""Signal describing a sky map in the coordinate system given by "sys", which defaults
to equatorial coordinates. If tiled==True, then this will be a distributed map with
the given tile_shape, otherwise it will be a plain enmap. interpol controls the
pointing matrix interpolation mode. See so3g's Projectionist docstring for details."""
Signal.__init__(self, name, ofmt, output, ext)
Signal.__init__(self, name, ofmt, output, ext, glitch_flags)
self.comm = comm
self.comps = comps
self.sys = sys
Expand All @@ -202,15 +212,16 @@ def __init__(self, shape, wcs, comm, comps="TQU", name="sky", ofmt="{name}", out
self.div = enmap.zeros((ncomp,ncomp)+shape, wcs, dtype=dtype)
self.hits= enmap.zeros( shape, wcs, dtype=dtype)

def add_obs(self, id, obs, nmat, Nd, pmap=None):
def add_obs(self, id, obs, nmat, Nd, pmap=None, glitch_flags: Optional[str] = None):
"""Add and process an observation, building the pointing matrix
and our part of the RHS. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data.
"""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
ctime = obs.timestamps
pcut = PmatCut(obs.flags.glitch_flags) # could pass this in, but fast to construct
gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
pcut = PmatCut(get_flags_from_path(obs, gflags)) # could pass this in, but fast to construct
if pmap is None:
# Build the local geometry and pointing matrix for this observation
if self.recenter:
Expand Down Expand Up @@ -261,9 +272,9 @@ def prepare(self):
self.dof = TileMapZipper(self.rhs.geometry, dtype=self.dtype, comm=self.comm)
else:
if self.comm is not None:
self.rhs = utils.allreduce(self.rhs, self.comm)
self.div = utils.allreduce(self.div, self.comm)
self.hits = utils.allreduce(self.hits, self.comm)
self.rhs = putils.allreduce(self.rhs, self.comm)
self.div = putils.allreduce(self.div, self.comm)
self.hits = putils.allreduce(self.hits, self.comm)
self.dof = MapZipper(*self.rhs.geometry, dtype=self.dtype)
self.idiv = safe_invert_div(self.div)
self.ready = True
Expand Down Expand Up @@ -300,7 +311,7 @@ def from_work(self, map):
return tilemap.redistribute(map, self.comm, self.rhs.geometry.active)
else:
if self.comm is None: return map
else: return utils.allreduce(map, self.comm)
else: return putils.allreduce(map, self.comm)

def write(self, prefix, tag, m):
if not self.output: return
Expand Down Expand Up @@ -347,6 +358,7 @@ def transeval(self, id, obs, other, map, tod):
# Currently we don't support any actual translation, but could handle
# resolution changes in the future (probably not useful though)
self._checkcompat(other)
ctime = obs.timestamp
# Build the local geometry and pointing matrix for this observation
if self.recenter:
rot = recentering_to_quat_lonlat(*evaluate_recentering(self.recenter,
Expand All @@ -361,9 +373,9 @@ def transeval(self, id, obs, other, map, tod):

class SignalCut(Signal):
def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
output=False, cut_type=None):
output=False, cut_type=None, glitch_flags:str ="flags.glitch_flags"):
"""Signal for handling the ML solution for the values of the cut samples."""
Signal.__init__(self, name, ofmt, output, ext="hdf")
Signal.__init__(self, name, ofmt, output, ext="hdf", glitch_flags=glitch_flags)
self.comm = comm
self.data = {}
self.dtype = dtype
Expand All @@ -372,12 +384,14 @@ def __init__(self, comm, name="cut", ofmt="{name}_{rank:02}", dtype=np.float32,
self.rhs = []
self.div = []

def add_obs(self, id, obs, nmat, Nd):
def add_obs(self, id, obs, nmat, Nd, glitch_flags: Optional[str] = None):
"""Add and process an observation. "obs" should be an Observation axis manager,
nmat a noise model, representing the inverse noise covariance matrix,
and Nd the result of applying the noise model to the detector time-ordered data."""
Nd = Nd.copy() # This copy can be avoided if build_obs is split into two parts
pcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type)

gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
pcut = PmatCut(get_flags_from_path(obs, gflags), model=self.cut_type)
# Build our RHS
obs_rhs = np.zeros(pcut.njunk, self.dtype)
pcut.backward(Nd, obs_rhs)
Expand Down Expand Up @@ -441,15 +455,16 @@ def translate(self, other, junk):
so3g.translate_cuts(odata.pcut.cuts, sdata.pcut.cuts, sdata.pcut.model, sdata.pcut.params, junk[odata.i1:odata.i2], res[sdata.i1:sdata.i2])
return res

def transeval(self, id, obs, other, junk, tod):
def transeval(self, id, obs, other, junk, tod, glitch_flags: Optional[str] = None):
"""Translate data junk from SignalCut other to the current SignalCut,
and then evaluate it for the given observation, returning a tod.
This is used when building a signal-free tod for the noise model
in multipass mapmaking."""
self._checkcompat(other)
# We have to make a pointing matrix from scratch because add_obs
# won't have been called yet at this point
spcut = PmatCut(obs.flags.glitch_flags, model=self.cut_type)
gflags = glitch_flags if glitch_flags is not None else self.glitch_flags
spcut = PmatCut(get_flags_from_path(obs, gflags), model=self.cut_type)
# We do have one for other though, since that will be the output
# from the previous round of multiplass mapmaking.
odata = other.data[id]
Expand Down
70 changes: 59 additions & 11 deletions sotodlib/mapmaking/utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Union, Optional

import numpy as np
from pixell import enmap, utils, fft, tilemap, resample
import so3g
from pixell import enmap, fft, resample, tilemap, utils

from .. import coords, core, tod_ops

from .. import core
from .. import tod_ops
from .. import coords

def deslope_el(tod, el, srate, inplace=False):
if not inplace: tod = tod.copy()
Expand Down Expand Up @@ -136,7 +137,6 @@ def safe_invert_div(div, lim=1e-2, lim0=np.finfo(np.float32).tiny**0.5):
return idiv



def measure_cov(d, nmax=10000):
d = d[:,::max(1,d.shape[1]//nmax)]
n,m = d.shape
Expand Down Expand Up @@ -339,6 +339,7 @@ def evaluate_recentering(info, ctime, geom=None, site=None, weather="typical"):
"""Evaluate the quaternion that performs the coordinate recentering specified in
info, which can be obtained from parse_recentering."""
import ephem

# Get the coordinates of the from, to and up points. This was a bit involved...
def to_cel(lonlat, sys, ctime=None, site=None, weather=None):
# Convert lonlat from sys to celestial coorinates. Maybe polish and put elswhere
Expand Down Expand Up @@ -370,6 +371,7 @@ def recentering_to_quat_lonlat(p1, p2, pu):
"""Return the quaternion that represents the rotation that takes point p1
to p2, with the up direction pointing towards the point pu, all given as lonlat pairs"""
from so3g.proj import quat

# 1. First rotate our point to the north pole: Ry(-(90-dec1))Rz(-ra1)
# 2. Apply the same rotation to the up point.
# 3. We want the up point to be upwards, so rotate it to ra = 180°: Rz(pi-rau2)
Expand Down Expand Up @@ -439,8 +441,48 @@ def rangemat_sum(rangemat):
res[i] = np.sum(ra[:,1]-ra[:,0])
return res

def find_usable_detectors(obs, maxcut=0.1):
ncut = rangemat_sum(obs.flags.glitch_flags)
def flags_in_path(
aman: core.AxisManager, rpath: str, sep: str = "."
) -> bool:
"""
This function allows to pull data from an AxisManager based on a path.
Parameters:
- aman: An Axis Manager object
- path: a string with a recursive path to extract data. The path is separated via a sep.
For example 'flags.glitch_flags'
- sep: separator. Defaults to `.`
"""

rpath = rpath.split(sep=sep)
flags = aman.copy()
while rpath and flags is not None:
path = rpath.pop()
flags = flags[path]

return flags is not None


def get_flags_from_path(
aman: core.AxisManager, rpath: str, sep: str = "."
) -> Union[so3g.proj.RangesMatrix, Any]:
"""
This function allows to pull data from an AxisManager based on a path.
Parameters:
- aman: An Axis Manager object
- path: a string with a recursive path to extract data. The path is separated via a sep.
For example 'flags.glitch_flags'
- sep: separator. Defaults to `.`
"""

flags = aman.copy()
for path in rpath.split(sep=sep):
flags = flags[path]

return flags


def find_usable_detectors(obs, maxcut=0.1, glitch_flags: str = "flags.glitch_flags"):
ncut = rangemat_sum(get_flags_from_path(obs, glitch_flags))
good = ncut < obs.samps.count * maxcut
return obs.dets.vals[good]

Expand Down Expand Up @@ -499,7 +541,7 @@ def downsample_obs(obs, down):
if isinstance(val, core.AxisManager):
res.wrap(key, val)
else:
axdesc = [(k,v) for k,v in enumerate(axes) if v is not None]
axdesc = [(k, v) for k, v in enumerate(axes) if v is not None]
res.wrap(key, val, axdesc)
# The normal sample stuff
res.wrap("timestamps", obs.timestamps[::down], [(0, "samps")])
Expand All @@ -511,16 +553,22 @@ def downsample_obs(obs, down):

# The cuts
# obs.flags will contain all types of flags. We should query it for glitch_flags and source_flags
cut_keys = ["glitch_flags"]
cut_keys = []
if flags_in_path(obs, "glitch_flags"):
cut_keys.append("glitch_flags")
elif flags_in_path(obs, "flags.glitch_flags"):
cut_keys.append("flags.glitch_flags")

if "source_flags" in obs.flags:
if flags_in_path(obs, "source_flags"):
cut_keys.append("source_flags")
elif flags_in_path(obs, "flags.source_flags"):
cut_keys.append("flags.source_flags")

# We need to add a res.flags FlagManager to res
res = res.wrap('flags', core.FlagManager.for_tod(res))

for key in cut_keys:
res.flags.wrap(key, downsample_cut(getattr(obs.flags, key), down), [(0,"dets"),(1,"samps")])
res.flags.wrap(key, downsample_cut(get_flags_from_path(obs, key), down), [(0,"dets"),(1,"samps")])

# Not sure how to deal with flags. Some sort of or-binning operation? But it
# doesn't matter anyway
Expand Down