Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autos only abs amp logcal #832

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions hera_cal/abscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3179,7 +3179,7 @@ def abscal_step(gains_to_update, AC, AC_func, AC_kwargs, gain_funcs, gain_args_l


def match_baselines(data_bls, model_bls, data_antpos, model_antpos=None, pols=[], data_is_redsol=False,
model_is_redundant=False, tol=1.0, min_bl_cut=None, max_bl_cut=None, max_dims=2, verbose=False):
model_is_redundant=False, tol=1.0, min_bl_cut=None, max_bl_cut=None, max_dims=2, autos_only_abs_amp_logcal=False, verbose=False):
'''Figure out which baselines to use in the data and the model for abscal and their correspondence.

Arguments:
Expand Down Expand Up @@ -3212,6 +3212,13 @@ def match_baselines(data_bls, model_bls, data_antpos, model_antpos=None, pols=[]
data_bl_to_load = set(utils.filter_bls(data_bls, pols=pols, antpos=data_antpos, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut))
model_bl_to_load = set(utils.filter_bls(model_bls, pols=pols, antpos=model_antpos, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut))

#Add autos to bl_to_load if we are running with autos_only_abs_amp_logcal:
if autos_only_abs_amp_logcal:
autos_data = [bl for bl in data_bls if bl[0]==bl[1] and bl[2] in pols]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be careful about using bl[0] == bl[1] for finding autos, since you might end up identifying (1, 1, 'ne') as an auto, which it's not. Instead use utils.split_bl()

autos_model = [bl for bl in model_bls if bl[0]==bl[1] and bl[2] in pols]
data_bl_to_load = set(list(data_bl_to_load) + autos_data)
model_bl_to_load = set(list(model_bl_to_load) + autos_model)

# If we're working with full data sets, only pick out matching keys (or ones that work reversably)
if not data_is_redsol and not model_is_redundant:
data_bl_to_load = [bl for bl in data_bl_to_load if (bl in model_bl_to_load) or (reverse_bl(bl) in model_bl_to_load)]
Expand All @@ -3225,7 +3232,7 @@ def match_baselines(data_bls, model_bls, data_antpos, model_antpos=None, pols=[]
# increase all antenna indices in the model by model_offset to distinguish them from data antennas
model_offset = np.max(list(data_antpos.keys())) + 1
joint_antpos = {**data_antpos, **{ant + model_offset: pos for ant, pos in model_antpos.items()}}
joint_reds = redcal.get_reds(joint_antpos, pols=pols, bl_error_tol=tol)
joint_reds = redcal.get_reds(joint_antpos, pols=pols, bl_error_tol=tol, include_autos=autos_only_abs_amp_logcal)

# filter out baselines not in data or model or between data and model
joint_reds = [[bl for bl in red if not ((bl[0] < model_offset) ^ (bl[1] < model_offset))] for red in joint_reds]
Expand Down Expand Up @@ -3405,7 +3412,7 @@ def _get_idealized_antpos(cal_flags, antpos, pols, tol=1.0, keep_flagged_ants=Tr


def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, kernel=(1, 15),
phs_max_iter=100, phs_conv_crit=1e-6, verbose=True, use_abs_amp_lincal=True):
phs_max_iter=100, phs_conv_crit=1e-6, verbose=True, autos_only_abs_amp_logcal=False, use_abs_amp_lincal=True):
'''Performs Abscal for data that has already been redundantly calibrated.

Arguments:
Expand Down Expand Up @@ -3436,29 +3443,62 @@ def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, ke
data_wgts=data_wgts, tol=tol, keep_flagged_ants=True)
reds = redcal.get_reds(idealized_antpos, pols=data.pols(), bl_error_tol=redcal.IDEALIZED_BL_TOL)

# separate autos and crosses if running amplitude calibration with autos only:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you've made this function way more complicated than it needs to be. Instead of splitting up the passed data and model, I think instead one should pass in a separate autos_model and autos_data, which default to None but must be specified if autos_only_abs_amp_logcal is True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense, and I will make this change

if autos_only_abs_amp_logcal:
autos = [antpol for antpol in data.keys() if antpol[0] == antpol[1]]
autos_data = DataContainer({antpol:data[antpol] for antpol in autos})
autos_data_wgts = DataContainer({antpol:data_wgts[antpol] for antpol in autos})
autos_model = DataContainer({antpol:model[antpol] for antpol in autos})

#separate crosses from autos:
crosses = [antpol for antpol in data.keys() if antpol[0] != antpol[1]]
crosses_data = DataContainer({antpol:data[antpol] for antpol in crosses})
crosses_data_wgts = DataContainer({antpol:data_wgts[antpol] for antpol in crosses})
crosses_model = DataContainer({antpol:model[antpol] for antpol in crosses})

