Skip to content

Commit

Permalink
[FIX] HINT not producing coherent forecasts (#964)
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint authored Apr 10, 2024
1 parent 63ebffd commit 4435b12
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
25 changes: 9 additions & 16 deletions nbs/models.hint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,15 @@
" \n",
" # Bootstrap Sample Reconciliation\n",
" # Default output [mean, quantiles]\n",
" samples = np.einsum('ij,jwhp->iwhp', self.SP, samples)\n",
" samples = np.einsum('ij, jwhp -> iwhp', self.SP, samples)\n",
"\n",
" sample_mean = np.mean(samples, axis=-1, keepdims=True)\n",
" sample_mean = sample_mean.reshape(-1, 1)\n",
"\n",
" forecasts = np.quantile(samples, self.model.loss.quantiles, axis=-1)\n",
" forecasts = forecasts.transpose(1,2,3,0) # [...,samples]\n",
" forecasts = forecasts.reshape(-1, len(self.model.loss.quantiles))\n",
" \n",
" sample_mean = np.mean(forecasts, axis=-1, keepdims=True)\n",
"\n",
" forecasts = np.concatenate([sample_mean, forecasts], axis=-1)\n",
" return forecasts\n",
"\n",
Expand Down Expand Up @@ -486,16 +489,13 @@
"\n",
"# ---Check Hierarchical Coherence---\n",
"parent_children_dict = {0: [1, 2], 1: [3, 4], 2: [5, 6]}\n",
"eps = 0.03\n",
"# check coherence for each horizon time step\n",
"for _, df in forecasts.groupby('ds'):\n",
" hint_mean = df['HINT'].values\n",
" for parent_idx, children_list in parent_children_dict.items():\n",
" parent_value = hint_mean[parent_idx]\n",
" children_sum = hint_mean[children_list].sum()\n",
" percent_diff = np.round(abs(parent_value-children_sum)/parent_value * 100, 2)\n",
" print(f\"Percentage Difference: {percent_diff}\")\n",
" assert percent_diff < eps"
" np.testing.assert_allclose(children_sum, parent_value)"
]
},
{
Expand Down Expand Up @@ -564,7 +564,6 @@
"nhits = NHITS(h=horizon,\n",
" input_size=24,\n",
" loss=GMM(n_components=10, level=level),\n",
" hist_exog_list=['month'],\n",
" max_steps=2000,\n",
" early_stop_patience_steps=10,\n",
" val_check_steps=50,\n",
Expand All @@ -577,7 +576,8 @@
"\n",
"# Fit and Predict\n",
"nf = NeuralForecast(models=[model], freq='MS')\n",
"Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)"
"Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n",
"Y_hat_df = Y_hat_df.reset_index()"
]
},
{
Expand Down Expand Up @@ -605,13 +605,6 @@
"plt.grid()\n",
"plt.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
7 changes: 5 additions & 2 deletions neuralforecast/models/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,15 @@ def predict(self, dataset, step_size=1, random_seed=None, **data_module_kwargs):

# Bootstrap Sample Reconciliation
# Default output [mean, quantiles]
samples = np.einsum("ij,jwhp->iwhp", self.SP, samples)
samples = np.einsum("ij, jwhp -> iwhp", self.SP, samples)

sample_mean = np.mean(samples, axis=-1, keepdims=True)
sample_mean = sample_mean.reshape(-1, 1)

forecasts = np.quantile(samples, self.model.loss.quantiles, axis=-1)
forecasts = forecasts.transpose(1, 2, 3, 0) # [...,samples]
forecasts = forecasts.reshape(-1, len(self.model.loss.quantiles))

sample_mean = np.mean(forecasts, axis=-1, keepdims=True)
forecasts = np.concatenate([sample_mean, forecasts], axis=-1)
return forecasts

Expand Down

0 comments on commit 4435b12

Please sign in to comment.