Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolate healpix maps #784

Merged
merged 6 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/toast/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ install(FILES
signal_diff_noise_model.py
gainscrambler.py
pointing.py
interpolate_healpix.py
scan_healpix.py
scan_wcs.py
mapmaker_binning.py
Expand Down
1 change: 1 addition & 0 deletions src/toast/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .run_spt3g import RunSpt3g
from .save_hdf5 import SaveHDF5
from .save_spt3g import SaveSpt3g
from .interpolate_healpix import InterpolateHealpixMap
from .scan_healpix import ScanHealpixMap, ScanHealpixMask
from .scan_map import ScanMap, ScanMask, ScanScale
from .scan_wcs import ScanWCSMap, ScanWCSMask
Expand Down
247 changes: 247 additions & 0 deletions src/toast/ops/interpolate_healpix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright (c) 2024 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

import numpy as np
import healpy as hp
import traitlets
from astropy import units as u
from pshmem import MPIShared

import toast.qarray as qa
from ..observation import default_values as defaults
from ..pixels_io_healpix import read_healpix
from ..timing import function_timer
from ..traits import Bool, Instance, Int, Unicode, Unit, trait_docs
from ..utils import Logger
from .operator import Operator


@trait_docs
class InterpolateHealpixMap(Operator):
"""Operator which reads a HEALPix format map from disk and
interpolates it to a timestream.

The map file is loaded and placed in shared memory on every
participating node. For each observation, the pointing model is
used to expand the pointing and bilinearly interpolate the map
values into detector data.

"""

# Class traits

API = Int(0, help="Internal interface version for this operator")

file = Unicode(
None,
allow_none=True,
help="Path to healpix FITS file. Use ';' if providing multiple files",
)

det_data = Unicode(
defaults.det_data,
help="Observation detdata key for accumulating output. Use ';' if different "
"files are applied to different flavors",
)

det_data_units = Unit(
defaults.det_data_units, help="Output units if creating detector data"
)

det_mask = Int(
defaults.det_mask_invalid,
help="Bit mask value for per-detector flagging",
)

subtract = Bool(
False, help="If True, subtract the map timestream instead of accumulating"
)

zero = Bool(False, help="If True, zero the data before accumulating / subtracting")

detector_pointing = Instance(
klass=Operator,
allow_none=True,
help="Operator that translates boresight pointing into detector frame",
)

stokes_weights = Instance(
klass=Operator,
allow_none=True,
help="This must be an instance of a Stokes weights operator",
)

save_map = Bool(False, help="If True, do not delete map during finalize")

@traitlets.validate("det_mask")
def _check_det_mask(self, proposal):
check = proposal["value"]
if check < 0:
raise traitlets.TraitError("Det mask should be a positive integer")
return check

@traitlets.validate("detector_pointing")
def _check_detector_pointing(self, proposal):
detpointing = proposal["value"]
if detpointing is not None:
if not isinstance(detpointing, Operator):
raise traitlets.TraitError(
"detector_pointing should be an Operator instance"
)
# Check that this operator has the traits we expect
for trt in [
"view",
"boresight",
"shared_flags",
"shared_flag_mask",
"quats",
"coord_in",
"coord_out",
]:
if not detpointing.has_trait(trt):
msg = f"detector_pointing operator should have a '{trt}' trait"
raise traitlets.TraitError(msg)
return detpointing

@traitlets.validate("stokes_weights")
def _check_stokes_weights(self, proposal):
weights = proposal["value"]
if weights is not None:
if not isinstance(weights, Operator):
raise traitlets.TraitError(
"stokes_weights should be an Operator instance"
)
# Check that this operator has the traits we expect
for trt in ["weights", "view"]:
if not weights.has_trait(trt):
msg = f"stokes_weights operator should have a '{trt}' trait"
raise traitlets.TraitError(msg)
return weights

def __init__(self, **kwargs):
self.map_names = []
self.maps = {}
super().__init__(**kwargs)

@function_timer
def _exec(self, data, detectors=None, **kwargs):
log = Logger.get()

for trait in ("file", "detector_pointing", "stokes_weights"):
if getattr(self, trait) is None:
msg = f"You must set the '{trait}' trait before calling exec()"
raise RuntimeError(msg)

# Split up the file and map names
self.file_names = self.file.split(";")
nmap = len(self.file_names)
self.det_data_keys = self.det_data.split(";")
nkey = len(self.det_data_keys)
if nkey != 1 and (nmap != nkey):
msg = "If multiple detdata keys are provided, each must have its own map"
raise RuntimeError(msg)
self.map_names = [f"{self.name}_map{i}" for i in range(nmap)]

# Determine the number of non-zeros from the Stokes weights
nnz = None
if self.stokes_weights is None or self.stokes_weights.mode == "I":
nnz = 1
elif self.stokes_weights.mode == "IQU":
nnz = 3
else:
msg = f"Unknown Stokes weights mode '{self.stokes_weights.mode}'"
raise RuntimeError(msg)

# Create our map(s) to scan named after our own operator name. Generally the
# files on disk are stored as float32, but even if not there is no real benefit
# to having higher precision to simulated map signal that is projected into
# timestreams.

world_comm = data.comm.comm_world
if world_comm is None:
world_rank = 0
else:
world_rank = world_comm.rank

for file_name, map_name in zip(self.file_names, self.map_names):
if map_name not in self.maps:
if world_rank == 0:
m = np.atleast_2d(read_healpix(file_name, None, dtype=np.float32))
map_shape = m.shape
else:
m = None
map_shape = None
if world_comm is not None:
map_shape = world_comm.bcast(map_shape)
self.maps[map_name] = MPIShared(map_shape, np.float32, world_comm)
self.maps[map_name].set(m)

# Loop over all observations and local detectors, interpolating each map
for ob in data.obs:
# Get the detectors we are using for this observation
dets = ob.select_local_detectors(detectors, flagmask=self.det_mask)
if len(dets) == 0:
# Nothing to do for this observation
continue
for key in self.det_data_keys:
# If our output detector data does not yet exist, create it
exists_data = ob.detdata.ensure(
key, detectors=dets, create_units=self.det_data_units
)
if self.zero:
ob.detdata[key][:] = 0

ob_data = data.select(obs_name=ob.name)
current_ob = ob_data.obs[0]
for idet, det in enumerate(dets):
self.detector_pointing.apply(ob_data, detectors=[det])
self.stokes_weights.apply(ob_data, detectors=[det])
det_quat = current_ob.detdata[self.detector_pointing.quats][det]
# Convert pointing quaternion into angles
theta, phi, _ = qa.to_iso_angles(det_quat)
# Get pointing weights
weights = current_ob.detdata[self.stokes_weights.weights][det]

# Interpolate the provided maps and accumulate the
# appropriate timestreams in the original observation
for map_name, map_value in self.maps.items():
if len(self.det_data_keys) == 1:
det_data_key = self.det_data_keys[0]
else:
det_data_key = self.det_data_keys[imap]
ref = ob.detdata[det_data_key][det]
nside = hp.get_nside(map_value)
interp_pix, interp_weight = hp.pixelfunc.get_interp_weights(
nside, theta, phi, nest=False, lonlat=False,
)
sig = np.zeros_like(ref)
for inz, map_column in enumerate(map_value):
sig += weights[:, inz] * np.sum(
map_column[interp_pix] * interp_weight, 0
)
if self.subtract:
ref -= sig
else:
ref += sig

# Clean up our map, if needed
if not self.save_map:
for map_name in self.map_names:
del self.maps[map_name]

return

def _finalize(self, data, **kwargs):
return

def _requires(self):
req = self.detector_pointing.requires()
req.update(self.stokes_weights.requires())
return req

def _provides(self):
prov = {"global": list(), "detdata": [self.det_data]}
if self.save_map:
prov["global"] = self.map_names
return prov
1 change: 1 addition & 0 deletions src/toast/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ install(FILES
ops_pointing_wcs.py
ops_memory_counter.py
ops_scan_map.py
ops_interpolate_healpix.py
ops_scan_healpix.py
ops_scan_wcs.py
ops_madam.py
Expand Down
90 changes: 90 additions & 0 deletions src/toast/tests/ops_interpolate_healpix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2024 by the parties listed in the AUTHORS file.
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

import os

import healpy as hp
import numpy as np
import numpy.testing as nt
from astropy import units as u

from .. import ops as ops
from ..observation import default_values as defaults
from ..pixels_io_healpix import write_healpix_fits, write_healpix_hdf5
from ._helpers import (
close_data,
create_fake_mask,
create_fake_sky,
create_outdir,
create_satellite_data,
)
from .mpi import MPITestCase


class InterpolateHealpixTest(MPITestCase):
def setUp(self):
fixture_name = os.path.splitext(os.path.basename(__file__))[0]
self.outdir = create_outdir(self.comm, fixture_name)
np.random.seed(123456)

def test_interpolate(self):
# Create a fake satellite data set for testing
data = create_satellite_data(self.comm)

# Create some detector pointing matrices
detpointing = ops.PointingDetectorSimple()
pixels = ops.PixelsHealpix(
nside=256,
create_dist="pixel_dist",
detector_pointing=detpointing,
)
pixels.apply(data)
weights = ops.StokesWeights(
mode="IQU",
hwp_angle=defaults.hwp_angle,
detector_pointing=detpointing,
)
weights.apply(data)

hpix_file = os.path.join(self.outdir, "fake.fits")
if data.comm.comm_world is None or data.comm.comm_world.rank == 0:
# Create a smooth sky
lmax = 3 * pixels.nside
cls = np.ones([4, lmax + 1])
np.random.seed(98776)
fake_sky = hp.synfast(cls, pixels.nside, fwhm=np.radians(30))
# Write this to a file
hp.write_map(hpix_file, fake_sky)

# Scan the map from the file

scan_hpix = ops.ScanHealpixMap(
file=hpix_file,
det_data="scan_data",
pixel_pointing=pixels,
stokes_weights=weights,
)
scan_hpix.apply(data)

# Interpolate the map from the file

interp_hpix = ops.InterpolateHealpixMap(
file=hpix_file,
det_data="interp_data",
detector_pointing=detpointing,
stokes_weights=weights,
)
interp_hpix.apply(data)

# Check that the sets of timestreams match.

for ob in data.obs:
for det in ob.select_local_detectors(flagmask=defaults.det_mask_invalid):
np.testing.assert_almost_equal(
ob.detdata["scan_data"][det],
ob.detdata["interp_data"][det],
decimal=1,
)

close_data(data)
2 changes: 2 additions & 0 deletions src/toast/tests/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from . import ops_pointing_healpix as test_ops_pointing_healpix
from . import ops_pointing_wcs as test_ops_pointing_wcs
from . import ops_polyfilter as test_ops_polyfilter
from . import ops_interpolate_healpix as test_ops_interpolate_healpix
from . import ops_scan_healpix as test_ops_scan_healpix
from . import ops_scan_map as test_ops_scan_map
from . import ops_scan_wcs as test_ops_scan_wcs
Expand Down Expand Up @@ -199,6 +200,7 @@ def test(name=None, verbosity=2):
suite.addTest(loader.loadTestsFromModule(test_ops_mapmaker_solve))
suite.addTest(loader.loadTestsFromModule(test_ops_mapmaker))
suite.addTest(loader.loadTestsFromModule(test_ops_scan_map))
suite.addTest(loader.loadTestsFromModule(test_ops_interpolate_healpix))
suite.addTest(loader.loadTestsFromModule(test_ops_scan_healpix))
suite.addTest(loader.loadTestsFromModule(test_ops_scan_wcs))
suite.addTest(loader.loadTestsFromModule(test_ops_madam))
Expand Down