# Abscal Step 1: Per-Channel Logarithmic Absolute Amplitude Calibration
gains_here = abs_amp_logcal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants)
if autos_only_abs_amp_logcal:
gains_here = abs_amp_logcal(autos_model, autos_data, wgts=autos_data_wgts, verbose=verbose, return_gains=True, gain_ants=ants)
else:
gains_here = abs_amp_logcal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants)

abscal_delta_gains = {ant: gains_here[ant] for ant in ants}
apply_cal.calibrate_in_place(data, gains_here)


# Abscal Step 2: Global Delay Slope Calibration
binary_wgts = DataContainer({bl: (data_wgts[bl] > 0).astype(float) for bl in data_wgts})
if autos_only_abs_amp_logcal:
binary_wgts = DataContainer({bl: (data_wgts[bl] > 0).astype(float) for bl in crosses_data_wgts})
else:
binary_wgts = DataContainer({bl: (data_wgts[bl] > 0).astype(float) for bl in data_wgts})
df = np.median(np.diff(data.freqs))
for time_avg in [True, False]: # first use the time-averaged solution to try to avoid false minima
gains_here = delay_slope_lincal(model, data, idealized_antpos, wgts=binary_wgts, df=df, f0=data.freqs[0], medfilt=True, kernel=kernel,
if autos_only_abs_amp_logcal:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a separate call to each calibration function depending on whether autos_only_abs_amp_logcal is true or makes the code a lot more complex and annoying to test and maintain.

gains_here = delay_slope_lincal(crosses_model, crosses_data, idealized_antpos, wgts=binary_wgts, df=df, f0=data.freqs[0], medfilt=True, kernel=kernel,
assume_2D=False, time_avg=time_avg, verbose=verbose, edge_cut=edge_cut, return_gains=True, gain_ants=ants)
else:
gains_here = delay_slope_lincal(model, data, idealized_antpos, wgts=binary_wgts, df=df, f0=data.freqs[0], medfilt=True, kernel=kernel,
assume_2D=False, time_avg=time_avg, verbose=verbose, edge_cut=edge_cut, return_gains=True, gain_ants=ants)
abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants}
apply_cal.calibrate_in_place(data, gains_here)

# Abscal Step 3: Global Phase Slope Calibration (first using ndim_fft, then using linfit)
for time_avg in [True, False]:
gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='ndim_fft', wgts=binary_wgts, verbose=verbose, assume_2D=False,
if autos_only_abs_amp_logcal:
gains_here = global_phase_slope_logcal(crosses_model, crosses_data, idealized_antpos, reds=reds, solver='ndim_fft', wgts=binary_wgts, verbose=verbose, assume_2D=False,
tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants)
else:
gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='ndim_fft', wgts=binary_wgts, verbose=verbose, assume_2D=False,
tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants)
abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants}
apply_cal.calibrate_in_place(data, gains_here)
for time_avg in [True, False]:
for i in range(phs_max_iter):
gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='linfit', wgts=binary_wgts, verbose=verbose, assume_2D=False,
if autos_only_abs_amp_logcal:
gains_here = global_phase_slope_logcal(crosses_model, crosses_data, idealized_antpos, reds=reds, solver='linfit', wgts=binary_wgts, verbose=verbose, assume_2D=False,
tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants)
else:
gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='linfit', wgts=binary_wgts, verbose=verbose, assume_2D=False,
tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants)
abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants}
apply_cal.calibrate_in_place(data, gains_here)
Expand All @@ -3468,11 +3508,17 @@ def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, ke
break

# Abscal Step 4: Per-Channel Tip-Tilt Phase Calibration
angle_wgts = DataContainer({bl: 2 * np.abs(model[bl])**2 * data_wgts[bl] for bl in model})
if autos_only_abs_amp_logcal:
angle_wgts = DataContainer({bl: 2 * np.abs(model[bl])**2 * data_wgts[bl] for bl in crosses_model})
else:
angle_wgts = DataContainer({bl: 2 * np.abs(model[bl])**2 * data_wgts[bl] for bl in model})
# This is because, in the high SNR limit, if Var(model) = 0 and Var(data) = Var(noise),
# then Var(angle(data / model)) = Var(noise) / (2 |model|^2). Here data_wgts = Var(noise)^-1.
for i in range(phs_max_iter):
gains_here = TT_phs_logcal(model, data, idealized_antpos, wgts=angle_wgts, verbose=verbose, assume_2D=False, return_gains=True, gain_ants=ants)
if autos_only_abs_amp_logcal:
gains_here = TT_phs_logcal(crosses_model, crosses_data, idealized_antpos, wgts=angle_wgts, verbose=verbose, assume_2D=False, return_gains=True, gain_ants=ants)
else:
gains_here = TT_phs_logcal(model, data, idealized_antpos, wgts=angle_wgts, verbose=verbose, assume_2D=False, return_gains=True, gain_ants=ants)
abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants}
apply_cal.calibrate_in_place(data, gains_here)
crit = np.median(np.linalg.norm([gains_here[k] - 1.0 for k in gains_here.keys()], axis=(0, 1)))
Expand All @@ -3482,16 +3528,19 @@ def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, ke

