Skip to content

Commit

Permalink
Update reset_times calendar, min_lead option and add similarity spati…
Browse files Browse the repository at this point in the history
…al plot (AusClimateService#78)

* Update reset_times calendar and revert increased rounding of coords

* Update min_lead option and independence/similarity spatial plots
  • Loading branch information
stellema authored Oct 8, 2024
1 parent ac6d82a commit 5727504
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 83 deletions.
24 changes: 17 additions & 7 deletions unseen/bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,18 @@ def _parse_command_line():
default="conservative",
help="Regridding method for observational or forecast data [default=conservative]",
)

parser.add_argument(
"--lead_dim",
type=str,
default="lead_time",
help="Name of lead time dimension",
)
parser.add_argument(
"--init_dim",
type=str,
default="init_date",
help="Name of initial date dimension",
)
args = parser.parse_args()

return args
Expand All @@ -263,14 +274,13 @@ def _main():
# Load min_lead from file
ds_min_lead = fileio.open_dataset(args.min_lead, **args.min_lead_kwargs)
min_lead = ds_min_lead["min_lead"].load()
# Assumes min_lead has only one init month
assert min_lead.month.size == 1, "Not implemented for multiple init months"
min_lead = min_lead.drop_vars("month")
if min_lead.size == 1:
min_lead = min_lead.item()
da_fcst = da_fcst.groupby(f"{args.init_dim}.month").where(
da_fcst[args.lead_dim] >= min_lead
)
da_fcst = da_fcst.drop_vars("month")
else:
min_lead = args.min_lead
da_fcst = da_fcst.where(da_fcst[args.lead_dim] >= min_lead)
da_fcst = da_fcst.where(da_fcst[args.lead_dim] >= min_lead)

