Skip to content

Commit

Permalink
Validation for spline basis, default auto config
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopeix committed Dec 12, 2024
1 parent 2aed153 commit 8a7f65d
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
2 changes: 1 addition & 1 deletion nbs/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@
" default_config = {\n",
" \"input_size_multiplier\": [1, 2, 3, 4, 5],\n",
" \"h\": None,\n",
" \"basis\": tune.choice([\"polynomial\", \"spline\"]),\n",
" \"basis\": tune.choice([\"polynomial\", \"changepoint\"]),\n",
" \"n_basis\": tune.choice([2, 5]),\n",
" \"learning_rate\": tune.loguniform(1e-4, 1e-1),\n",
" \"scaler_type\": tune.choice([None, 'robust', 'standard']),\n",
Expand Down
23 changes: 11 additions & 12 deletions nbs/models.nbeats.ipynb
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "75372ec2",
"metadata": {},
"outputs": [],
"source": [
"%set_env PYTORCH_ENABLE_MPS_FALLBACK=1"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -191,7 +181,8 @@
" Returns:\n",
" - spline_basis (ndarray): An array of cubic spline basis functions.\n",
" \"\"\"\n",
" n_basis = max(4, n_basis)\n",
" if n_basis < 4:\n",
" raise ValueError(f\"To use the spline basis, n_basis must be set to 4 or more. Current value is {n_basis}\")\n",
" x = np.linspace(0, 1, length)\n",
" knots = np.linspace(0, 1, n_basis - 2)\n",
" t = np.concatenate(([0, 0, 0], knots, [1, 1, 1]))\n",
Expand Down Expand Up @@ -817,7 +808,7 @@
"\n",
"model = NBEATS(h=12, input_size=24,\n",
" basis='polynomial',\n",
" n_basis=2,\n",
" n_basis=5,\n",
" loss=DistributionLoss(distribution='Poisson', level=[80, 90]),\n",
" stack_types = ['identity', 'trend', 'seasonality'],\n",
" max_steps=100,\n",
Expand Down Expand Up @@ -847,6 +838,14 @@
"plt.legend()\n",
"plt.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c87058ca",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion neuralforecast/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ class AutoNBEATS(BaseAuto):
default_config = {
"input_size_multiplier": [1, 2, 3, 4, 5],
"h": None,
"basis": tune.choice(["polynomial", "spline"]),
"basis": tune.choice(["polynomial", "changepoint"]),
"n_basis": tune.choice([2, 5]),
"learning_rate": tune.loguniform(1e-4, 1e-1),
"scaler_type": tune.choice([None, "robust", "standard"]),
Expand Down
15 changes: 9 additions & 6 deletions neuralforecast/models/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# %% auto 0
__all__ = ['NBEATS']

# %% ../../nbs/models.nbeats.ipynb 6
# %% ../../nbs/models.nbeats.ipynb 5
from typing import Tuple, Optional

import numpy as np
Expand All @@ -16,7 +16,7 @@
from ..losses.pytorch import MAE
from ..common._base_windows import BaseWindows

# %% ../../nbs/models.nbeats.ipynb 8
# %% ../../nbs/models.nbeats.ipynb 7
def generate_legendre_basis(length, n_basis):
"""
Generates Legendre polynomial basis functions.
Expand Down Expand Up @@ -119,7 +119,10 @@ def generate_spline_basis(length, n_basis):
Returns:
- spline_basis (ndarray): An array of cubic spline basis functions.
"""
n_basis = max(4, n_basis)
if n_basis < 4:
raise ValueError(
f"To use the spline basis, n_basis must be set to 4 or more. Current value is {n_basis}"
)
x = np.linspace(0, 1, length)
knots = np.linspace(0, 1, n_basis - 2)
t = np.concatenate(([0, 0, 0], knots, [1, 1, 1]))
Expand Down Expand Up @@ -162,7 +165,7 @@ def get_basis(length, n_basis, basis):
}
return basis_dict[basis](length, n_basis + 1)

# %% ../../nbs/models.nbeats.ipynb 9
# %% ../../nbs/models.nbeats.ipynb 8
class IdentityBasis(nn.Module):
def __init__(self, backcast_size: int, forecast_size: int, out_features: int = 1):
super().__init__()
Expand Down Expand Up @@ -273,7 +276,7 @@ def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
forecast = torch.einsum("bpq,pt->btq", forecast_theta, self.forecast_basis)
return backcast, forecast

# %% ../../nbs/models.nbeats.ipynb 10
# %% ../../nbs/models.nbeats.ipynb 9
ACTIVATIONS = ["ReLU", "Softplus", "Tanh", "SELU", "LeakyReLU", "PReLU", "Sigmoid"]


Expand Down Expand Up @@ -319,7 +322,7 @@ def forward(self, insample_y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
backcast, forecast = self.basis(theta)
return backcast, forecast

# %% ../../nbs/models.nbeats.ipynb 11
# %% ../../nbs/models.nbeats.ipynb 10
class NBEATS(BaseWindows):
"""NBEATS
Expand Down

0 comments on commit 8a7f65d

Please sign in to comment.