Skip to content

Commit

Permalink
Merge branch 'main' into mlpmultivariate
Browse files Browse the repository at this point in the history
  • Loading branch information
cchallu authored Apr 3, 2024
2 parents e5d7b9e + 9e0efab commit dc2b40d
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 25 deletions.
116 changes: 105 additions & 11 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,13 @@
" raise Exception('you must define `n_windows` or `test_size` but not both') \n",
" # Recover initial model if use_init_models.\n",
" if use_init_models:\n",
" self._reset_models() \n",
" self._reset_models()\n",
" if isinstance(df, pd.DataFrame) and df.index.name == id_col:\n",
" warnings.warn(\n",
" \"Passing the id as index is deprecated, please provide it as a column instead.\",\n",
" FutureWarning,\n",
" )\n",
" df = df.reset_index(id_col) \n",
" if not refit:\n",
" return self._no_refit_cross_validation(\n",
" df=df,\n",
Expand Down Expand Up @@ -964,13 +970,19 @@
"\n",
" # Remove test set from dataset and last dates\n",
" test_size = self.models[0].get_test_size()\n",
" if test_size>0:\n",
"\n",
" # trim the forefront period to ensure `test_size - h` should be module `step_size\n",
" # Note: current constraint imposes that all series lengths are equal, so we can take the first series length as sample\n",
" series_length = self.dataset.indptr[1] - self.dataset.indptr[0]\n",
" _, forefront_offset = np.divmod((series_length - test_size - self.h), step_size)\n",
"\n",
" if test_size>0 or forefront_offset>0:\n",
" trimmed_dataset = TimeSeriesDataset.trim_dataset(dataset=self.dataset,\n",
" right_trim=test_size,\n",
" left_trim=0)\n",
" left_trim=forefront_offset)\n",
" new_idxs = np.hstack(\n",
" [\n",
" np.arange(self.dataset.indptr[i], self.dataset.indptr[i + 1] - test_size)\n",
" np.arange(self.dataset.indptr[i] + forefront_offset, self.dataset.indptr[i + 1] - test_size)\n",
" for i in range(self.dataset.n_groups)\n",
" ]\n",
" )\n",
Expand Down Expand Up @@ -1364,20 +1376,25 @@
"source": [
"#| hide\n",
"# id as index warnings\n",
"df_with_idx = AirPassengersPanel_train.set_index('unique_id')\n",
"models = [\n",
" NHITS(h=12, input_size=12, max_steps=1)\n",
"]\n",
"nf = NeuralForecast(models=models, freq='M')\n",
"nf.fit(df=AirPassengersPanel_train)\n",
"with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', category=FutureWarning)\n",
" nf.fit(df=df_with_idx) \n",
" nf.predict()\n",
" nf.predict_insample()\n",
" nf.cross_validation(df=AirPassengersPanel_train)\n",
"id_warnings = [\n",
" nf.cross_validation(df=df_with_idx)\n",
"input_id_warnings = [\n",
" w for w in issued_warnings if 'Passing the id as index is deprecated' in str(w.message)\n",
"]\n",
"assert len(input_id_warnings) == 2\n",
"output_id_warnings = [\n",
" w for w in issued_warnings if 'the predictions will have the id as a column' in str(w.message)\n",
"]\n",
"assert len(id_warnings) == 3"
"assert len(output_id_warnings) == 3"
]
},
{
Expand Down Expand Up @@ -1683,7 +1700,41 @@
"forecasts = nf.predict_insample(step_size=1)\n",
"\n",
"expected_size = n_series*((len(AirPassengersPanel_train)//n_series-test_size)-h+1)*h\n",
"assert len(forecasts) == expected_size, 'Shape mistmach in predict_insample'"
"assert len(forecasts) == expected_size, f'Shape mismatch in predict_insample: {len(forecasts)=}, {expected_size=}'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23fb98f8-0e27-44b2-a8a9-3b551986d53f",
"metadata": {},
"outputs": [],
"source": [
"#| hide,\n",
"# Test predict_insample step_size\n",
"\n",
"h = 12\n",
"train_end = AirPassengersPanel_train['ds'].max()\n",
"sizes = AirPassengersPanel_train['unique_id'].value_counts().to_numpy()\n",
"for step_size, test_size in [(7, 0), (9, 0), (7, 5), (9, 5)]:\n",
" models = [NHITS(h=h, input_size=12, max_steps=1)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" nf.fit(AirPassengersPanel_train)\n",
" # Note: only apply set_test_size() upon nf.fit(), otherwise it would have set the test_size = 0\n",
" nf.models[0].set_test_size(test_size)\n",
" \n",
" forecasts = nf.predict_insample(step_size=step_size)\n",
" last_cutoff = train_end - test_size * pd.offsets.MonthEnd() - h * pd.offsets.MonthEnd()\n",
" n_expected_cutoffs = (sizes[0] - test_size - nf.h + step_size) // step_size\n",
"\n",
" # compare cutoff values\n",
" expected_cutoffs = np.flip(np.array([last_cutoff - step_size * i * pd.offsets.MonthEnd() for i in range(n_expected_cutoffs)]))\n",
" actual_cutoffs = np.array([pd.Timestamp(x) for x in forecasts[forecasts['unique_id']==nf.uids[1]]['cutoff'].unique()])\n",
" np.testing.assert_array_equal(expected_cutoffs, actual_cutoffs, err_msg=f\"{step_size=},{expected_cutoffs=},{actual_cutoffs=}\")\n",
" \n",
" # check forecast-points count per series\n",
" cutoffs_by_series = forecasts.groupby(['unique_id', 'cutoff']).size().unstack('unique_id')\n",
" pd.testing.assert_series_equal(cutoffs_by_series['Airline1'], cutoffs_by_series['Airline2'], check_names=False)"
]
},
{
Expand Down Expand Up @@ -2296,7 +2347,8 @@
"source": [
"#| hide\n",
"#| polars\n",
"import polars"
"import polars\n",
"from polars.testing import assert_frame_equal"
]
},
{
Expand Down Expand Up @@ -2351,6 +2403,48 @@
"assert_equal_dfs(cv_res, cv_res_pl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "906fd509-82f5-431e-86df-d73e928ddeef",
"metadata": {},
"outputs": [],
"source": [
"#| hide,\n",
"#| polars\n",
"# Test predict_insample step_size\n",
"\n",
"h = 12\n",
"train_end = AirPassengers_pl['time'].max()\n",
"sizes = AirPassengers_pl['uid'].value_counts().to_numpy()\n",
"\n",
"for step_size, test_size in [(7, 0), (9, 0), (7, 5), (9, 5)]:\n",
" models = [NHITS(h=h, input_size=12, max_steps=1)]\n",
" nf = NeuralForecast(models=models, freq='1mo')\n",
" nf.fit(\n",
" AirPassengers_pl,\n",
" id_col='uid',\n",
" time_col='time',\n",
" target_col='target', \n",
" )\n",
" # Note: only apply set_test_size() upon nf.fit(), otherwise it would have set the test_size = 0\n",
" nf.models[0].set_test_size(test_size) \n",
" \n",
" forecasts = nf.predict_insample(step_size=step_size)\n",
" n_expected_cutoffs = (sizes[0][1] - test_size - nf.h + step_size) // step_size\n",
"\n",
" # compare cutoff values\n",
" last_cutoff = train_end - test_size * pd.offsets.MonthEnd() - h * pd.offsets.MonthEnd()\n",
" expected_cutoffs = np.flip(np.array([last_cutoff - step_size * i * pd.offsets.MonthEnd() for i in range(n_expected_cutoffs)]))\n",
" pl_cutoffs = forecasts.filter(polars.col('uid') ==nf.uids[1]).select('cutoff').unique(maintain_order=True)\n",
" actual_cutoffs = np.array([pd.Timestamp(x['cutoff']) for x in pl_cutoffs.rows(named=True)])\n",
" np.testing.assert_array_equal(expected_cutoffs, actual_cutoffs, err_msg=f\"{step_size=},{expected_cutoffs=},{actual_cutoffs=}\")\n",
"\n",
" # check forecast-points count per series\n",
" cutoffs_by_series = forecasts.groupby(['uid', 'cutoff']).count()\n",
" assert_frame_equal(cutoffs_by_series.filter(polars.col('uid') == \"Airline1\").select(['cutoff', 'count']), cutoffs_by_series.filter(polars.col('uid') == \"Airline2\").select(['cutoff', 'count'] ), check_row_order=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -2491,7 +2585,7 @@
"source": [
"#| hide\n",
"# test that if the user-defined optimizer is not a subclass of torch.optim.optimizer, failed with exception\n",
"# tests cover different types of base classes such as basewindows, baserecurrent, basemultivariate\n",
"# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
"test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
"test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n"
Expand Down
8 changes: 7 additions & 1 deletion nbs/models.nlinear.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
" `alias`: str, optional, Custom name of the model.<br>\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br> \n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
"\n",
"\t*References*<br>\n",
Expand Down Expand Up @@ -141,6 +143,8 @@
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" optimizer = None,\n",
" optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(NLinear, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand All @@ -165,6 +169,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
Expand Down Expand Up @@ -206,7 +212,7 @@
" # Final\n",
" forecast = self.linear(norm_insample_y) + last_value\n",
" forecast = forecast.reshape(batch_size, self.h, self.loss.outputsize_multiplier)\n",
" forecast = self.loss.domain_map(forecast)\n",
" forecast = self.loss.domain_map(forecast)\n",
" return forecast"
]
},
Expand Down
13 changes: 9 additions & 4 deletions nbs/models.timellm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,11 @@
"\n",
"from neuralforecast.losses.pytorch import MAE\n",
"\n",
"IS_TRANSFORMERS_INSTALLED = True\n",
"try:\n",
" from transformers import GPT2Config, GPT2Model, GPT2Tokenizer\n",
" from transformers import GPT2Config, GPT2Model, GPT2Tokenizer\n",
" IS_TRANSFORMERS_INSTALLED = True\n",
"except ImportError:\n",
" IS_TRANSFORMERS_INSTALLED = False\n",
" print('The transformers library is required for Time-LLM to work')"
" IS_TRANSFORMERS_INSTALLED = False"
]
},
{
Expand Down Expand Up @@ -338,6 +337,8 @@
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
" `alias`: str, optional, Custom name of the model.<br>\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br> \n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br> \n",
"\n",
" **References:**<br>\n",
Expand Down Expand Up @@ -387,6 +388,8 @@
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" random_seed: int = 1,\n",
" optimizer = None,\n",
" optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(TimeLLM, self).__init__(h=h,\n",
" input_size=input_size,\n",
Expand All @@ -410,6 +413,8 @@
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
" \n",
" # Asserts\n",
Expand Down
2 changes: 1 addition & 1 deletion neuralforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "1.6.4"
__version__ = "1.7.0"
__all__ = ['NeuralForecast']
from .core import NeuralForecast
19 changes: 16 additions & 3 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,12 @@ def cross_validation(
# Recover initial model if use_init_models.
if use_init_models:
self._reset_models()
if isinstance(df, pd.DataFrame) and df.index.name == id_col:
warnings.warn(
"Passing the id as index is deprecated, please provide it as a column instead.",
FutureWarning,
)
df = df.reset_index(id_col)
if not refit:
return self._no_refit_cross_validation(
df=df,
Expand Down Expand Up @@ -904,14 +910,21 @@ def predict_insample(self, step_size: int = 1):

# Remove test set from dataset and last dates
test_size = self.models[0].get_test_size()
if test_size > 0:

# trim the forefront period to ensure `test_size - h` should be module `step_size
# Note: current constraint imposes that all series lengths are equal, so we can take the first series length as sample
series_length = self.dataset.indptr[1] - self.dataset.indptr[0]
_, forefront_offset = np.divmod((series_length - test_size - self.h), step_size)

if test_size > 0 or forefront_offset > 0:
trimmed_dataset = TimeSeriesDataset.trim_dataset(
dataset=self.dataset, right_trim=test_size, left_trim=0
dataset=self.dataset, right_trim=test_size, left_trim=forefront_offset
)
new_idxs = np.hstack(
[
np.arange(
self.dataset.indptr[i], self.dataset.indptr[i + 1] - test_size
self.dataset.indptr[i] + forefront_offset,
self.dataset.indptr[i + 1] - test_size,
)
for i in range(self.dataset.n_groups)
]
Expand Down
6 changes: 6 additions & 0 deletions neuralforecast/models/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class NLinear(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>
`alias`: str, optional, Custom name of the model.<br>
`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>
`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>
*References*<br>
Expand Down Expand Up @@ -73,6 +75,8 @@ def __init__(
random_seed: int = 1,
num_workers_loader: int = 0,
drop_last_loader: bool = False,
optimizer=None,
optimizer_kwargs=None,
**trainer_kwargs
):
super(NLinear, self).__init__(
Expand All @@ -99,6 +103,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)

Expand Down
13 changes: 9 additions & 4 deletions neuralforecast/models/timellm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.timellm.ipynb.

# %% auto 0
__all__ = ['IS_TRANSFORMERS_INSTALLED', 'ReplicationPad1d', 'TokenEmbedding', 'PatchEmbedding', 'FlattenHead',
'ReprogrammingLayer', 'Normalize', 'TimeLLM']
__all__ = ['ReplicationPad1d', 'TokenEmbedding', 'PatchEmbedding', 'FlattenHead', 'ReprogrammingLayer', 'Normalize', 'TimeLLM']

# %% ../../nbs/models.timellm.ipynb 6
import math
Expand All @@ -15,12 +14,12 @@

from ..losses.pytorch import MAE

IS_TRANSFORMERS_INSTALLED = True
try:
from transformers import GPT2Config, GPT2Model, GPT2Tokenizer

IS_TRANSFORMERS_INSTALLED = True
except ImportError:
IS_TRANSFORMERS_INSTALLED = False
print("The transformers library is required for Time-LLM to work")

# %% ../../nbs/models.timellm.ipynb 9
class ReplicationPad1d(nn.Module):
Expand Down Expand Up @@ -267,6 +266,8 @@ class TimeLLM(BaseWindows):
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>
`alias`: str, optional, Custom name of the model.<br>
`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>
`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>
**References:**<br>
Expand Down Expand Up @@ -317,6 +318,8 @@ def __init__(
num_workers_loader: int = 0,
drop_last_loader: bool = False,
random_seed: int = 1,
optimizer=None,
optimizer_kwargs=None,
**trainer_kwargs,
):
super(TimeLLM, self).__init__(
Expand All @@ -342,6 +345,8 @@ def __init__(
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ author = Nixtla
author_email = [email protected]
copyright = Nixtla Inc.
branch = main
version = 1.6.4
version = 1.7.0
min_python = 3.8
audience = Developers
language = English
Expand Down

0 comments on commit dc2b40d

Please sign in to comment.