From 6718c62c3dddb38517cf4e1d82bd364f18c969a4 Mon Sep 17 00:00:00 2001 From: Annette Stellema <40450353+stellema@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:15:12 +1000 Subject: [PATCH] Update min_lead file cmd line option --- unseen/bias_correction.py | 9 ++-- unseen/moments.py | 86 +++++++++++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/unseen/bias_correction.py b/unseen/bias_correction.py index 965ed88..83e9a9a 100644 --- a/unseen/bias_correction.py +++ b/unseen/bias_correction.py @@ -261,12 +261,13 @@ def _main(): if args.min_lead: if isinstance(args.min_lead, str): # Load min_lead from file - min_lead = fileio.open_dataset(args.min_lead, **args.min_lead_kwargs) + ds_min_lead = fileio.open_dataset(args.min_lead, **args.min_lead_kwargs) + min_lead = ds_min_lead["min_lead"].load() # Assumes min_lead has only one init month assert min_lead.month.size == 1, "Not implemented for multiple init months" min_lead = min_lead.drop_vars("month") - min_lead = min_lead["min_lead"].load() - + else: + min_lead = args.min_lead da_fcst = da_fcst.where(da_fcst >= min_lead) # Calculate bias @@ -291,6 +292,8 @@ def _main(): args.fcst_file: ds_fcst.attrs["history"], args.obs_file: ds_obs.attrs["history"], } + if isinstance(args.min_lead, str): + infile_logs[args.min_lead] = ds_min_lead.attrs["history"] ds_fcst_bc.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs) if args.output_chunks: diff --git a/unseen/moments.py b/unseen/moments.py index 46caa02..532e3bf 100644 --- a/unseen/moments.py +++ b/unseen/moments.py @@ -8,8 +8,9 @@ import numpy as np import scipy -from . import fileio from . import eva +from . import fileio +from . import general_utils logging.basicConfig(level=logging.INFO) @@ -21,7 +22,18 @@ def calc_ci(data): - """Calculate the 95% confidence interval""" + """Calculate the 95% confidence interval + + Parameters + ---------- + data : list + List of data values + + Returns + ------- + lower_ci, upper_ci : float + Lower and upper confidence interval bounds + """ lower_ci = np.percentile(np.array(data), 2.5, axis=0) upper_ci = np.percentile(np.array(data), 97.5, axis=0) @@ -30,7 +42,20 @@ def calc_ci(data): def calc_moments(sample_da, **kwargs): - """Calculate all the moments for a given sample.""" + """Calculate all the moments for a given sample. + + Parameters + ---------- + sample_da : xarray.DataArray + Sample data array + kwargs : dict + Keyword arguments for the GEV fit + + Returns + ------- + moments : dict + Dictionary of moments + """ moments = {} moments["mean"] = float(np.mean(sample_da)) @@ -46,7 +71,24 @@ def calc_moments(sample_da, **kwargs): def log_results(moments_obs, model_lower_cis, model_upper_cis, bias_corrected=False): - """Log the results""" + """Log the moments test results. + + Parameters + ---------- + moments_obs : dict + Dictionary of observed moments + model_lower_cis : dict + Dictionary of model lower confidence intervals + model_upper_cis : dict + Dictionary of model upper confidence intervals + bias_corrected : bool, default False + Flag for bias corrected model + + Returns + ------- + metadata : dict + Dictionary of logged metadata + """ if bias_corrected: text_insert = "Bias corrected model" @@ -91,12 +133,12 @@ def create_plot( outfile : str, optional Path for output image file units : str, optional - units for plot axis labels - ensemble_dim : str, default ensemble + Units for plot axis labels + ensemble_dim : str, default 'ensemble' Name of ensemble member dimension - init_dim : str, default init_date + init_dim : str, default 'init_date' Name of initial date dimension - lead_dim : str, default lead_time + lead_dim : str, default 'lead_time' Name of lead time dimension infile_logs : dict, optional File names (keys) and history attributes (values) of input data files @@ -254,7 +296,7 @@ def create_plot( def _parse_command_line(): - """Parse the command line for input agruments""" + """Parse the command line for input arguments.""" parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -285,9 +327,14 @@ def _parse_command_line(): ) parser.add_argument( "--min_lead", - type=int, default=None, - help="Minimum lead time", + help="Minimum lead time to include in analysis (int or filename)", + ) + parser.add_argument( + "--min_lead_kwargs", + nargs="*", + action=general_utils.store_dict, + help="Optional fileio.open_dataset kwargs for lead independence (e.g., spatial_agg=median)", ) parser.add_argument( "--units", @@ -307,8 +354,19 @@ def _main(): ds_fcst = fileio.open_dataset(args.fcst_file) da_fcst = ds_fcst[args.var] - if args.min_lead is not None: - da_fcst = da_fcst.where(ds_fcst[args.lead_dim] >= args.min_lead) + + # Mask lead times below min_lead + if args.min_lead: + if isinstance(args.min_lead, str): + # Load min_lead from file + ds_min_lead = fileio.open_dataset(args.min_lead, **args.min_lead_kwargs) + min_lead = ds_min_lead["min_lead"].load() + # Assumes min_lead has only one init month + assert min_lead.month.size == 1, "Not implemented for multiple init months" + min_lead = min_lead.drop_vars("month") + else: + min_lead = args.min_lead + da_fcst = da_fcst.where(da_fcst >= min_lead) ds_obs = fileio.open_dataset(args.obs_file) da_obs = ds_obs[args.var].dropna("time") @@ -327,6 +385,8 @@ def _main(): infile_logs[args.bias_file] = ds_bc_fcst.attrs["history"] else: infile_logs[args.fcst_file] = ds_fcst.attrs["history"] + if isinstance(args.min_lead, str): + infile_logs[args.min_lead] = ds_min_lead.attrs["history"] else: infile_logs = None