diff --git a/action_files/test_models/src/evaluation.py b/action_files/test_models/src/evaluation.py index 9264ba00b..cbe4e35c6 100644 --- a/action_files/test_models/src/evaluation.py +++ b/action_files/test_models/src/evaluation.py @@ -42,17 +42,18 @@ def evaluate(model: str, dataset: str, group: str): if __name__ == '__main__': groups = ['Monthly'] models = ['AutoDilatedRNN', 'RNN', 'TCN', 'DeepAR', - 'NHITS', 'TFT', 'AutoMLP', 'DLinear', 'VanillaTransformer'] + 'NHITS', 'TFT', 'AutoMLP', 'DLinear', 'VanillaTransformer', + 'BiTCN', 'TiDE'] datasets = ['M3'] evaluation = [evaluate(model, dataset, group) for model, group in product(models, groups) for dataset in datasets] evaluation = [eval_ for eval_ in evaluation if eval_ is not None] - evaluation = pd.concat(evaluation) - evaluation = evaluation[['dataset', 'model', 'time', 'mae', 'smape']] - evaluation['time'] /= 60 #minutes - evaluation = evaluation.set_index(['dataset', 'model']).stack().reset_index() - evaluation.columns = ['dataset', 'model', 'metric', 'val'] - evaluation = evaluation.set_index(['dataset', 'metric', 'model']).unstack().round(3) - evaluation = evaluation.droplevel(0, 1).reset_index() - evaluation['AutoARIMA'] = [666.82, 15.35, 3.000] - evaluation.to_csv('data/evaluation.csv') - print(evaluation.T) + df_evaluation = pd.concat(evaluation) + df_evaluation = df_evaluation.loc[:, ['dataset', 'model', 'time', 'mae', 'smape']] + df_evaluation['time'] /= 60 #minutes + df_evaluation = df_evaluation.set_index(['dataset', 'model']).stack().reset_index() + df_evaluation.columns = ['dataset', 'model', 'metric', 'val'] + df_evaluation = df_evaluation.set_index(['dataset', 'metric', 'model']).unstack().round(3) + df_evaluation = df_evaluation.droplevel(0, 1).reset_index() + df_evaluation['AutoARIMA'] = [666.82, 15.35, 3.000] + df_evaluation.to_csv('data/evaluation.csv') + print(df_evaluation.T) diff --git a/action_files/test_models/src/models.py b/action_files/test_models/src/models.py index e03f5e69c..7fb66f2d2 100644 --- a/action_files/test_models/src/models.py +++ b/action_files/test_models/src/models.py @@ -2,33 +2,39 @@ import time import fire -import numpy as np +# import numpy as np import pandas as pd -import pytorch_lightning as pl -import torch +# import pytorch_lightning as pl +# import torch -import neuralforecast +# import neuralforecast from neuralforecast.core import NeuralForecast -from neuralforecast.models.gru import GRU +# from neuralforecast.models.gru import GRU from neuralforecast.models.rnn import RNN from neuralforecast.models.tcn import TCN -from neuralforecast.models.lstm import LSTM -from neuralforecast.models.dilated_rnn import DilatedRNN +# from neuralforecast.models.lstm import LSTM +# from neuralforecast.models.dilated_rnn import DilatedRNN from neuralforecast.models.deepar import DeepAR -from neuralforecast.models.mlp import MLP +# from neuralforecast.models.mlp import MLP from neuralforecast.models.nhits import NHITS -from neuralforecast.models.nbeats import NBEATS -from neuralforecast.models.nbeatsx import NBEATSx +# from neuralforecast.models.nbeats import NBEATS +# from neuralforecast.models.nbeatsx import NBEATSx from neuralforecast.models.tft import TFT from neuralforecast.models.vanillatransformer import VanillaTransformer -from neuralforecast.models.informer import Informer -from neuralforecast.models.autoformer import Autoformer -from neuralforecast.models.patchtst import PatchTST +# from neuralforecast.models.informer import Informer +# from neuralforecast.models.autoformer import Autoformer +# from neuralforecast.models.patchtst import PatchTST from neuralforecast.models.dlinear import DLinear +from neuralforecast.models.bitcn import BiTCN +from neuralforecast.models.tide import TiDE from neuralforecast.auto import ( - AutoMLP, AutoNHITS, AutoNBEATS, AutoDilatedRNN, AutoTFT + AutoMLP, + # AutoNHITS, + # AutoNBEATS, + AutoDilatedRNN, + # AutoTFT ) from neuralforecast.losses.pytorch import SMAPE, MAE @@ -43,13 +49,6 @@ def main(dataset: str = 'M3', group: str = 'Monthly') -> None: train, horizon, freq, seasonality = get_data('data/', dataset, group) train['ds'] = pd.to_datetime(train['ds']) - config_nbeats = { - "input_size": tune.choice([2 * horizon]), - "max_steps": 1000, - "val_check_steps": 300, - "scaler_type": "minmax1", - "random_seed": tune.choice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), - } config = { "hidden_size": tune.choice([256, 512]), "num_layers": tune.choice([2, 4]), @@ -64,6 +63,7 @@ def main(dataset: str = 'M3', group: str = 'Monthly') -> None: "max_steps": 300, "val_check_steps": 100, "random_seed": tune.choice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),} + models = [ AutoDilatedRNN(h=horizon, loss=MAE(), config=config_drnn, num_samples=2, cpus=1), RNN(h=horizon, input_size=2 * horizon, encoder_hidden_size=50, max_steps=300), @@ -74,10 +74,12 @@ def main(dataset: str = 'M3', group: str = 'Monthly') -> None: TFT(h=horizon, input_size=2 * horizon, loss=SMAPE(), hidden_size=64, scaler_type='robust', windows_batch_size=512, max_steps=1500, val_check_steps=500), VanillaTransformer(h=horizon, input_size=2 * horizon, loss=MAE(), hidden_size=64, scaler_type='minmax1', windows_batch_size=512, max_steps=1500, val_check_steps=500), DeepAR(h=horizon, input_size=2 * horizon, scaler_type='minmax1', max_steps=1000), + BiTCN(h=horizon, input_size=2 * horizon, loss=MAE(), dropout=0.0, max_steps=1000, val_check_steps=500), + TiDE(h=horizon, input_size=2 * horizon, loss=MAE(), max_steps=1000, val_check_steps=500), ] # Models - for model in models[:-1]: + for model in models: model_name = type(model).__name__ print(50*'-', model_name, 50*'-') start = time.time() @@ -87,26 +89,13 @@ def main(dataset: str = 'M3', group: str = 'Monthly') -> None: end = time.time() print(end - start) + if model_name == 'DeepAR': + forecasts = forecasts[['unique_id', 'ds', 'DeepAR-median']] + forecasts.columns = ['unique_id', 'ds', model_name] forecasts.to_csv(f'data/{model_name}-forecasts-{dataset}-{group}.csv', index=False) time_df = pd.DataFrame({'time': [end - start], 'model': [model_name]}) time_df.to_csv(f'data/{model_name}-time-{dataset}-{group}.csv', index=False) - # DeepAR - model_name = type(models[-1]).__name__ - start = time.time() - fcst = NeuralForecast(models=[models[-1]], freq=freq) - fcst.fit(train) - forecasts = fcst.predict() - end = time.time() - print(end - start) - - forecasts = forecasts[['unique_id', 'ds', 'DeepAR-median']] - forecasts.columns = ['unique_id', 'ds', 'DeepAR'] - forecasts.to_csv(f'data/{model_name}-forecasts-{dataset}-{group}.csv', index=False) - time_df = pd.DataFrame({'time': [end - start], 'model': [model_name]}) - time_df.to_csv(f'data/{model_name}-time-{dataset}-{group}.csv', index=False) - - if __name__ == '__main__': fire.Fire(main) diff --git a/nbs/core.ipynb b/nbs/core.ipynb index eb1f7be7f..ea9b90a14 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -91,7 +91,7 @@ " Informer, Autoformer, FEDformer,\n", " StemGNN, PatchTST, TimesNet, TimeLLM, TSMixer, TSMixerx,\n", " MLPMultivariate, iTransformer,\n", - " BiTCN,\n", + " BiTCN, TiDE\n", ")" ] }, @@ -238,6 +238,7 @@ " 'mlpmultivariate': MLPMultivariate, 'automlpmultivariate': MLPMultivariate,\n", " 'itransformer': iTransformer, 'autoitransformer': iTransformer,\n", " 'bitcn': BiTCN, 'autobitcn': BiTCN,\n", + " 'tide': TiDE, 'autotide': TiDE,\n", "}" ] }, diff --git a/nbs/imgs_models/tide.png b/nbs/imgs_models/tide.png new file mode 100644 index 000000000..d10b5b01b Binary files /dev/null and b/nbs/imgs_models/tide.png differ diff --git a/nbs/models.ipynb b/nbs/models.ipynb index 9e437cea8..428eeabc1 100644 --- a/nbs/models.ipynb +++ b/nbs/models.ipynb @@ -53,6 +53,7 @@ "from neuralforecast.models.nhits import NHITS\n", "from neuralforecast.models.dlinear import DLinear\n", "from neuralforecast.models.nlinear import NLinear\n", + "from neuralforecast.models.tide import TiDE\n", "\n", "from neuralforecast.models.tft import TFT\n", "from neuralforecast.models.vanillatransformer import VanillaTransformer\n", @@ -2013,6 +2014,401 @@ "model.fit(dataset=dataset)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "973a470e", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class AutoTiDE(BaseAuto):\n", + "\n", + " default_config = {\n", + " \"input_size_multiplier\": [1, 2, 3, 4, 5],\n", + " \"h\": None,\n", + " \"hidden_size\": tune.choice([256, 512, 1024]),\n", + " \"decoder_output_dim\": tune.choice([8, 16, 32]),\n", + " \"temporal_decoder_dim\": tune.choice([32, 64, 128]),\n", + " \"num_encoder_layers\": tune.choice([1, 2, 3]),\n", + " \"num_decoder_layers\": tune.choice([1, 2, 3]),\n", + " \"temporal_width\": tune.choice([4, 8, 16]),\n", + " \"dropout\":tune.choice([0.0, 0.1, 0.2, 0.3, 0.5]),\n", + " \"layernorm\": tune.choice([True, False]),\n", + " \"learning_rate\": tune.loguniform(1e-5, 1e-2),\n", + " \"scaler_type\": tune.choice([None, 'robust', 'standard']),\n", + " \"max_steps\": tune.quniform(lower=500, upper=1500, q=100),\n", + " \"batch_size\": tune.choice([32, 64, 128, 256]),\n", + " \"windows_batch_size\": tune.choice([128, 256, 512, 1024]),\n", + " \"loss\": None,\n", + " \"random_seed\": tune.randint(lower=1, upper=20),\n", + " }\n", + "\n", + " def __init__(self,\n", + " h,\n", + " loss=MAE(),\n", + " valid_loss=None,\n", + " config=None, \n", + " search_alg=BasicVariantGenerator(random_state=1),\n", + " num_samples=10,\n", + " refit_with_val=False,\n", + " cpus=cpu_count(),\n", + " gpus=torch.cuda.device_count(),\n", + " verbose=False,\n", + " alias=None,\n", + " backend='ray',\n", + " callbacks=None,\n", + " ):\n", + "\n", + " # Define search space, input/output sizes\n", + " if config is None:\n", + " config = self.get_default_config(h=h, backend=backend) \n", + "\n", + " super(AutoTiDE, self).__init__(\n", + " cls_model=TiDE, \n", + " h=h,\n", + " loss=loss,\n", + " valid_loss=valid_loss,\n", + " config=config,\n", + " search_alg=search_alg,\n", + " num_samples=num_samples,\n", + " refit_with_val=refit_with_val,\n", + " cpus=cpus,\n", + " gpus=gpus,\n", + " verbose=verbose,\n", + " alias=alias,\n", + " backend=backend,\n", + " callbacks=callbacks,\n", + " )\n", + "\n", + " @classmethod\n", + " def get_default_config(cls, h, backend, n_series=None):\n", + " config = cls.default_config.copy()\n", + " config['input_size'] = tune.choice([h*x \\\n", + " for x in config['input_size_multiplier']])\n", + " config['step_size'] = tune.choice([1, h]) \n", + " del config['input_size_multiplier']\n", + " if backend == 'optuna':\n", + " config = cls._ray_config_to_optuna(config) \n", + "\n", + " return config " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d31d3bfa", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "### AutoTiDE\n", + "\n", + "> AutoTiDE (h, loss=MAE(), valid_loss=None, config=None,\n", + "> search_alg= object at 0x0000022D7EF8FC10>, num_samples=10,\n", + "> refit_with_val=False, cpus=20, gpus=1, verbose=False,\n", + "> alias=None, backend='ray', callbacks=None)\n", + "\n", + "Class for Automatic Hyperparameter Optimization, it builds on top of `ray` to\n", + "give access to a wide variety of hyperparameter optimization tools ranging\n", + "from classic grid search, to Bayesian optimization and HyperBand algorithm.\n", + "\n", + "The validation loss to be optimized is defined by the `config['loss']` dictionary\n", + "value, the config also contains the rest of the hyperparameter search space.\n", + "\n", + "It is important to note that the success of this hyperparameter optimization\n", + "heavily relies on a strong correlation between the validation and test periods.\n", + "\n", + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| h | int | | Forecast horizon |\n", + "| loss | MAE | MAE() | Instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html). |\n", + "| valid_loss | NoneType | None | Instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html). |\n", + "| config | NoneType | None | Dictionary with ray.tune defined search space or function that takes an optuna trial and returns a configuration dict. |\n", + "| search_alg | BasicVariantGenerator | | For ray see https://docs.ray.io/en/latest/tune/api_docs/suggestion.html
For optuna see https://optuna.readthedocs.io/en/stable/reference/samplers/index.html. |\n", + "| num_samples | int | 10 | Number of hyperparameter optimization steps/samples. |\n", + "| refit_with_val | bool | False | Refit of best model should preserve val_size. |\n", + "| cpus | int | 20 | Number of cpus to use during optimization. Only used with ray tune. |\n", + "| gpus | int | 1 | Number of gpus to use during optimization, default all available. Only used with ray tune. |\n", + "| verbose | bool | False | Track progress. |\n", + "| alias | NoneType | None | Custom name of the model. |\n", + "| backend | str | ray | Backend to use for searching the hyperparameter space, can be either 'ray' or 'optuna'. |\n", + "| callbacks | NoneType | None | List of functions to call during the optimization process.
ray reference: https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html
optuna reference: https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/007_optuna_callback.html |" + ], + "text/plain": [ + "---\n", + "\n", + "### AutoTiDE\n", + "\n", + "> AutoTiDE (h, loss=MAE(), valid_loss=None, config=None,\n", + "> search_alg= object at 0x0000022D7EF8FC10>, num_samples=10,\n", + "> refit_with_val=False, cpus=20, gpus=1, verbose=False,\n", + "> alias=None, backend='ray', callbacks=None)\n", + "\n", + "Class for Automatic Hyperparameter Optimization, it builds on top of `ray` to\n", + "give access to a wide variety of hyperparameter optimization tools ranging\n", + "from classic grid search, to Bayesian optimization and HyperBand algorithm.\n", + "\n", + "The validation loss to be optimized is defined by the `config['loss']` dictionary\n", + "value, the config also contains the rest of the hyperparameter search space.\n", + "\n", + "It is important to note that the success of this hyperparameter optimization\n", + "heavily relies on a strong correlation between the validation and test periods.\n", + "\n", + "| | **Type** | **Default** | **Details** |\n", + "| -- | -------- | ----------- | ----------- |\n", + "| h | int | | Forecast horizon |\n", + "| loss | MAE | MAE() | Instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html). |\n", + "| valid_loss | NoneType | None | Instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html). |\n", + "| config | NoneType | None | Dictionary with ray.tune defined search space or function that takes an optuna trial and returns a configuration dict. |\n", + "| search_alg | BasicVariantGenerator | | For ray see https://docs.ray.io/en/latest/tune/api_docs/suggestion.html
For optuna see https://optuna.readthedocs.io/en/stable/reference/samplers/index.html. |\n", + "| num_samples | int | 10 | Number of hyperparameter optimization steps/samples. |\n", + "| refit_with_val | bool | False | Refit of best model should preserve val_size. |\n", + "| cpus | int | 20 | Number of cpus to use during optimization. Only used with ray tune. |\n", + "| gpus | int | 1 | Number of gpus to use during optimization, default all available. Only used with ray tune. |\n", + "| verbose | bool | False | Track progress. |\n", + "| alias | NoneType | None | Custom name of the model. |\n", + "| backend | str | ray | Backend to use for searching the hyperparameter space, can be either 'ray' or 'optuna'. |\n", + "| callbacks | NoneType | None | List of functions to call during the optimization process.
ray reference: https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html
optuna reference: https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/007_optuna_callback.html |" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "show_doc(AutoTiDE, title_level=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ae8f192", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-15 19:19:42,074\tINFO worker.py:1752 -- Started a local Ray instance.\n", + "2024-04-15 19:19:43,810\tINFO tune.py:263 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `Tuner(...)`.\n", + "2024-04-15 19:19:43,813\tINFO tune.py:613 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949\n", + "2024-04-15 19:19:50,851\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to 'C:/Users/ospra/ray_results/_train_tune_2024-04-15_19-19-40' in 0.0053s.\n", + "Seed set to 1\n" + ] + } + ], + "source": [ + "%%capture\n", + "# Use your own config or AutoTiDE.default_config\n", + "config = dict(max_steps=2, val_check_steps=1, input_size=12)\n", + "model = AutoTiDE(h=12, config=config, num_samples=1, cpus=1)\n", + "\n", + "# Fit and predict\n", + "model.fit(dataset=dataset)\n", + "y_hat = model.predict(dataset=dataset)\n", + "\n", + "# Optuna\n", + "model = AutoTiDE(h=12, config=None, backend='optuna')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d66600b9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[36m(_train_tune pid=30124)\u001b[0m c:\\Users\\ospra\\miniconda3\\envs\\neuralforecast\\lib\\site-packages\\ray\\tune\\integration\\pytorch_lightning.py:194: `ray.tune.integration.pytorch_lightning.TuneReportCallback` is deprecated. Use `ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback` instead.\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m c:\\Users\\ospra\\miniconda3\\envs\\neuralforecast\\lib\\site-packages\\pytorch_lightning\\utilities\\parsing.py:199: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m c:\\Users\\ospra\\miniconda3\\envs\\neuralforecast\\lib\\site-packages\\pytorch_lightning\\utilities\\parsing.py:199: Attribute 'valid_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['valid_loss'])`.\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m Seed set to 11\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m GPU available: True (cuda), used: True\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m TPU available: False, using: 0 TPU cores\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m IPU available: False, using: 0 IPUs\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m HPU available: False, using: 0 HPUs\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m `Trainer(val_check_interval=1)` was configured so validation will run after every batch.\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m Missing logger folder: C:\\Users\\ospra\\AppData\\Local\\Temp\\ray\\session_2024-04-15_19-19-40_426885_27112\\artifacts\\2024-04-15_19-19-55\\_train_tune_2024-04-15_19-19-55\\working_dirs\\_train_tune_55d90_00000\\lightning_logs\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m | Name | Type | Params\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m ---------------------------------------------------\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 0 | loss | MAE | 0 \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 1 | padder_train | ConstantPad1d | 0 \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 2 | scaler | TemporalNorm | 0 \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 3 | dense_encoder | Sequential | 1.1 M \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 4 | dense_decoder | Sequential | 361 K \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 5 | temporal_decoder | MLPResidual | 1.3 K \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 6 | global_skip | Linear | 156 \n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m ---------------------------------------------------\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 1.4 M Trainable params\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 0 Non-trainable params\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 1.4 M Total params\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m 5.706 Total estimated model params size (MB)\n", + "\u001b[36m(_train_tune pid=30124)\u001b[0m c:\\Users\\ospra\\miniconda3\\envs\\neuralforecast\\lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sanity Checking: | | 0/? [00:00 Time-series Dense Encoder (`TiDE`) is a MLP-based univariate time-series forecasting model. `TiDE` uses Multi-layer Perceptrons (MLPs) in an encoder-decoder model for long-term time-series forecasting. In addition, this model can handle exogenous inputs.\n", + "\n", + "

**References**
-[Das, Abhimanyu, Weihao Kong, Andrew Leach, Shaan Mathur, Rajat Sen, and Rose Yu (2024). \"Long-term Forecasting with TiDE: Time-series Dense Encoder.\"](http://arxiv.org/abs/2304.08424)
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Figure 1. TiDE architecture.](imgs_models/tide.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "from fastcore.test import test_eq\n", + "from nbdev.showdoc import show_doc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from typing import Optional\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "from neuralforecast.losses.pytorch import MAE\n", + "from neuralforecast.common._base_windows import BaseWindows" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Auxiliary Functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1.1 MLP residual\n", + "An MLP block with a residual connection." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", + "class MLPResidual(nn.Module):\n", + " def __init__(self, input_dim, hidden_size, output_dim, dropout, layernorm):\n", + " super().__init__()\n", + " self.layernorm = layernorm\n", + " if layernorm:\n", + " self.norm = nn.LayerNorm(output_dim)\n", + "\n", + " self.drop = nn.Dropout(dropout)\n", + " self.lin1 = nn.Linear(input_dim, hidden_size)\n", + " self.lin2 = nn.Linear(hidden_size, output_dim)\n", + " self.skip = nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, input):\n", + " # MLP dense\n", + " x = F.relu(self.lin1(input)) \n", + " x = self.lin2(x)\n", + " x = self.drop(x)\n", + "\n", + " # Skip connection\n", + " x_skip = self.skip(input)\n", + "\n", + " # Combine\n", + " x = x + x_skip\n", + "\n", + " if self.layernorm:\n", + " return self.norm(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class TiDE(BaseWindows):\n", + " \"\"\" TiDE\n", + "\n", + " Time-series Dense Encoder (`TiDE`) is a MLP-based univariate time-series forecasting model. `TiDE` uses Multi-layer Perceptrons (MLPs) in an encoder-decoder model for long-term time-series forecasting.\n", + "\n", + " **Parameters:**
\n", + " `h`: int, forecast horizon.
\n", + " `input_size`: int, considered autorregresive inputs (lags), y=[1,2,3,4] input_size=2 -> lags=[1,2].
\n", + " `hidden_size`: int=1024, number of units for the dense MLPs.
\n", + " `decoder_output_dim`: int=32, number of units for the output of the decoder.
\n", + " `temporal_decoder_dim`: int=128, number of units for the hidden sizeof the temporal decoder.
\n", + " `dropout`: float=0.0, dropout rate between (0, 1) .
\n", + " `layernorm`: bool=True, if True uses Layer Normalization on the MLP residual block outputs.
\n", + " `num_encoder_layers`: int=1, number of encoder layers.
\n", + " `num_decoder_layers`: int=1, number of decoder layers.
\n", + " `temporal_width`: int=4, lower temporal projected dimension.
\n", + " `futr_exog_list`: str list, future exogenous columns.
\n", + " `hist_exog_list`: str list, historic exogenous columns.
\n", + " `stat_exog_list`: str list, static exogenous columns.
\n", + " `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", + " `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", + " `max_steps`: int=1000, maximum number of training steps.
\n", + " `learning_rate`: float=1e-3, Learning rate between (0, 1).
\n", + " `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n", + " `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n", + " `val_check_steps`: int=100, Number of training steps between every validation loss check.
\n", + " `batch_size`: int=32, number of different series in each batch.
\n", + " `step_size`: int=1, step size between each window of temporal data.
\n", + " `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n", + " `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n", + " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", + " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", + " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", + " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", + "\n", + " **References:**
\n", + " - [Das, Abhimanyu, Weihao Kong, Andrew Leach, Shaan Mathur, Rajat Sen, and Rose Yu (2024). \"Long-term Forecasting with TiDE: Time-series Dense Encoder.\"](http://arxiv.org/abs/2304.08424)\n", + "\n", + " \"\"\"\n", + " # Class attributes\n", + " SAMPLING_TYPE = 'windows'\n", + " \n", + " def __init__(self,\n", + " h,\n", + " input_size, \n", + " hidden_size = 512,\n", + " decoder_output_dim = 32,\n", + " temporal_decoder_dim = 128,\n", + " dropout = 0.3,\n", + " layernorm=True,\n", + " num_encoder_layers = 1,\n", + " num_decoder_layers = 1,\n", + " temporal_width = 4,\n", + " futr_exog_list = None,\n", + " hist_exog_list = None,\n", + " stat_exog_list = None,\n", + " exclude_insample_y = False,\n", + " loss = MAE(),\n", + " valid_loss = None,\n", + " max_steps: int = 1000,\n", + " learning_rate: float = 1e-3,\n", + " num_lr_decays: int = -1,\n", + " early_stop_patience_steps: int =-1,\n", + " val_check_steps: int = 100,\n", + " batch_size: int = 32,\n", + " valid_batch_size: Optional[int] = None,\n", + " windows_batch_size = 1024,\n", + " inference_windows_batch_size = 1024,\n", + " start_padding_enabled = False,\n", + " step_size: int = 1,\n", + " scaler_type: str = 'identity',\n", + " random_seed: int = 1,\n", + " num_workers_loader: int = 0,\n", + " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", + " **trainer_kwargs):\n", + "\n", + " # Inherit BaseWindows class\n", + " super(TiDE, self).__init__(\n", + " h=h,\n", + " input_size=input_size,\n", + " futr_exog_list=futr_exog_list,\n", + " hist_exog_list=hist_exog_list,\n", + " stat_exog_list=stat_exog_list,\n", + " exclude_insample_y = exclude_insample_y,\n", + " loss=loss,\n", + " valid_loss=valid_loss,\n", + " max_steps=max_steps,\n", + " learning_rate=learning_rate,\n", + " num_lr_decays=num_lr_decays,\n", + " early_stop_patience_steps=early_stop_patience_steps,\n", + " val_check_steps=val_check_steps,\n", + " batch_size=batch_size,\n", + " valid_batch_size=valid_batch_size,\n", + " windows_batch_size=windows_batch_size,\n", + " inference_windows_batch_size=inference_windows_batch_size,\n", + " start_padding_enabled=start_padding_enabled,\n", + " step_size=step_size,\n", + " scaler_type=scaler_type,\n", + " random_seed=random_seed,\n", + " num_workers_loader=num_workers_loader,\n", + " drop_last_loader=drop_last_loader,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", + " **trainer_kwargs\n", + " ) \n", + " self.h = h\n", + "\n", + " self.futr_exog_size = len(self.futr_exog_list)\n", + " self.hist_exog_size = len(self.hist_exog_list)\n", + " self.stat_exog_size = len(self.stat_exog_list) \n", + "\n", + " if self.hist_exog_size > 0 or self.futr_exog_size > 0:\n", + " self.hist_exog_projection = MLPResidual(input_dim = self.hist_exog_size,\n", + " hidden_size=hidden_size,\n", + " output_dim=temporal_width,\n", + " dropout=dropout,\n", + " layernorm=layernorm) \n", + " if self.futr_exog_size > 0:\n", + " self.futr_exog_projection = MLPResidual(input_dim = self.futr_exog_size,\n", + " hidden_size = hidden_size,\n", + " output_dim=temporal_width,\n", + " dropout=dropout,\n", + " layernorm=layernorm)\n", + "\n", + " # Encoder\n", + " dense_encoder_input_size = input_size + \\\n", + " input_size * (self.hist_exog_size > 0) * temporal_width + \\\n", + " (input_size + h) * (self.futr_exog_size > 0) * temporal_width + \\\n", + " (self.stat_exog_size > 0) * self.stat_exog_size\n", + "\n", + " dense_encoder_layers = [MLPResidual(input_dim=dense_encoder_input_size if i == 0 else hidden_size,\n", + " hidden_size=hidden_size,\n", + " output_dim=hidden_size,\n", + " dropout=dropout,\n", + " layernorm=layernorm) for i in range(num_encoder_layers)]\n", + " self.dense_encoder = nn.Sequential(*dense_encoder_layers)\n", + "\n", + " # Decoder\n", + " decoder_output_size = decoder_output_dim * h\n", + " dense_decoder_layers = [MLPResidual(input_dim=hidden_size,\n", + " hidden_size=hidden_size,\n", + " output_dim=decoder_output_size if i == num_decoder_layers - 1 else hidden_size,\n", + " dropout=dropout,\n", + " layernorm=layernorm) for i in range(num_decoder_layers)]\n", + " self.dense_decoder = nn.Sequential(*dense_decoder_layers)\n", + "\n", + " # Temporal decoder with loss dependent dimensions\n", + " self.temporal_decoder = MLPResidual(input_dim = decoder_output_dim + (self.futr_exog_size > 0) * temporal_width,\n", + " hidden_size = temporal_decoder_dim,\n", + " output_dim=self.loss.outputsize_multiplier,\n", + " dropout=dropout,\n", + " layernorm=layernorm)\n", + "\n", + "\n", + " # Global skip connection\n", + " self.global_skip = nn.Linear(in_features = input_size,\n", + " out_features = h * self.loss.outputsize_multiplier)\n", + "\n", + " def forward(self, windows_batch):\n", + " # Parse windows_batch\n", + " x = windows_batch['insample_y'].unsqueeze(-1) # [B, L, 1]\n", + " hist_exog = windows_batch['hist_exog'] # [B, L, X]\n", + " futr_exog = windows_batch['futr_exog'] # [B, L + h, F]\n", + " stat_exog = windows_batch['stat_exog'] # [B, S]\n", + " batch_size, seq_len = x.shape[:2] # B = batch_size, L = seq_len\n", + "\n", + " # Flatten insample_y\n", + " x = x.reshape(batch_size, -1) # [B, L, 1] -> [B, L]\n", + "\n", + " # Global skip connection\n", + " x_skip = self.global_skip(x) # [B, L] -> [B, h * n_outputs]\n", + " x_skip = x_skip.reshape(batch_size, self.h, -1) # [B, h * n_outputs] -> [B, h, n_outputs]\n", + "\n", + " # Concatenate x with flattened historical exogenous\n", + " if self.hist_exog_size > 0:\n", + " x_hist_exog = self.hist_exog_projection(hist_exog) # [B, L, X] -> [B, L, temporal_width]\n", + " x_hist_exog = x_hist_exog.reshape(batch_size, -1) # [B, L, temporal_width] -> [B, L * temporal_width]\n", + " x = torch.cat((x, x_hist_exog), dim=1) # [B, L] + [B, L * temporal_width] -> [B, L * (1 + temporal_width)]\n", + "\n", + " # Concatenate x with flattened future exogenous\n", + " if self.futr_exog_size > 0:\n", + " x_futr_exog = self.futr_exog_projection(futr_exog) # [B, L + h, F] -> [B, L + h, temporal_width]\n", + " x_futr_exog_flat = x_futr_exog.reshape(batch_size, -1) # [B, L + h, temporal_width] -> [B, (L + h) * temporal_width]\n", + " x = torch.cat((x, x_futr_exog_flat), dim=1) # [B, L * (1 + temporal_width)] + [B, (L + h) * temporal_width] -> [B, L * (1 + 2 * temporal_width) + h * temporal_width]\n", + "\n", + " # Concatenate x with static exogenous\n", + " if self.stat_exog_size > 0:\n", + " x = torch.cat((x, stat_exog), dim=1) # [B, L * (1 + 2 * temporal_width) + h * temporal_width] + [B, S] -> [B, L * (1 + 2 * temporal_width) + h * temporal_width + S]\n", + "\n", + " # Dense encoder\n", + " x = self.dense_encoder(x) # [B, L * (1 + 2 * temporal_width) + h * temporal_width + S] -> [B, hidden_size]\n", + "\n", + " # Dense decoder\n", + " x = self.dense_decoder(x) # [B, hidden_size] -> [B, decoder_output_dim * h]\n", + " x = x.reshape(batch_size, self.h, -1) # [B, decoder_output_dim * h] -> [B, h, decoder_output_dim]\n", + "\n", + " # Stack with futr_exog for horizon part of futr_exog\n", + " if self.futr_exog_size > 0:\n", + " x_futr_exog_h = x_futr_exog[:, seq_len:] # [B, L + h, temporal_width] -> [B, h, temporal_width]\n", + " x = torch.cat((x, x_futr_exog_h), dim=2) # [B, h, decoder_output_dim] + [B, h, temporal_width] -> [B, h, temporal_width + decoder_output_dim]\n", + "\n", + " # Temporal decoder\n", + " x = self.temporal_decoder(x) # [B, h, temporal_width + decoder_output_dim] -> [B, h, n_outputs]\n", + "\n", + " # Map to output domain\n", + " forecast = self.loss.domain_map(x + x_skip)\n", + " \n", + " return forecast\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(TiDE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(TiDE.fit, name='TiDE.fit')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(TiDE.predict, name='TiDE.predict')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Usage Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Train model and forecast future values with `predict` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| eval: false\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from neuralforecast.utils import AirPassengersDF as Y_df\n", + "from neuralforecast.tsdataset import TimeSeriesDataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| eval: false\n", + "Y_train_df = Y_df[Y_df.ds<='1959-12-31'] # 132 train\n", + "Y_test_df = Y_df[Y_df.ds>'1959-12-31'] # 12 test\n", + "\n", + "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n", + "model = TiDE(h=12, input_size=24, max_steps=500, scaler_type='standard')\n", + "model.fit(dataset=dataset)\n", + "y_hat = model.predict(dataset=dataset)\n", + "Y_test_df['TiDE'] = y_hat\n", + "\n", + "#test we recover the same forecast\n", + "y_hat2 = model.predict(dataset=dataset)\n", + "test_eq(y_hat, y_hat2)\n", + "\n", + "pd.concat([Y_train_df, Y_test_df]).drop('unique_id', axis=1).set_index('ds').plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating probabilistic forecasts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| eval: false\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pytorch_lightning as pl\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from neuralforecast import NeuralForecast\n", + "from neuralforecast.losses.pytorch import GMM, DistributionLoss\n", + "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| eval: false\n", + "# Plot predictions\n", + "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n", + "\n", + "fcst = NeuralForecast(\n", + " models=[\n", + " TiDE(h=12,\n", + " input_size=24,\n", + " loss=GMM(n_components=7, return_params=True, level=[80,90]),\n", + " max_steps=500,\n", + " scaler_type='standard',\n", + " futr_exog_list=['y_[lag12]'],\n", + " hist_exog_list=None,\n", + " stat_exog_list=['airline1'],\n", + " ), \n", + " ],\n", + " freq='M'\n", + ")\n", + "fcst.fit(df=Y_train_df, static_df=AirPassengersStatic)\n", + "forecasts = fcst.predict(futr_df=Y_test_df)\n", + "\n", + "# Plot quantile predictions\n", + "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n", + "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n", + "plot_df = pd.concat([Y_train_df, plot_df])\n", + "\n", + "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n", + "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n", + "plt.plot(plot_df['ds'], plot_df['TiDE-median'], c='blue', label='median')\n", + "plt.fill_between(x=plot_df['ds'][-12:], \n", + " y1=plot_df['TiDE-lo-90'][-12:].values,\n", + " y2=plot_df['TiDE-hi-90'][-12:].values,\n", + " alpha=0.4, label='level 90')\n", + "plt.legend()\n", + "plt.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/neuralforecast/_modidx.py b/neuralforecast/_modidx.py index 5c10130b3..4bcbdabad 100644 --- a/neuralforecast/_modidx.py +++ b/neuralforecast/_modidx.py @@ -114,6 +114,10 @@ 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoTSMixerx.get_default_config': ( 'models.html#autotsmixerx.get_default_config', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoTiDE': ('models.html#autotide', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoTiDE.__init__': ('models.html#autotide.__init__', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoTiDE.get_default_config': ( 'models.html#autotide.get_default_config', + 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoTimesNet': ('models.html#autotimesnet', 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoTimesNet.__init__': ( 'models.html#autotimesnet.__init__', 'neuralforecast/auto.py'), @@ -1021,6 +1025,17 @@ 'neuralforecast/models/tft.py'), 'neuralforecast.models.tft.VariableSelectionNetwork.forward': ( 'models.tft.html#variableselectionnetwork.forward', 'neuralforecast/models/tft.py')}, + 'neuralforecast.models.tide': { 'neuralforecast.models.tide.MLPResidual': ( 'models.tide.html#mlpresidual', + 'neuralforecast/models/tide.py'), + 'neuralforecast.models.tide.MLPResidual.__init__': ( 'models.tide.html#mlpresidual.__init__', + 'neuralforecast/models/tide.py'), + 'neuralforecast.models.tide.MLPResidual.forward': ( 'models.tide.html#mlpresidual.forward', + 'neuralforecast/models/tide.py'), + 'neuralforecast.models.tide.TiDE': ('models.tide.html#tide', 'neuralforecast/models/tide.py'), + 'neuralforecast.models.tide.TiDE.__init__': ( 'models.tide.html#tide.__init__', + 'neuralforecast/models/tide.py'), + 'neuralforecast.models.tide.TiDE.forward': ( 'models.tide.html#tide.forward', + 'neuralforecast/models/tide.py')}, 'neuralforecast.models.timellm': { 'neuralforecast.models.timellm.FlattenHead': ( 'models.timellm.html#flattenhead', 'neuralforecast/models/timellm.py'), 'neuralforecast.models.timellm.FlattenHead.__init__': ( 'models.timellm.html#flattenhead.__init__', diff --git a/neuralforecast/auto.py b/neuralforecast/auto.py index ebfe1398c..7b898815f 100644 --- a/neuralforecast/auto.py +++ b/neuralforecast/auto.py @@ -2,7 +2,7 @@ # %% auto 0 __all__ = ['AutoRNN', 'AutoLSTM', 'AutoGRU', 'AutoTCN', 'AutoDeepAR', 'AutoDilatedRNN', 'AutoBiTCN', 'AutoMLP', 'AutoNBEATS', - 'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTFT', 'AutoVanillaTransformer', + 'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTiDE', 'AutoTFT', 'AutoVanillaTransformer', 'AutoInformer', 'AutoAutoformer', 'AutoFEDformer', 'AutoPatchTST', 'AutoiTransformer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer', 'AutoTSMixerx', 'AutoMLPMultivariate'] @@ -30,6 +30,7 @@ from .models.nhits import NHITS from .models.dlinear import DLinear from .models.nlinear import NLinear +from .models.tide import TiDE from .models.tft import TFT from .models.vanillatransformer import VanillaTransformer @@ -958,7 +959,81 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 67 +# %% ../nbs/models.ipynb 66 +class AutoTiDE(BaseAuto): + + default_config = { + "input_size_multiplier": [1, 2, 3, 4, 5], + "h": None, + "hidden_size": tune.choice([256, 512, 1024]), + "decoder_output_dim": tune.choice([8, 16, 32]), + "temporal_decoder_dim": tune.choice([32, 64, 128]), + "num_encoder_layers": tune.choice([1, 2, 3]), + "num_decoder_layers": tune.choice([1, 2, 3]), + "temporal_width": tune.choice([4, 8, 16]), + "dropout": tune.choice([0.0, 0.1, 0.2, 0.3, 0.5]), + "layernorm": tune.choice([True, False]), + "learning_rate": tune.loguniform(1e-5, 1e-2), + "scaler_type": tune.choice([None, "robust", "standard"]), + "max_steps": tune.quniform(lower=500, upper=1500, q=100), + "batch_size": tune.choice([32, 64, 128, 256]), + "windows_batch_size": tune.choice([128, 256, 512, 1024]), + "loss": None, + "random_seed": tune.randint(lower=1, upper=20), + } + + def __init__( + self, + h, + loss=MAE(), + valid_loss=None, + config=None, + search_alg=BasicVariantGenerator(random_state=1), + num_samples=10, + refit_with_val=False, + cpus=cpu_count(), + gpus=torch.cuda.device_count(), + verbose=False, + alias=None, + backend="ray", + callbacks=None, + ): + + # Define search space, input/output sizes + if config is None: + config = self.get_default_config(h=h, backend=backend) + + super(AutoTiDE, self).__init__( + cls_model=TiDE, + h=h, + loss=loss, + valid_loss=valid_loss, + config=config, + search_alg=search_alg, + num_samples=num_samples, + refit_with_val=refit_with_val, + cpus=cpus, + gpus=gpus, + verbose=verbose, + alias=alias, + backend=backend, + callbacks=callbacks, + ) + + @classmethod + def get_default_config(cls, h, backend, n_series=None): + config = cls.default_config.copy() + config["input_size"] = tune.choice( + [h * x for x in config["input_size_multiplier"]] + ) + config["step_size"] = tune.choice([1, h]) + del config["input_size_multiplier"] + if backend == "optuna": + config = cls._ray_config_to_optuna(config) + + return config + +# %% ../nbs/models.ipynb 71 class AutoTFT(BaseAuto): default_config = { @@ -1026,7 +1101,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 71 +# %% ../nbs/models.ipynb 75 class AutoVanillaTransformer(BaseAuto): default_config = { @@ -1094,7 +1169,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 75 +# %% ../nbs/models.ipynb 79 class AutoInformer(BaseAuto): default_config = { @@ -1162,7 +1237,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 79 +# %% ../nbs/models.ipynb 83 class AutoAutoformer(BaseAuto): default_config = { @@ -1230,7 +1305,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 83 +# %% ../nbs/models.ipynb 87 class AutoFEDformer(BaseAuto): default_config = { @@ -1297,7 +1372,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 87 +# %% ../nbs/models.ipynb 91 class AutoPatchTST(BaseAuto): default_config = { @@ -1367,7 +1442,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 91 +# %% ../nbs/models.ipynb 95 class AutoiTransformer(BaseAuto): default_config = { @@ -1452,7 +1527,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 96 +# %% ../nbs/models.ipynb 100 class AutoTimesNet(BaseAuto): default_config = { @@ -1520,7 +1595,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 101 +# %% ../nbs/models.ipynb 105 class AutoStemGNN(BaseAuto): default_config = { @@ -1605,7 +1680,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 105 +# %% ../nbs/models.ipynb 109 class AutoHINT(BaseAuto): def __init__( @@ -1677,7 +1752,7 @@ def _fit_model( def get_default_config(cls, h, backend, n_series=None): raise Exception("AutoHINT has no default configuration.") -# %% ../nbs/models.ipynb 110 +# %% ../nbs/models.ipynb 114 class AutoTSMixer(BaseAuto): default_config = { @@ -1763,7 +1838,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 114 +# %% ../nbs/models.ipynb 118 class AutoTSMixerx(BaseAuto): default_config = { @@ -1849,7 +1924,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 118 +# %% ../nbs/models.ipynb 122 class AutoMLPMultivariate(BaseAuto): default_config = { diff --git a/neuralforecast/core.py b/neuralforecast/core.py index 6725a69c4..b13338d4c 100644 --- a/neuralforecast/core.py +++ b/neuralforecast/core.py @@ -57,6 +57,7 @@ MLPMultivariate, iTransformer, BiTCN, + TiDE, ) # %% ../nbs/core.ipynb 5 @@ -170,6 +171,8 @@ def _insample_times( "autoitransformer": iTransformer, "bitcn": BiTCN, "autobitcn": BiTCN, + "tide": TiDE, + "autotide": TiDE, } # %% ../nbs/core.ipynb 8 diff --git a/neuralforecast/models/__init__.py b/neuralforecast/models/__init__.py index d4a6ead9d..fbca72d6e 100644 --- a/neuralforecast/models/__init__.py +++ b/neuralforecast/models/__init__.py @@ -2,7 +2,7 @@ 'MLP', 'NHITS', 'NBEATS', 'NBEATSx', 'DLinear', 'NLinear', 'TFT', 'VanillaTransformer', 'Informer', 'Autoformer', 'PatchTST', 'FEDformer', 'StemGNN', 'HINT', 'TimesNet', 'TimeLLM', 'TSMixer', 'TSMixerx', 'MLPMultivariate', - 'iTransformer', 'BiTCN', + 'iTransformer', 'BiTCN', 'TiDE', ] from .rnn import RNN @@ -32,4 +32,5 @@ from .mlpmultivariate import MLPMultivariate from .itransformer import iTransformer from .bitcn import BiTCN +from .tide import TiDE diff --git a/neuralforecast/models/tide.py b/neuralforecast/models/tide.py new file mode 100644 index 000000000..85f33e49f --- /dev/null +++ b/neuralforecast/models/tide.py @@ -0,0 +1,307 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.tide.ipynb. + +# %% auto 0 +__all__ = ['TiDE'] + +# %% ../../nbs/models.tide.ipynb 5 +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..losses.pytorch import MAE +from ..common._base_windows import BaseWindows + +# %% ../../nbs/models.tide.ipynb 8 +class MLPResidual(nn.Module): + def __init__(self, input_dim, hidden_size, output_dim, dropout, layernorm): + super().__init__() + self.layernorm = layernorm + if layernorm: + self.norm = nn.LayerNorm(output_dim) + + self.drop = nn.Dropout(dropout) + self.lin1 = nn.Linear(input_dim, hidden_size) + self.lin2 = nn.Linear(hidden_size, output_dim) + self.skip = nn.Linear(input_dim, output_dim) + + def forward(self, input): + # MLP dense + x = F.relu(self.lin1(input)) + x = self.lin2(x) + x = self.drop(x) + + # Skip connection + x_skip = self.skip(input) + + # Combine + x = x + x_skip + + if self.layernorm: + return self.norm(x) + + return x + +# %% ../../nbs/models.tide.ipynb 10 +class TiDE(BaseWindows): + """TiDE + + Time-series Dense Encoder (`TiDE`) is a MLP-based univariate time-series forecasting model. `TiDE` uses Multi-layer Perceptrons (MLPs) in an encoder-decoder model for long-term time-series forecasting. + + **Parameters:**
+ `h`: int, forecast horizon.
+ `input_size`: int, considered autorregresive inputs (lags), y=[1,2,3,4] input_size=2 -> lags=[1,2].
+ `hidden_size`: int=1024, number of units for the dense MLPs.
+ `decoder_output_dim`: int=32, number of units for the output of the decoder.
+ `temporal_decoder_dim`: int=128, number of units for the hidden sizeof the temporal decoder.
+ `dropout`: float=0.0, dropout rate between (0, 1) .
+ `layernorm`: bool=True, if True uses Layer Normalization on the MLP residual block outputs.
+ `num_encoder_layers`: int=1, number of encoder layers.
+ `num_decoder_layers`: int=1, number of decoder layers.
+ `temporal_width`: int=4, lower temporal projected dimension.
+ `futr_exog_list`: str list, future exogenous columns.
+ `hist_exog_list`: str list, historic exogenous columns.
+ `stat_exog_list`: str list, static exogenous columns.
+ `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `max_steps`: int=1000, maximum number of training steps.
+ `learning_rate`: float=1e-3, Learning rate between (0, 1).
+ `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
+ `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
+ `val_check_steps`: int=100, Number of training steps between every validation loss check.
+ `batch_size`: int=32, number of different series in each batch.
+ `step_size`: int=1, step size between each window of temporal data.
+ `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
+ `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
+ `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
+ `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
+ `alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
+ `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
+ + **References:**
+ - [Das, Abhimanyu, Weihao Kong, Andrew Leach, Shaan Mathur, Rajat Sen, and Rose Yu (2024). "Long-term Forecasting with TiDE: Time-series Dense Encoder."](http://arxiv.org/abs/2304.08424) + + """ + + # Class attributes + SAMPLING_TYPE = "windows" + + def __init__( + self, + h, + input_size, + hidden_size=512, + decoder_output_dim=32, + temporal_decoder_dim=128, + dropout=0.3, + layernorm=True, + num_encoder_layers=1, + num_decoder_layers=1, + temporal_width=4, + futr_exog_list=None, + hist_exog_list=None, + stat_exog_list=None, + exclude_insample_y=False, + loss=MAE(), + valid_loss=None, + max_steps: int = 1000, + learning_rate: float = 1e-3, + num_lr_decays: int = -1, + early_stop_patience_steps: int = -1, + val_check_steps: int = 100, + batch_size: int = 32, + valid_batch_size: Optional[int] = None, + windows_batch_size=1024, + inference_windows_batch_size=1024, + start_padding_enabled=False, + step_size: int = 1, + scaler_type: str = "identity", + random_seed: int = 1, + num_workers_loader: int = 0, + drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, + **trainer_kwargs + ): + + # Inherit BaseWindows class + super(TiDE, self).__init__( + h=h, + input_size=input_size, + futr_exog_list=futr_exog_list, + hist_exog_list=hist_exog_list, + stat_exog_list=stat_exog_list, + exclude_insample_y=exclude_insample_y, + loss=loss, + valid_loss=valid_loss, + max_steps=max_steps, + learning_rate=learning_rate, + num_lr_decays=num_lr_decays, + early_stop_patience_steps=early_stop_patience_steps, + val_check_steps=val_check_steps, + batch_size=batch_size, + valid_batch_size=valid_batch_size, + windows_batch_size=windows_batch_size, + inference_windows_batch_size=inference_windows_batch_size, + start_padding_enabled=start_padding_enabled, + step_size=step_size, + scaler_type=scaler_type, + random_seed=random_seed, + num_workers_loader=num_workers_loader, + drop_last_loader=drop_last_loader, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + **trainer_kwargs + ) + self.h = h + + self.futr_exog_size = len(self.futr_exog_list) + self.hist_exog_size = len(self.hist_exog_list) + self.stat_exog_size = len(self.stat_exog_list) + + if self.hist_exog_size > 0 or self.futr_exog_size > 0: + self.hist_exog_projection = MLPResidual( + input_dim=self.hist_exog_size, + hidden_size=hidden_size, + output_dim=temporal_width, + dropout=dropout, + layernorm=layernorm, + ) + if self.futr_exog_size > 0: + self.futr_exog_projection = MLPResidual( + input_dim=self.futr_exog_size, + hidden_size=hidden_size, + output_dim=temporal_width, + dropout=dropout, + layernorm=layernorm, + ) + + # Encoder + dense_encoder_input_size = ( + input_size + + input_size * (self.hist_exog_size > 0) * temporal_width + + (input_size + h) * (self.futr_exog_size > 0) * temporal_width + + (self.stat_exog_size > 0) * self.stat_exog_size + ) + + dense_encoder_layers = [ + MLPResidual( + input_dim=dense_encoder_input_size if i == 0 else hidden_size, + hidden_size=hidden_size, + output_dim=hidden_size, + dropout=dropout, + layernorm=layernorm, + ) + for i in range(num_encoder_layers) + ] + self.dense_encoder = nn.Sequential(*dense_encoder_layers) + + # Decoder + decoder_output_size = decoder_output_dim * h + dense_decoder_layers = [ + MLPResidual( + input_dim=hidden_size, + hidden_size=hidden_size, + output_dim=( + decoder_output_size if i == num_decoder_layers - 1 else hidden_size + ), + dropout=dropout, + layernorm=layernorm, + ) + for i in range(num_decoder_layers) + ] + self.dense_decoder = nn.Sequential(*dense_decoder_layers) + + # Temporal decoder with loss dependent dimensions + self.temporal_decoder = MLPResidual( + input_dim=decoder_output_dim + (self.futr_exog_size > 0) * temporal_width, + hidden_size=temporal_decoder_dim, + output_dim=self.loss.outputsize_multiplier, + dropout=dropout, + layernorm=layernorm, + ) + + # Global skip connection + self.global_skip = nn.Linear( + in_features=input_size, out_features=h * self.loss.outputsize_multiplier + ) + + def forward(self, windows_batch): + # Parse windows_batch + x = windows_batch["insample_y"].unsqueeze(-1) # [B, L, 1] + hist_exog = windows_batch["hist_exog"] # [B, L, X] + futr_exog = windows_batch["futr_exog"] # [B, L + h, F] + stat_exog = windows_batch["stat_exog"] # [B, S] + batch_size, seq_len = x.shape[:2] # B = batch_size, L = seq_len + + # Flatten insample_y + x = x.reshape(batch_size, -1) # [B, L, 1] -> [B, L] + + # Global skip connection + x_skip = self.global_skip(x) # [B, L] -> [B, h * n_outputs] + x_skip = x_skip.reshape( + batch_size, self.h, -1 + ) # [B, h * n_outputs] -> [B, h, n_outputs] + + # Concatenate x with flattened historical exogenous + if self.hist_exog_size > 0: + x_hist_exog = self.hist_exog_projection( + hist_exog + ) # [B, L, X] -> [B, L, temporal_width] + x_hist_exog = x_hist_exog.reshape( + batch_size, -1 + ) # [B, L, temporal_width] -> [B, L * temporal_width] + x = torch.cat( + (x, x_hist_exog), dim=1 + ) # [B, L] + [B, L * temporal_width] -> [B, L * (1 + temporal_width)] + + # Concatenate x with flattened future exogenous + if self.futr_exog_size > 0: + x_futr_exog = self.futr_exog_projection( + futr_exog + ) # [B, L + h, F] -> [B, L + h, temporal_width] + x_futr_exog_flat = x_futr_exog.reshape( + batch_size, -1 + ) # [B, L + h, temporal_width] -> [B, (L + h) * temporal_width] + x = torch.cat( + (x, x_futr_exog_flat), dim=1 + ) # [B, L * (1 + temporal_width)] + [B, (L + h) * temporal_width] -> [B, L * (1 + 2 * temporal_width) + h * temporal_width] + + # Concatenate x with static exogenous + if self.stat_exog_size > 0: + x = torch.cat( + (x, stat_exog), dim=1 + ) # [B, L * (1 + 2 * temporal_width) + h * temporal_width] + [B, S] -> [B, L * (1 + 2 * temporal_width) + h * temporal_width + S] + + # Dense encoder + x = self.dense_encoder( + x + ) # [B, L * (1 + 2 * temporal_width) + h * temporal_width + S] -> [B, hidden_size] + + # Dense decoder + x = self.dense_decoder(x) # [B, hidden_size] -> [B, decoder_output_dim * h] + x = x.reshape( + batch_size, self.h, -1 + ) # [B, decoder_output_dim * h] -> [B, h, decoder_output_dim] + + # Stack with futr_exog for horizon part of futr_exog + if self.futr_exog_size > 0: + x_futr_exog_h = x_futr_exog[ + :, seq_len: + ] # [B, L + h, temporal_width] -> [B, h, temporal_width] + x = torch.cat( + (x, x_futr_exog_h), dim=2 + ) # [B, h, decoder_output_dim] + [B, h, temporal_width] -> [B, h, temporal_width + decoder_output_dim] + + # Temporal decoder + x = self.temporal_decoder( + x + ) # [B, h, temporal_width + decoder_output_dim] -> [B, h, n_outputs] + + # Map to output domain + forecast = self.loss.domain_map(x + x_skip) + + return forecast