diff --git a/unseen/independence.py b/unseen/independence.py index d89e25d..01696de 100644 --- a/unseen/independence.py +++ b/unseen/independence.py @@ -40,11 +40,8 @@ def run_tests( Returns ------- - mean_correlations : dict of - Mean correlation between all ensemble members for each lead time and initial month - null_correlation_bounds : dict of - Bounds on zero correlation for each lead time and initial month - + ds : xarray Dataset + Correlation ensemble mean and the corresponding 95% confidence interval """ months = np.unique(fcst[init_dim].dt.month.values) @@ -79,12 +76,37 @@ def run_tests( lead_dim=lead_dim, ensemble_dim=ensemble_dim, ) - - return mean_correlations, null_correlation_bounds + # Create Dataset + ds_corr = xr.Dataset( + { + "correlation_mean": xr.concat( + [da.assign_coords({"month": k}) for k, da in mean_correlations.items()], + dim="month", + ), + "correlation_ci": xr.concat( + [ + da.assign_coords({"month": k}) + for k, da in null_correlation_bounds.items() + ], + dim="month", + ), + } + ) + ds_corr["correlation_mean"].attrs = { + "standard_name": "correlation_coefficent", + "long_name": "Ensemble mean correlation coefficent", + "description": "Mean Spearman rank correlation coefficent between all ensemble members", + } + ds_corr["correlation_ci"].attrs = { + "standard_name": "confidence_interval", + "long_name": "95% confidence interval", + "description": "Spearman rank correlation coefficent confidence interval (2.5%-97.5%)", + } + return ds_corr def create_plot( - mean_correlations, null_correlation_bounds, outfile, lead_dim="lead_time" + mean_correlations, null_correlation_bounds, outfile=None, lead_dim="lead_time" ): """Create independence plot. @@ -94,7 +116,7 @@ def create_plot( Mean correlation (for each lead time) data null_correlation_bounds : list Bounds on zero correlation [lower_bound, upper_bound] - outfile : str + outfile : str, optional Path for output image file lead_dim: str, default 'lead_time' Name of the lead time dimension in mean_correlations @@ -103,12 +125,14 @@ def create_plot( fig, ax = plt.subplots() colors = iter(plt.cm.Set1(np.linspace(0, 1, 9))) - months = list(mean_correlations.keys()) + months = list(mean_correlations.month.values) months.sort() - for month in months: + for i, month in enumerate(months): color = next(colors) - mean_corr = mean_correlations[month].dropna(lead_dim) month_abbr = calendar.month_abbr[month] + + # Plot the ensemble mean correlation + mean_corr = mean_correlations.isel(month=i).dropna(lead_dim) mean_corr.plot.scatter( x=lead_dim, color=color, @@ -116,17 +140,20 @@ def create_plot( linestyle="None", label=f"{month_abbr} starts", ) - lower_bound, upper_bound = null_correlation_bounds[month].values - lead_time_bounds = [mean_corr[lead_dim].min(), mean_corr[lead_dim].max()] - plt.plot( - lead_time_bounds, [lower_bound, lower_bound], color=color, linestyle="--" - ) - plt.plot( - lead_time_bounds, [upper_bound, upper_bound], color=color, linestyle="--" - ) - plt.ylabel("correlation") - plt.legend() - plt.savefig(outfile, bbox_inches="tight", facecolor="white", dpi=200) + # Plot the null correlation bounds as dashed lines + bounds = null_correlation_bounds.isel(month=i).values + ax.axhline(bounds[0], c=color, ls="--") + ax.axhline(bounds[1], c=color, ls="--", label="95% CI") + + ax.set_xlabel(lead_dim.replace("_", " ").capitalize()) + ax.set_ylabel("Correlation coefficient") + ax.legend() + + if outfile: + plt.tight_layout() + plt.savefig(outfile, bbox_inches="tight", facecolor="white", dpi=200) + else: + plt.show() def _remove_ensemble_mean_trend(da, dim="init_date", ensemble_dim="ensemble"): @@ -327,7 +354,8 @@ def _parse_command_line(): parser.add_argument("fcst_file", type=str, help="Forecast file") parser.add_argument("var", type=str, help="Variable name") - parser.add_argument("outfile", type=str, help="Output file") + parser.add_argument("outfile", type=str, default=None, help="Data filename") + parser.add_argument("plot_outfile", type=str, default=None, help="Plot filename") parser.add_argument( "--dask_config", type=str, help="YAML file specifying dask client configuration" @@ -379,48 +407,32 @@ def _main(): ) da_fcst = ds_fcst[args.var] - mean_correlations, null_correlation_bounds = run_tests( + ds_corr = run_tests( da_fcst, init_dim=args.init_dim, lead_dim=args.lead_dim, ensemble_dim=args.ensemble_dim, ) - # # Create plot if the data dimensions are correct (i.e, no extra dimensions) - if all( - [ - dim in [args.ensemble_dim, args.init_dim, args.lead_dim] - for dim in da_fcst.dims - ] - ): - create_plot(mean_correlations, null_correlation_bounds, args.outfile) - - # Save mean_correlations and null_correlation_bounds to a file - else: - # Create Dataset - ds_independence = xr.Dataset( - { - "mean_correlations": xr.concat( - [ - da.assign_coords({"month": k}) - for k, da in mean_correlations.items() - ], - dim="month", - ), - "null_correlation_bounds": xr.concat( - [ - da.assign_coords({"month": k}) - for k, da in null_correlation_bounds.items() - ], - dim="month", - ), - } + if args.plot_outfile and len(da_fcst.dims) > 3: + # Scatter plot of correlation vs lead (if there are no extra dimensions) + create_plot( + ds_corr["correlation_mean"], + ds_corr["correlation_ci"], + args.plot_outfile, + args.lead_dim, ) - # todo: copy attributes from da_fcst - # Add history attribute + + if args.outfile: + # Save correlation coefficents and confidence intervals + ds_corr.update(ds_fcst.attrs) infile_logs = {args.fcst_file: ds_fcst.attrs["history"]} - ds_independence.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs) - ds_independence.to_netcdf(args.outfile.replace(".png", ".nc")) + ds_corr.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs) + + if "zarr" in args.outfile: + fileio.to_zarr(ds_corr, args.outfile) + else: + ds_corr.to_netcdf(args.outfile) if __name__ == "__main__":