diff --git a/hera_cal/abscal.py b/hera_cal/abscal.py index d88db16d8..40c594a0d 100644 --- a/hera_cal/abscal.py +++ b/hera_cal/abscal.py @@ -3179,7 +3179,7 @@ def abscal_step(gains_to_update, AC, AC_func, AC_kwargs, gain_funcs, gain_args_l def match_baselines(data_bls, model_bls, data_antpos, model_antpos=None, pols=[], data_is_redsol=False, - model_is_redundant=False, tol=1.0, min_bl_cut=None, max_bl_cut=None, max_dims=2, verbose=False): + model_is_redundant=False, tol=1.0, min_bl_cut=None, max_bl_cut=None, max_dims=2, autos_only_abs_amp_logcal=False, verbose=False): '''Figure out which baselines to use in the data and the model for abscal and their correspondence. Arguments: @@ -3212,6 +3212,13 @@ def match_baselines(data_bls, model_bls, data_antpos, model_antpos=None, pols=[] data_bl_to_load = set(utils.filter_bls(data_bls, pols=pols, antpos=data_antpos, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut)) model_bl_to_load = set(utils.filter_bls(model_bls, pols=pols, antpos=model_antpos, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut)) +#Add autos to bl_to_load if we are running with autos_only_abs_amp_logcal: + if autos_only_abs_amp_logcal: + autos_data = [bl for bl in data_bls if bl[0]==bl[1] and bl[2] in pols] + autos_model = [bl for bl in model_bls if bl[0]==bl[1] and bl[2] in pols] + data_bl_to_load = set(list(data_bl_to_load) + autos_data) + model_bl_to_load = set(list(model_bl_to_load) + autos_model) + # If we're working with full data sets, only pick out matching keys (or ones that work reversably) if not data_is_redsol and not model_is_redundant: data_bl_to_load = [bl for bl in data_bl_to_load if (bl in model_bl_to_load) or (reverse_bl(bl) in model_bl_to_load)] @@ -3225,7 +3232,7 @@ def match_baselines(data_bls, model_bls, data_antpos, model_antpos=None, pols=[] # increase all antenna indices in the model by model_offset to distinguish them from data antennas model_offset = np.max(list(data_antpos.keys())) + 1 joint_antpos = {**data_antpos, **{ant + model_offset: pos for ant, pos in model_antpos.items()}} - joint_reds = redcal.get_reds(joint_antpos, pols=pols, bl_error_tol=tol) + joint_reds = redcal.get_reds(joint_antpos, pols=pols, bl_error_tol=tol, include_autos=autos_only_abs_amp_logcal) # filter out baselines not in data or model or between data and model joint_reds = [[bl for bl in red if not ((bl[0] < model_offset) ^ (bl[1] < model_offset))] for red in joint_reds] @@ -3405,7 +3412,7 @@ def _get_idealized_antpos(cal_flags, antpos, pols, tol=1.0, keep_flagged_ants=Tr def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, kernel=(1, 15), - phs_max_iter=100, phs_conv_crit=1e-6, verbose=True, use_abs_amp_lincal=True): + phs_max_iter=100, phs_conv_crit=1e-6, verbose=True, autos_only_abs_amp_logcal=False, use_abs_amp_lincal=True): '''Performs Abscal for data that has already been redundantly calibrated. Arguments: @@ -3436,29 +3443,62 @@ def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, ke data_wgts=data_wgts, tol=tol, keep_flagged_ants=True) reds = redcal.get_reds(idealized_antpos, pols=data.pols(), bl_error_tol=redcal.IDEALIZED_BL_TOL) + # separate autos and crosses if running amplitude calibration with autos only: + if autos_only_abs_amp_logcal: + autos = [antpol for antpol in data.keys() if antpol[0] == antpol[1]] + autos_data = DataContainer({antpol:data[antpol] for antpol in autos}) + autos_data_wgts = DataContainer({antpol:data_wgts[antpol] for antpol in autos}) + autos_model = DataContainer({antpol:model[antpol] for antpol in autos}) + + #separate crosses from autos: + crosses = [antpol for antpol in data.keys() if antpol[0] != antpol[1]] + crosses_data = DataContainer({antpol:data[antpol] for antpol in crosses}) + crosses_data_wgts = DataContainer({antpol:data_wgts[antpol] for antpol in crosses}) + crosses_model = DataContainer({antpol:model[antpol] for antpol in crosses}) + # Abscal Step 1: Per-Channel Logarithmic Absolute Amplitude Calibration - gains_here = abs_amp_logcal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants) + if autos_only_abs_amp_logcal: + gains_here = abs_amp_logcal(autos_model, autos_data, wgts=autos_data_wgts, verbose=verbose, return_gains=True, gain_ants=ants) + else: + gains_here = abs_amp_logcal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants) + abscal_delta_gains = {ant: gains_here[ant] for ant in ants} apply_cal.calibrate_in_place(data, gains_here) + # Abscal Step 2: Global Delay Slope Calibration - binary_wgts = DataContainer({bl: (data_wgts[bl] > 0).astype(float) for bl in data_wgts}) + if autos_only_abs_amp_logcal: + binary_wgts = DataContainer({bl: (data_wgts[bl] > 0).astype(float) for bl in crosses_data_wgts}) + else: + binary_wgts = DataContainer({bl: (data_wgts[bl] > 0).astype(float) for bl in data_wgts}) df = np.median(np.diff(data.freqs)) for time_avg in [True, False]: # first use the time-averaged solution to try to avoid false minima - gains_here = delay_slope_lincal(model, data, idealized_antpos, wgts=binary_wgts, df=df, f0=data.freqs[0], medfilt=True, kernel=kernel, + if autos_only_abs_amp_logcal: + gains_here = delay_slope_lincal(crosses_model, crosses_data, idealized_antpos, wgts=binary_wgts, df=df, f0=data.freqs[0], medfilt=True, kernel=kernel, + assume_2D=False, time_avg=time_avg, verbose=verbose, edge_cut=edge_cut, return_gains=True, gain_ants=ants) + else: + gains_here = delay_slope_lincal(model, data, idealized_antpos, wgts=binary_wgts, df=df, f0=data.freqs[0], medfilt=True, kernel=kernel, assume_2D=False, time_avg=time_avg, verbose=verbose, edge_cut=edge_cut, return_gains=True, gain_ants=ants) abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants} apply_cal.calibrate_in_place(data, gains_here) # Abscal Step 3: Global Phase Slope Calibration (first using ndim_fft, then using linfit) for time_avg in [True, False]: - gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='ndim_fft', wgts=binary_wgts, verbose=verbose, assume_2D=False, + if autos_only_abs_amp_logcal: + gains_here = global_phase_slope_logcal(crosses_model, crosses_data, idealized_antpos, reds=reds, solver='ndim_fft', wgts=binary_wgts, verbose=verbose, assume_2D=False, + tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants) + else: + gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='ndim_fft', wgts=binary_wgts, verbose=verbose, assume_2D=False, tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants) abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants} apply_cal.calibrate_in_place(data, gains_here) for time_avg in [True, False]: for i in range(phs_max_iter): - gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='linfit', wgts=binary_wgts, verbose=verbose, assume_2D=False, + if autos_only_abs_amp_logcal: + gains_here = global_phase_slope_logcal(crosses_model, crosses_data, idealized_antpos, reds=reds, solver='linfit', wgts=binary_wgts, verbose=verbose, assume_2D=False, + tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants) + else: + gains_here = global_phase_slope_logcal(model, data, idealized_antpos, reds=reds, solver='linfit', wgts=binary_wgts, verbose=verbose, assume_2D=False, tol=redcal.IDEALIZED_BL_TOL, edge_cut=edge_cut, time_avg=time_avg, return_gains=True, gain_ants=ants) abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants} apply_cal.calibrate_in_place(data, gains_here) @@ -3468,11 +3508,17 @@ def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, ke break # Abscal Step 4: Per-Channel Tip-Tilt Phase Calibration - angle_wgts = DataContainer({bl: 2 * np.abs(model[bl])**2 * data_wgts[bl] for bl in model}) + if autos_only_abs_amp_logcal: + angle_wgts = DataContainer({bl: 2 * np.abs(model[bl])**2 * data_wgts[bl] for bl in crosses_model}) + else: + angle_wgts = DataContainer({bl: 2 * np.abs(model[bl])**2 * data_wgts[bl] for bl in model}) # This is because, in the high SNR limit, if Var(model) = 0 and Var(data) = Var(noise), # then Var(angle(data / model)) = Var(noise) / (2 |model|^2). Here data_wgts = Var(noise)^-1. for i in range(phs_max_iter): - gains_here = TT_phs_logcal(model, data, idealized_antpos, wgts=angle_wgts, verbose=verbose, assume_2D=False, return_gains=True, gain_ants=ants) + if autos_only_abs_amp_logcal: + gains_here = TT_phs_logcal(crosses_model, crosses_data, idealized_antpos, wgts=angle_wgts, verbose=verbose, assume_2D=False, return_gains=True, gain_ants=ants) + else: + gains_here = TT_phs_logcal(model, data, idealized_antpos, wgts=angle_wgts, verbose=verbose, assume_2D=False, return_gains=True, gain_ants=ants) abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants} apply_cal.calibrate_in_place(data, gains_here) crit = np.median(np.linalg.norm([gains_here[k] - 1.0 for k in gains_here.keys()], axis=(0, 1))) @@ -3482,16 +3528,19 @@ def post_redcal_abscal(model, data, data_wgts, rc_flags, edge_cut=0, tol=1.0, ke # Abscal Step 5: Per-Channel Linear Absolute Amplitude Calibration if use_abs_amp_lincal: - gains_here = abs_amp_lincal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants) - abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants} + if autos_only_abs_amp_logcal: + echo("Skipping abs_amp_lincal. autos_only_abs_amp_logcal=True inconsistent with lincal", verbose=verbose) + else: + gains_here = abs_amp_lincal(model, data, wgts=data_wgts, verbose=verbose, return_gains=True, gain_ants=ants) + abscal_delta_gains = {ant: abscal_delta_gains[ant] * gains_here[ant] for ant in ants} return abscal_delta_gains def post_redcal_abscal_run(data_file, redcal_file, model_files, raw_auto_file=None, data_is_redsol=False, model_is_redundant=False, output_file=None, nInt_to_load=None, data_solar_horizon=90, model_solar_horizon=90, extrap_limit=.5, min_bl_cut=1.0, max_bl_cut=None, - edge_cut=0, tol=1.0, phs_max_iter=100, phs_conv_crit=1e-6, refant=None, clobber=True, add_to_history='', verbose=True, skip_abs_amp_lincal=False, - write_delta_gains=False, output_file_delta=None): + edge_cut=0, tol=1.0, phs_max_iter=100, phs_conv_crit=1e-6, refant=None, clobber=True, add_to_history='', verbose=True, + autos_only_abs_amp_logcal=False, skip_abs_amp_lincal=False, write_delta_gains=False, output_file_delta=None): '''Perform abscal on entire data files, picking relevant model_files from a list and doing partial data loading. Does not work on data (or models) with baseline-dependant averaging. @@ -3608,7 +3657,8 @@ def post_redcal_abscal_run(data_file, redcal_file, model_files, raw_auto_file=No model_bl_to_load, data_to_model_bl_map) = match_baselines(hd.bls, model_bls, hd.data_antpos, model_antpos=model_antpos, pols=[pol], data_is_redsol=data_is_redsol, model_is_redundant=model_is_redundant, - tol=tol, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, verbose=verbose) + tol=tol, min_bl_cut=min_bl_cut, max_bl_cut=max_bl_cut, + autos_only_abs_amp_logcal=autos_only_abs_amp_logcal, verbose=verbose) if (len(data_bl_to_load) == 0) or (len(model_bl_to_load) == 0): echo("No baselines in the data match baselines in the model. Results for this polarization will be fully flagged.", verbose=verbose) else: @@ -3657,7 +3707,8 @@ def post_redcal_abscal_run(data_file, redcal_file, model_files, raw_auto_file=No # run absolute calibration to get the gain updates delta_gains = post_redcal_abscal(model, data, data_wgts, rc_flags_subset, edge_cut=edge_cut, tol=tol, - phs_max_iter=phs_max_iter, phs_conv_crit=phs_conv_crit, verbose=verbose, use_abs_amp_lincal=not(skip_abs_amp_lincal)) + phs_max_iter=phs_max_iter, phs_conv_crit=phs_conv_crit, verbose=verbose, + autos_only_abs_amp_logcal=autos_only_abs_amp_logcal, use_abs_amp_lincal=not(skip_abs_amp_lincal)) # abscal autos, rebuild weights, and generate abscal Chi^2 calibrate_in_place(autocorrs, delta_gains) @@ -4005,6 +4056,7 @@ def post_redcal_abscal_argparser(): a.add_argument("--phs_conv_crit", default=1e-6, type=float, help="convergence criterion for updates to iterative phase calibration that compares them to all 1.0s.") a.add_argument("--clobber", default=False, action="store_true", help="overwrites existing abscal calfits file at the output path") a.add_argument("--verbose", default=False, action="store_true", help="print calibration progress updates") + a.add_argument("--autos_only_abs_amp_logcal", default=False, action="store_true", help="run amplitude calibration with only autos") a.add_argument("--skip_abs_amp_lincal", default=False, action="store_true", help="finish calibration with an unbiased amplitude lincal step") a.add_argument("--write_delta_gains", default=False, action="store_true", help="Write degenerate abscal component of gains separately.") a.add_argument("--output_file_delta", type=str, default=None, help="Filename to write delta gains too.") diff --git a/scripts/post_redcal_abscal_run.py b/scripts/post_redcal_abscal_run.py index fa128dabb..3c88da5a4 100755 --- a/scripts/post_redcal_abscal_run.py +++ b/scripts/post_redcal_abscal_run.py @@ -17,4 +17,5 @@ data_solar_horizon=a.data_solar_horizon, model_solar_horizon=a.model_solar_horizon, min_bl_cut=a.min_bl_cut, max_bl_cut=a.max_bl_cut, edge_cut=a.edge_cut, tol=a.tol, phs_max_iter=a.phs_max_iter, phs_conv_crit=a.phs_conv_crit, clobber=a.clobber, - add_to_history=' '.join(sys.argv), verbose=a.verbose, skip_abs_amp_lincal=a.skip_abs_amp_lincal) + add_to_history=' '.join(sys.argv), verbose=a.verbose, + autos_only_abs_amp_logcal = a.autos_only_abs_amp_logcal, skip_abs_amp_lincal=a.skip_abs_amp_lincal)