Skip to content

Commit

Permalink
Update min_lead file cmd line option
Browse files Browse the repository at this point in the history
  • Loading branch information
stellema committed Sep 26, 2024
1 parent 164216b commit 6718c62
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 16 deletions.
9 changes: 6 additions & 3 deletions unseen/bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,13 @@ def _main():
if args.min_lead:
if isinstance(args.min_lead, str):
# Load min_lead from file
min_lead = fileio.open_dataset(args.min_lead, **args.min_lead_kwargs)
ds_min_lead = fileio.open_dataset(args.min_lead, **args.min_lead_kwargs)
min_lead = ds_min_lead["min_lead"].load()
# Assumes min_lead has only one init month
assert min_lead.month.size == 1, "Not implemented for multiple init months"
min_lead = min_lead.drop_vars("month")
min_lead = min_lead["min_lead"].load()

else:
min_lead = args.min_lead
da_fcst = da_fcst.where(da_fcst >= min_lead)

# Calculate bias
Expand All @@ -291,6 +292,8 @@ def _main():
args.fcst_file: ds_fcst.attrs["history"],
args.obs_file: ds_obs.attrs["history"],
}
if isinstance(args.min_lead, str):
infile_logs[args.min_lead] = ds_min_lead.attrs["history"]
ds_fcst_bc.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs)

if args.output_chunks:
Expand Down
86 changes: 73 additions & 13 deletions unseen/moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import numpy as np
import scipy

from . import fileio
from . import eva
from . import fileio
from . import general_utils


logging.basicConfig(level=logging.INFO)
Expand All @@ -21,7 +22,18 @@


def calc_ci(data):
"""Calculate the 95% confidence interval"""
"""Calculate the 95% confidence interval
Parameters
----------
data : list
List of data values
Returns
-------
lower_ci, upper_ci : float
Lower and upper confidence interval bounds
"""

lower_ci = np.percentile(np.array(data), 2.5, axis=0)
upper_ci = np.percentile(np.array(data), 97.5, axis=0)
Expand All @@ -30,7 +42,20 @@ def calc_ci(data):


def calc_moments(sample_da, **kwargs):
"""Calculate all the moments for a given sample."""
"""Calculate all the moments for a given sample.
Parameters
----------
sample_da : xarray.DataArray
Sample data array
kwargs : dict
Keyword arguments for the GEV fit
Returns
-------
moments : dict
Dictionary of moments
"""

moments = {}
moments["mean"] = float(np.mean(sample_da))
Expand All @@ -46,7 +71,24 @@ def calc_moments(sample_da, **kwargs):


def log_results(moments_obs, model_lower_cis, model_upper_cis, bias_corrected=False):
"""Log the results"""
"""Log the moments test results.
Parameters
----------
moments_obs : dict
Dictionary of observed moments
model_lower_cis : dict
Dictionary of model lower confidence intervals
model_upper_cis : dict
Dictionary of model upper confidence intervals
bias_corrected : bool, default False
Flag for bias corrected model
Returns
-------
metadata : dict
Dictionary of logged metadata
"""

if bias_corrected:
text_insert = "Bias corrected model"
Expand Down Expand Up @@ -91,12 +133,12 @@ def create_plot(
outfile : str, optional
Path for output image file
units : str, optional
units for plot axis labels
ensemble_dim : str, default ensemble
Units for plot axis labels
ensemble_dim : str, default 'ensemble'
Name of ensemble member dimension
init_dim : str, default init_date
init_dim : str, default 'init_date'
Name of initial date dimension
lead_dim : str, default lead_time
lead_dim : str, default 'lead_time'
Name of lead time dimension
infile_logs : dict, optional
File names (keys) and history attributes (values) of input data files
Expand Down Expand Up @@ -254,7 +296,7 @@ def create_plot(


def _parse_command_line():
"""Parse the command line for input agruments"""
"""Parse the command line for input arguments."""

parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
Expand Down Expand Up @@ -285,9 +327,14 @@ def _parse_command_line():
)
parser.add_argument(
"--min_lead",
type=int,
default=None,
help="Minimum lead time",
help="Minimum lead time to include in analysis (int or filename)",
)
parser.add_argument(
"--min_lead_kwargs",
nargs="*",
action=general_utils.store_dict,
help="Optional fileio.open_dataset kwargs for lead independence (e.g., spatial_agg=median)",
)
parser.add_argument(
"--units",
Expand All @@ -307,8 +354,19 @@ def _main():

ds_fcst = fileio.open_dataset(args.fcst_file)
da_fcst = ds_fcst[args.var]
if args.min_lead is not None:
da_fcst = da_fcst.where(ds_fcst[args.lead_dim] >= args.min_lead)

# Mask lead times below min_lead
if args.min_lead:
if isinstance(args.min_lead, str):
# Load min_lead from file
ds_min_lead = fileio.open_dataset(args.min_lead, **args.min_lead_kwargs)
min_lead = ds_min_lead["min_lead"].load()
# Assumes min_lead has only one init month
assert min_lead.month.size == 1, "Not implemented for multiple init months"
min_lead = min_lead.drop_vars("month")
else:
min_lead = args.min_lead
da_fcst = da_fcst.where(da_fcst >= min_lead)

ds_obs = fileio.open_dataset(args.obs_file)
da_obs = ds_obs[args.var].dropna("time")
Expand All @@ -327,6 +385,8 @@ def _main():
infile_logs[args.bias_file] = ds_bc_fcst.attrs["history"]
else:
infile_logs[args.fcst_file] = ds_fcst.attrs["history"]
if isinstance(args.min_lead, str):
infile_logs[args.min_lead] = ds_min_lead.attrs["history"]
else:
infile_logs = None

Expand Down

0 comments on commit 6718c62

Please sign in to comment.