# Abscal Step 5: Per-Channel Linear Absolute Amplitude Calibration
if use_abs_amp_lincal:
gains_here = abs_amp_lincal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants)
abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants}
if autos_only_abs_amp_logcal:
echo("Skipping abs_amp_lincal. autos_only_abs_amp_logcal=True inconsistent with lincal", verbose=verbose)
else:
gains_here = abs_amp_lincal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants)
abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants}

return abscal_delta_gains


def post_redcal_abscal_run(data_file, redcal_file, model_files, raw_auto_file=None, data_is_redsol=False, model_is_redundant=False, output_file=None,
nInt_to_load=None, data_solar_horizon=90, model_solar_horizon=90, extrap_limit=.5, min_bl_cut=1.0, max_bl_cut=None,
edge_cut=0, tol=1.0, phs_max_iter=100, phs_conv_crit=1e-6, refant=None, clobber=True, add_to_history='', verbose=True, skip_abs_amp_lincal=False,
write_delta_gains=False, output_file_delta=None):
edge_cut=0, tol=1.0, phs_max_iter=100, phs_conv_crit=1e-6, refant=None, clobber=True, add_to_history='', verbose=True,
autos_only_abs_amp_logcal=False, skip_abs_amp_lincal=False, write_delta_gains=False, output_file_delta=None):
'''Perform abscal on entire data files, picking relevant model_files from a list and doing partial data loading.
Does not work on data (or models) with baseline-dependant averaging.

Expand Down Expand Up @@ -3608,7 +3657,8 @@ def post_redcal_abscal_run(data_file, redcal_file, model_files, raw_auto_file=No
model_bl_to_load,
data_to_model_bl_map) = match_baselines(hd.bls, model_bls, hd.data_antpos, model_antpos=model_antpos, pols=[pol],
data_is_redsol=data_is_redsol, model_is_redundant=model_is_redundant,
tol=tol, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, verbose=verbose)
tol=tol, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut,
autos_only_abs_amp_logcal=autos_only_abs_amp_logcal, verbose=verbose)
if (len(data_bl_to_load) == 0) or (len(model_bl_to_load) == 0):
echo("No baselines in the data match baselines in the model. Results for this polarization will be fully flagged.", verbose=verbose)
else:
Expand Down Expand Up @@ -3657,7 +3707,8 @@ def post_redcal_abscal_run(data_file, redcal_file, model_files, raw_auto_file=No

# run absolute calibration to get the gain updates
delta_gains = post_redcal_abscal(model, data, data_wgts, rc_flags_subset, edge_cut=edge_cut, tol=tol,
phs_max_iter=phs_max_iter, phs_conv_crit=phs_conv_crit, verbose=verbose, use_abs_amp_lincal=not(skip_abs_amp_lincal))
phs_max_iter=phs_max_iter, phs_conv_crit=phs_conv_crit, verbose=verbose,
autos_only_abs_amp_logcal=autos_only_abs_amp_logcal, use_abs_amp_lincal=not(skip_abs_amp_lincal))

# abscal autos, rebuild weights, and generate abscal Chi^2
calibrate_in_place(autocorrs, delta_gains)
Expand Down Expand Up @@ -4005,6 +4056,7 @@ def post_redcal_abscal_argparser():
a.add_argument("--phs_conv_crit", default=1e-6, type=float, help="convergence criterion for updates to iterative phase calibration that compares them to all 1.0s.")
a.add_argument("--clobber", default=False, action="store_true", help="overwrites existing abscal calfits file at the output path")
a.add_argument("--verbose", default=False, action="store_true", help="print calibration progress updates")
a.add_argument("--autos_only_abs_amp_logcal", default=False, action="store_true", help="run amplitude calibration with only autos")
a.add_argument("--skip_abs_amp_lincal", default=False, action="store_true", help="finish calibration with an unbiased amplitude lincal step")
a.add_argument("--write_delta_gains", default=False, action="store_true", help="Write degenerate abscal component of gains separately.")
a.add_argument("--output_file_delta", type=str, default=None, help="Filename to write delta gains too.")
Expand Down
3 changes: 2 additions & 1 deletion scripts/post_redcal_abscal_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
data_solar_horizon=a.data_solar_horizon, model_solar_horizon=a.model_solar_horizon,
min_bl_cut=a.min_bl_cut, max_bl_cut=a.max_bl_cut, edge_cut=a.edge_cut, tol=a.tol,
phs_max_iter=a.phs_max_iter, phs_conv_crit=a.phs_conv_crit, clobber=a.clobber,
add_to_history=' '.join(sys.argv), verbose=a.verbose, skip_abs_amp_lincal=a.skip_abs_amp_lincal)
add_to_history=' '.join(sys.argv), verbose=a.verbose,
autos_only_abs_amp_logcal = a.autos_only_abs_amp_logcal, skip_abs_amp_lincal=a.skip_abs_amp_lincal)