Skip to content

Commit

Permalink
implement frequency fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Wilensky committed Feb 2, 2024
1 parent 0a5ca70 commit 3244cdc
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions SSINS/incoherent_noise_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def set_extra_params(self, time_order=0, freq_order=None,
Nsb = len(self.subband_freq_chans)
self.Nsubband = Nsb
"""Number of subbands"""
self.Nfreq_sb = self.Nfreqs // self.Nsubband
"""Number of frequencies per subband"""

self.use_integration_weights = use_integration_weights
"""Whether to use integration time to weight the spectrum"""
Expand Down Expand Up @@ -311,9 +313,8 @@ def get_dmatr(self):
if self.freq_order is not None:
Npoly_freq = self.freq_order + 1
# FIXME: Ragged subbands won't handle this nicely
Nfreq_sb = self.Nfreqs // self.Nsubband
f = np.linspace(-1, 1, num=Nfreq_sb)
fmatr = np.zeros(Nfreq_sb, Npoly_freq)
f = np.linspace(-1, 1, num=self.Nfreq_sb)
fmatr = np.zeros(self.Nfreq_sb, Npoly_freq)
for order in range(Npoly_freq):
fmatr[:, order] = legendre(order)(f)
else:
Expand Down Expand Up @@ -454,19 +455,48 @@ def mean_subtract(self, freq_slice=slice(None), return_coeffs=False):

wt_data = wt * data # shape tfp

ttmatr = tmatr[:, np.newaxis] * tmatr[:, :, np.newaxis] # shape tAa

if fmatr is None: # Separates over frequency

# make the left-hand-side of lsq operator
ttmatr = tmatr[:, np.newaxis] * tmatr[:, :, np.newaxis] # shape tAa
lhs_op = np.tensordot(w, ttmatr, axes=((0, ), (0, ))) # shape fpAa
# make the operator on the left-hand-side of normal equations
lhs_op = np.tensordot(wt, ttmatr, axes=((0, ), (0, ))) # shape fpAa

# Make the vector on the rhs
rhs_vec = np.tensordot(tmatr.T, wt_data, axes=1) # shape afp
rhs_vec = np.tensordot(wt_data, tmatr, axes=((0, ), (0, ))) # shape fpa

soln = np.linalg.solve(lhs_op, rhs_vec) # shape afp
fitspec = np.tensordot(tmatr, soln)
soln = np.linalg.solve(lhs_op, rhs_vec) # shape fpa
fitspec = np.tensordot(tmatr, soln, axes=((-1,), (-1,))) # shape tfp
else:
raise NotImplementedError("Frequency fitting not available yet.")

new_shape = (self.Ntimes, self.Nsubband, self.Nfreq_sb,
self.Npols)

# Make RHS vec by multiplying by design matrix transpose
wt_data_res = wt_data.reshape(new_shape)
rhs_tmult = np.tensordot(wt_data_res, tmatr, axes=((0, ), (0, ))) # shape Nwpa
rhs_vec = np.tensordot(rhs_tmult, fmatr, axes=((1,), (0,))) # shape Npab
Ncoeff = (self.time_order + 1) * (self.freq_order + 1)
rhs_vec = rhs_vec.reshape(self.Nsubband, self.Npols, Ncoeff)

# Make the lhs_op as above but with extra steps for freq axis
wt_res = wt.reshape(new_shape)
ffmatr = fmatr[:, np.newaxis] * fmatr[:, :, np.newaxis] # shape fBb

lhs_tmult = np.tensordot(wt_res, ttmatr, axes=((0,), (0,))) # shape NwpAa
lhs_op = np.tensordot(lhs_tmult, ffmatr, axes=((1,), (0, ))) # shape NpAaBb
lhs_op = lhs_op.swapaxes(3, 4) # shape NpABab
lhs_op = lhs_op.reshape(self.Nsubband, self.Npols, Ncoeff, Ncoeff)

soln = np.linalg.solve(lhs_op, rhs_vec)
soln = soln.reshape(self.Nsubband, self.Npols,
self.time_order + 1, self.freq_order + 1)
fitspec_tmult = np.tensordot(tmatr, soln, axes=((1, ), (2, ))) # Shape tNpb
fitspec_res = np.tensordot(fmatr, fitspec_tmult, axes=((1, ), (3, ))) # shape wtNp
fitspec_res = fitspec_res.transpose(1, 2, 0, 3)

fitspec = fitspec.reshape(self.Ntimes, self.Nfreqs, self.Npols)

MS = (self.metric_array / fitspec - 1) * weights_factor
else: # Whole slice has been flagged. Don't rely on solve returning 0.
MS[:] = np.ma.masked
Expand Down

0 comments on commit 3244cdc

Please sign in to comment.