# Calculate bias
bias = get_bias(
Expand Down
2 changes: 1 addition & 1 deletion unseen/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _fix_metadata(ds, metadata_file):

if "round_coords" in metadata_dict:
for coord in metadata_dict["round_coords"]:
ds = ds.assign_coords({coord: ds[coord].round(decimals=3)})
ds = ds.assign_coords({coord: ds[coord].round(decimals=6)})

if "units" in metadata_dict:
for var, units in metadata_dict["units"].items():
Expand Down
92 changes: 47 additions & 45 deletions unseen/independence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import calendar
from cartopy.crs import PlateCarree
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import xskillscore as xs
import matplotlib.pyplot as plt

from . import dask_setup
from . import fileio
Expand Down Expand Up @@ -87,7 +87,7 @@ def run_tests(
n_resamples=n_resamples,
)

ds_corr = xr.Dataset(
ds = xr.Dataset(
{
"r": xr.concat(
[da.assign_coords({"month": k}) for k, da in r_mean.items()],
Expand All @@ -98,30 +98,29 @@ def run_tests(
),
},
coords=fcst.coords,
attrs=fcst.attrs,
)
# Add minimum independent lead as variable
ds_corr["min_lead"] = min_independent_lead(ds_corr, lead_dim=lead_dim)
return ds_corr
ds["min_lead"] = min_independent_lead(ds, lead_dim=lead_dim)
return ds


def min_independent_lead(ds_corr, lead_dim="lead_time"):
def min_independent_lead(ds, lead_dim="lead_time"):
"""Get the first lead time within confidence interval.
Parameters
----------
ds_corr : xarray.Dataset
ds : xarray.Dataset
Dataset of correlation coefficient and confidence interval
lead_dim : str, default 'lead_time'
Name of the lead time dimension in ds_corr
Name of the lead time dimension in ds
Returns
-------
lead : int
Index of the first lead time within the confidence interval
"""
mask = (ds_corr["r"] >= ds_corr["ci"].isel(quantile=0)) & (
ds_corr["r"] >= ds_corr["ci"].isel(quantile=-1)
mask = (ds["r"] >= ds["ci"].isel(quantile=0)) & (
ds["r"] >= ds["ci"].isel(quantile=-1)
)

min_lead = mask.rank(lead_dim).argmin(lead_dim)
Expand All @@ -133,8 +132,9 @@ def min_independent_lead(ds_corr, lead_dim="lead_time"):


def point_plot(
ds_corr,
ds,
outfile=None,
dataset_name=None,
lead_dim="lead_time",
confidence_interval=0.95,
**kwargs,
Expand All @@ -143,10 +143,12 @@ def point_plot(
Parameters
----------
ds_corr : xarray.Dataset
Mean correlation (for each lead time) and confidence interval
ds : xarray.Dataset
Mean correlation (for each lead time) and confidence interval bounds
outfile : str, optional
Path for output image file
dataset_name : str, optional
Name of the dataset for plot suptitle
lead_dim : str, default 'lead_time'
Name of the lead time dimension
confidence_interval : float, default 0.95
Expand All @@ -158,14 +160,14 @@ def point_plot(
fig, ax = plt.subplots(**kwargs)
colors = iter(plt.cm.Set1(np.linspace(0, 1, 9)))

months = list(ds_corr["r"].month.values)
months = list(ds["r"].month.values)
months.sort()
for i, month in enumerate(months):
color = next(colors)
month_abbr = calendar.month_abbr[month]

# Plot the ensemble mean correlation
mean_corr = ds_corr["r"].isel(month=i).dropna(lead_dim)
mean_corr = ds["r"].isel(month=i).dropna(lead_dim)
mean_corr.plot.scatter(
ax=ax,
x=lead_dim,
Expand All @@ -175,45 +177,48 @@ def point_plot(
label=f"{month_abbr} starts",
)
# Plot the null correlation bounds as dashed lines
bounds = ds_corr["ci"].isel(month=i).values
bounds = ds["ci"].isel(month=i).values
ax.axhline(bounds[0], c=color, ls="--")
ax.axhline(
bounds[1], c=color, ls="--", label=f"{confidence_interval * 100: g}% CI"
)

ax.set_xlabel("Lead")
ax.set_ylabel(ds_corr["r"].attrs["long_name"])
ax.set_ylabel(ds["r"].attrs["long_name"])
ax.legend()

if dataset_name:
fig.suptitle(dataset_name, x=0.6, y=1.02)

if outfile:
plt.tight_layout()
plt.savefig(outfile, bbox_inches="tight", facecolor="white", dpi=200)
else:
plt.show()


def spatial_plot(ds_corr, outfile=None, **kwargs):
def spatial_plot(ds, dataset_name=None, outfile=None, kwargs=dict(figsize=[8, 5])):
"""Contour plot of the first independent lead time (for each init month).
Parameters
----------
ds_corr : xarray.Dataset
ds : xarray.Dataset
Index of the first independent lead time (for each init month)
dataset_name : str, optional
Name of the dataset for plot titles
outfile : str, optional
Path for output image file
kwargs : dict, optional
Additional keyword arguments for xarray.DataArray.plot
"""
# Discrete colour bar based on lead times
cbar_ticks = np.arange(ds_corr.min_lead.min(), ds_corr.min_lead.max() + 2)
cbar_ticks = np.arange(ds.min_lead.min(), ds.min_lead.max() + 2)
# Convert integer to month names for plot titles
ds_corr.coords["month"] = [
f"{calendar.month_name[m]} starts" for m in ds_corr.month.values
]
titles = [f"{calendar.month_name[m]} starts" for m in ds.month.values]

cm = ds_corr.min_lead.plot.pcolormesh(
cm = ds.min_lead.plot.pcolormesh(
col="month",
col_wrap=min(3, len(ds_corr.month)),
col_wrap=min(3, len(ds.month)),
transform=PlateCarree(),
subplot_kws=dict(projection=PlateCarree()),
levels=cbar_ticks,
Expand All @@ -222,27 +227,25 @@ def spatial_plot(ds_corr, outfile=None, **kwargs):
**kwargs,
)
# Fix hidden axis ticks and labels
for ax in cm.axs.flat:
for i, ax in enumerate(cm.axs.flat):
subplotspec = ax.get_subplotspec()
if subplotspec.is_last_row():
ax.xaxis.set_visible(True)
if subplotspec.is_first_col():
ax.yaxis.set_visible(True)
ax.coastlines()
ax.set_title(titles[i])
cm.fig.set_constrained_layout(True)
cm.fig.get_layout_engine().set(h_pad=0.2)
cm.add_colorbar()
cm.set_titles("{value}")

# Fix lat/lon axis labels
if all([dim in ds_corr.dims for dim in ["lat", "lon"]]):
cm.set_xlabels(
f"{ds_corr.lon.attrs['long_name']} [{ds_corr.lon.attrs['units']}]"
)
cm.set_ylabels(
f"{ds_corr.lat.attrs['long_name']} [{ds_corr.lat.attrs['units']}]"
)
if dataset_name:
cm.fig.suptitle(dataset_name, x=0.6, y=1.02)

# Fix lat/lon axis labels
if all([dim in ds.dims for dim in ["lat", "lon"]]):
cm.set_xlabels(f"{ds.lon.attrs['long_name']} [{ds.lon.attrs['units']}]")
cm.set_ylabels(f"{ds.lat.attrs['long_name']} [{ds.lat.attrs['units']}]")
if outfile:
plt.savefig(outfile, bbox_inches="tight", facecolor="white", dpi=200)
else:
Expand Down Expand Up @@ -535,36 +538,35 @@ def _main():
)
da_fcst = ds_fcst[args.var]

ds_corr = run_tests(
ds = run_tests(
da_fcst,
init_dim=args.init_dim,
lead_dim=args.lead_dim,
ensemble_dim=args.ensemble_dim,
confidence_interval=args.confidence_interval,
n_resamples=args.n_resamples,
)

ds.attrs = ds_fcst.attrs # Add forecast dataset attributes
# Save correlation coefficients, confidence intervals and minimum lead
infile_logs = {args.fcst_file: ds_fcst.attrs["history"]}
ds_corr.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs)
ds.attrs["history"] = fileio.get_new_log(infile_logs=infile_logs)

if args.output_chunks:
ds_corr = ds_corr.chunk(args.output_chunks)

ds = ds.chunk(args.output_chunks)
if "zarr" in args.outfile:
fileio.to_zarr(ds_corr, args.outfile)
fileio.to_zarr(ds, args.outfile)
else:
ds_corr.to_netcdf(args.outfile, compute=True)
ds.to_netcdf(args.outfile, compute=True)

if args.plot_outfile and len(da_fcst.dims) <= 3:
# Scatter plot of correlation vs lead (if there are no extra dimensions)
point_plot(ds_corr, args.plot_outfile, args.lead_dim, args.confidence_interval)
point_plot(ds, args.plot_outfile, args.lead_dim, args.confidence_interval)

elif args.plot_outfile:
# Spatial plot of minimum correlation
spatial_plot(
ds_corr,
ds,
args.plot_outfile,
figsize=[8, 5],
)


Expand Down
Loading

0 comments on commit 5727504

Please sign in to comment.