Skip to content

Commit

Permalink
Update independence.py
Browse files Browse the repository at this point in the history
  • Loading branch information
stellema committed Sep 12, 2024
1 parent aa904b3 commit 6409643
Showing 1 changed file with 70 additions and 58 deletions.
128 changes: 70 additions & 58 deletions unseen/independence.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@ def run_tests(
Returns
-------
mean_correlations : dict of
Mean correlation between all ensemble members for each lead time and initial month
null_correlation_bounds : dict of
Bounds on zero correlation for each lead time and initial month
ds : xarray Dataset
Correlation ensemble mean and the corresponding 95% confidence interval
"""

months = np.unique(fcst[init_dim].dt.month.values)
Expand Down Expand Up @@ -79,12 +76,37 @@ def run_tests(
lead_dim=lead_dim,
ensemble_dim=ensemble_dim,
)

return mean_correlations, null_correlation_bounds
# Create Dataset
ds_corr = xr.Dataset(
{
"correlation_mean": xr.concat(
[da.assign_coords({"month": k}) for k, da in mean_correlations.items()],
dim="month",
),
"correlation_ci": xr.concat(
[
da.assign_coords({"month": k})
for k, da in null_correlation_bounds.items()
],
dim="month",
),
}
)
ds_corr["correlation_mean"].attrs = {
"standard_name": "correlation_coefficent",
"long_name": "Ensemble mean correlation coefficent",
"description": "Mean Spearman rank correlation coefficent between all ensemble members",
}
ds_corr["correlation_ci"].attrs = {
"standard_name": "confidence_interval",
"long_name": "95% confidence interval",
"description": "Spearman rank correlation coefficent confidence interval (2.5%-97.5%)",
}
return ds_corr


def create_plot(
mean_correlations, null_correlation_bounds, outfile, lead_dim="lead_time"
mean_correlations, null_correlation_bounds, outfile=None, lead_dim="lead_time"
):
"""Create independence plot.
Expand All @@ -94,7 +116,7 @@ def create_plot(
Mean correlation (for each lead time) data
null_correlation_bounds : list
Bounds on zero correlation [lower_bound, upper_bound]
outfile : str
outfile : str, optional
Path for output image file
lead_dim: str, default 'lead_time'
Name of the lead time dimension in mean_correlations
Expand All @@ -103,30 +125,35 @@ def create_plot(
fig, ax = plt.subplots()
colors = iter(plt.cm.Set1(np.linspace(0, 1, 9)))

months = list(mean_correlations.keys())
months = list(mean_correlations.month.values)
months.sort()
for month in months:
for i, month in enumerate(months):
color = next(colors)
mean_corr = mean_correlations[month].dropna(lead_dim)
month_abbr = calendar.month_abbr[month]

# Plot the ensemble mean correlation
mean_corr = mean_correlations.isel(month=i).dropna(lead_dim)
mean_corr.plot.scatter(
x=lead_dim,
color=color,
marker="o",
linestyle="None",
label=f"{month_abbr} starts",
)
lower_bound, upper_bound = null_correlation_bounds[month].values
lead_time_bounds = [mean_corr[lead_dim].min(), mean_corr[lead_dim].max()]
plt.plot(
lead_time_bounds, [lower_bound, lower_bound], color=color, linestyle="--"
)
plt.plot(
lead_time_bounds, [upper_bound, upper_bound], color=color, linestyle="--"
)
plt.ylabel("correlation")
plt.legend()
plt.savefig(outfile, bbox_inches="tight", facecolor="white", dpi=200)
# Plot the null correlation bounds as dashed lines
bounds = null_correlation_bounds.isel(month=i).values
ax.axhline(bounds[0], c=color, ls="--")
ax.axhline(bounds[1], c=color, ls="--", label="95% CI")

ax.set_xlabel(lead_dim.replace("_", " ").capitalize())
ax.set_ylabel("Correlation coefficient")
ax.legend()

if outfile:
plt.tight_layout()
plt.savefig(outfile, bbox_inches="tight", facecolor="white", dpi=200)
else:
plt.show()


def _remove_ensemble_mean_trend(da, dim="init_date", ensemble_dim="ensemble"):
Expand Down Expand Up @@ -327,7 +354,8 @@ def _parse_command_line():

parser.add_argument("fcst_file", type=str, help="Forecast file")
parser.add_argument("var", type=str, help="Variable name")
parser.add_argument("outfile", type=str, help="Output file")
parser.add_argument("outfile", type=str, default=None, help="Data filename")
parser.add_argument("plot_outfile", type=str, default=None, help="Plot filename")

parser.add_argument(
"--dask_config", type=str, help="YAML file specifying dask client configuration"
Expand Down Expand Up @@ -379,48 +407,32 @@ def _main():
)
da_fcst = ds_fcst[args.var]

mean_correlations, null_correlation_bounds = run_tests(
ds_corr = run_tests(
da_fcst,
init_dim=args.init_dim,
lead_dim=args.lead_dim,
ensemble_dim=args.ensemble_dim,
)

# # Create plot if the data dimensions are correct (i.e, no extra dimensions)
if all(
[
dim in [args.ensemble_dim, args.init_dim, args.lead_dim]
for dim in da_fcst.dims
]
):
create_plot(mean_correlations, null_correlation_bounds, args.outfile)

# Save mean_correlations and null_correlation_bounds to a file
else:
# Create Dataset
ds_independence = xr.Dataset(
{
"mean_correlations": xr.concat(
[
da.assign_coords({"month": k})
for k, da in mean_correlations.items()
],
dim="month",
),
"null_correlation_bounds": xr.concat(
[
da.assign_coords({"month": k})
for k, da in null_correlation_bounds.items()
],
dim="month",
),
}
if args.plot_outfile and len(da_fcst.dims) > 3:
# Scatter plot of correlation vs lead (if there are no extra dimensions)
create_plot(
ds_corr["correlation_mean"],
ds_corr["correlation_ci"],
args.plot_outfile,
args.lead_dim,
)
# todo: copy attributes from da_fcst
# Add history attribute

if args.outfile:
# Save correlation coefficents and confidence intervals
ds_corr.update(ds_fcst.attrs)
infile_logs = {args.fcst_file: ds_fcst.attrs["history"]}
ds_independence.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs)
ds_independence.to_netcdf(args.outfile.replace(".png", ".nc"))
ds_corr.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs)

if "zarr" in args.outfile:
fileio.to_zarr(ds_corr, args.outfile)
else:
ds_corr.to_netcdf(args.outfile)


if __name__ == "__main__":
Expand Down

0 comments on commit 6409643

Please sign in to comment.