From 61e8d2d65801f1697d2dcc23196eff8b49af662c Mon Sep 17 00:00:00 2001 From: Marco Date: Mon, 16 Sep 2024 13:40:38 -0400 Subject: [PATCH] FIX: timemixer shapes mismatch and doc update (#1138) * Use math.ceil to prevent shape mismatch * Show exog support for KAN in doc * FEAT: TimeLLM is faster and supports more LLMs (#1139) * Fix issue #950: Reduce TimeLLM setup time for training * Restore changes on the examples * Revert changes to nbs/models.ipynb, nbs/models.softs.ipynb and neuralforecast/_modidx.py * Revert changes to nbs/models.ipynb, nbs/models.softs.ipynb and neuralforecast/_modidx.py * Refactor code to dynamically load models with AutoModel, AutoTokenizer, and AutoConfig - Updated load_model_and_tokenizer function to use AutoModel, AutoTokenizer, and AutoConfig for flexible model loading. - Included default model(gpt2) for cases where the specified model fails to load. - Kept llm, llm_config, and llm_tokenizer arguments to minimize changes. - Changed llm from storing pretrained weights to accepting pretrained model path to reduce necessary modifications. This update enhances the flexibility and reliability of model loading based on received feedback while minimizing necessary changes. * Refactor code to dynamically load models with AutoModel, AutoTokenizer, and AutoConfig - Updated load_model_and_tokenizer function to use AutoModel, AutoTokenizer, and AutoConfig for flexible model loading. - Included default model(gpt2) for cases where the specified model fails to load. - Kept llm, llm_config, and llm_tokenizer arguments to minimize changes. - Changed llm from storing pretrained weights to accepting pretrained model path to reduce necessary modifications. This update enhances the flexibility and reliability of model loading based on received feedback while minimizing necessary changes. * clear output * modify test code * Optimize model loading and add deprecation warning - Simplify model loading logic - Add constant for default model name - Improve error handling for model loading - Add success messages for model loading - Implement deprecation warning for 'llm_config' and 'llm_tokenizer' parameters - Update print messages for clarity - Remove redundant code This commit improves code readability, maintainability, and user experience by providing clearer feedback and warnings about deprecated parameters. * Resolved conflict in nbs/models.timellm.ipynb --------- Co-authored-by: ive2go Co-authored-by: Olivier Sprangers <45119856+elephaint@users.noreply.github.com> * Consistency with math.ceil --------- Co-authored-by: Olivier Sprangers <45119856+elephaint@users.noreply.github.com> Co-authored-by: ive2go --- nbs/docs/capabilities/01_overview.ipynb | 2 +- nbs/models.timemixer.ipynb | 3283 +---------------------- neuralforecast/models/timemixer.py | 19 +- 3 files changed, 26 insertions(+), 3278 deletions(-) diff --git a/nbs/docs/capabilities/01_overview.ipynb b/nbs/docs/capabilities/01_overview.ipynb index a71c552c5..11b964a7f 100644 --- a/nbs/docs/capabilities/01_overview.ipynb +++ b/nbs/docs/capabilities/01_overview.ipynb @@ -25,7 +25,7 @@ "|`HINT` | `AutoHINT` | Any7 | Both7 | Both7 | F/H/S | \n", "|`Informer` | `AutoInformer` | Transformer | Multivariate | Direct | F | \n", "|`iTransformer` | `AutoiTransformer` | Transformer | Multivariate | Direct | - | \n", - "|`KAN` | `AutoKAN` | KAN | Univariate | Direct | - | \n", + "|`KAN` | `AutoKAN` | KAN | Univariate | Direct | F/H/S | \n", "|`LSTM` | `AutoLSTM` | RNN | Univariate | Recursive | F/H/S | \n", "|`MLP` | `AutoMLP` | MLP | Univariate | Direct | F/H/S | \n", "|`MLPMultivariate` | `AutoMLPMultivariate` | MLP | Multivariate | Direct | F/H/S | \n", diff --git a/nbs/models.timemixer.ipynb b/nbs/models.timemixer.ipynb index bccb36adf..0aacd694c 100644 --- a/nbs/models.timemixer.ipynb +++ b/nbs/models.timemixer.ipynb @@ -35,7 +35,7 @@ "outputs": [], "source": [ "#| export\n", - "\n", + "import math\n", "import numpy as np\n", "\n", "import torch\n", @@ -243,13 +243,13 @@ " [\n", " nn.Sequential(\n", " torch.nn.Linear(\n", - " seq_len // (down_sampling_window ** i),\n", - " seq_len // (down_sampling_window ** (i + 1)),\n", + " math.ceil(seq_len // (down_sampling_window ** i)),\n", + " math.ceil(seq_len // (down_sampling_window ** (i + 1))),\n", " ),\n", " nn.GELU(),\n", " torch.nn.Linear(\n", - " seq_len // (down_sampling_window ** (i + 1)),\n", - " seq_len // (down_sampling_window ** (i + 1)),\n", + " math.ceil(seq_len // (down_sampling_window ** (i + 1))),\n", + " math.ceil(seq_len // (down_sampling_window ** (i + 1))),\n", " ),\n", "\n", " )\n", @@ -287,13 +287,13 @@ " [\n", " nn.Sequential(\n", " torch.nn.Linear(\n", - " seq_len // (down_sampling_window ** (i + 1)),\n", - " seq_len // (down_sampling_window ** i),\n", + " math.ceil(seq_len / (down_sampling_window ** (i + 1))),\n", + " math.ceil(seq_len / (down_sampling_window ** i)),\n", " ),\n", " nn.GELU(),\n", " torch.nn.Linear(\n", - " seq_len // (down_sampling_window ** i),\n", - " seq_len // (down_sampling_window ** i),\n", + " math.ceil(seq_len / (down_sampling_window ** i)),\n", + " math.ceil(seq_len / (down_sampling_window ** i)),\n", " ),\n", " )\n", " for i in reversed(range(down_sampling_layers))\n", @@ -573,7 +573,7 @@ " self.predict_layers = torch.nn.ModuleList(\n", " [\n", " torch.nn.Linear(\n", - " self.input_size // (self.down_sampling_window ** i),\n", + " math.ceil(self.input_size // (self.down_sampling_window ** i)),\n", " self.h,\n", " )\n", " for i in range(self.down_sampling_layers + 1)\n", @@ -773,149 +773,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "---\n", - "\n", - "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/timemixer.py#L329){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", - "\n", - "### TimeMixer\n", - "\n", - "> TimeMixer (h, input_size, n_series, stat_exog_list=None,\n", - "> hist_exog_list=None, futr_exog_list=None, d_model:int=32,\n", - "> d_ff:int=32, dropout:float=0.1, e_layers:int=4, top_k:int=5,\n", - "> decomp_method:str='moving_avg', moving_avg:int=25,\n", - "> channel_independence:int=0, down_sampling_layers:int=1,\n", - "> down_sampling_window:int=2, down_sampling_method:str='avg',\n", - "> use_norm:bool=True, decoder_input_size_multiplier:float=0.5,\n", - "> loss=MAE(), valid_loss=None, max_steps:int=1000,\n", - "> learning_rate:float=0.001, num_lr_decays:int=-1,\n", - "> early_stop_patience_steps:int=-1, val_check_steps:int=100,\n", - "> batch_size:int=32, step_size:int=1,\n", - "> scaler_type:str='identity', random_seed:int=1,\n", - "> num_workers_loader:int=0, drop_last_loader:bool=False,\n", - "> optimizer=None, optimizer_kwargs=None, lr_scheduler=None,\n", - "> lr_scheduler_kwargs=None, **trainer_kwargs)\n", - "\n", - "TimeMixer\n", - "**Parameters**
\n", - "`h`: int, Forecast horizon.
\n", - "`input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
\n", - "`n_series`: int, number of time-series.
\n", - "`futr_exog_list`: str list, future exogenous columns.
\n", - "`hist_exog_list`: str list, historic exogenous columns.
\n", - "`stat_exog_list`: str list, static exogenous columns.
\n", - "`d_model`: int, dimension of the model.
\n", - "`d_ff`: int, dimension of the fully-connected network.
\n", - "`dropout`: float, dropout rate.
\n", - "`e_layers`: int, number of encoder layers.
\n", - "`top_k`: int, number of selected frequencies.
\n", - "`decomp_method`: str, method of series decomposition [moving_avg, dft_decomp].
\n", - "`moving_avg`: int, window size of moving average.
\n", - "`channel_independence`: int, 0: channel dependence, 1: channel independence.
\n", - "`down_sampling_layers`: int, number of downsampling layers.
\n", - "`down_sampling_window`: int, size of downsampling window.
\n", - "`down_sampling_method`: str, down sampling method [avg, max, conv].
\n", - "`use_norm`: bool, whether to normalize or not.
\n", - " `decoder_input_size_multiplier`: float = 0.5.
\n", - "`loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", - "`valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", - "`max_steps`: int=1000, maximum number of training steps.
\n", - "`learning_rate`: float=1e-3, Learning rate between (0, 1).
\n", - "`num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n", - "`early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n", - "`val_check_steps`: int=100, Number of training steps between every validation loss check.
\n", - "`batch_size`: int=32, number of different series in each batch.
\n", - "`step_size`: int=1, step size between each window of temporal data.
\n", - "`scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n", - "`random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n", - "`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", - "`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", - "`alias`: str, optional, Custom name of the model.
\n", - "`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", - "`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", - "`lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
\n", - "`lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
\n", - "`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", - "\n", - "**References**
\n", - "[Shiyu Wang, Haixu Wu, Xiaoming Shi, Tengge Hu, Huakun Luo, Lintao Ma, James Y. Zhang, Jun Zhou.\"TimeMixer: Decomposable Multiscale Mixing For Time Series Forecasting\"](https://openreview.net/pdf?id=7oLshfEIC2)" - ], - "text/plain": [ - "---\n", - "\n", - "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/timemixer.py#L329){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", - "\n", - "### TimeMixer\n", - "\n", - "> TimeMixer (h, input_size, n_series, stat_exog_list=None,\n", - "> hist_exog_list=None, futr_exog_list=None, d_model:int=32,\n", - "> d_ff:int=32, dropout:float=0.1, e_layers:int=4, top_k:int=5,\n", - "> decomp_method:str='moving_avg', moving_avg:int=25,\n", - "> channel_independence:int=0, down_sampling_layers:int=1,\n", - "> down_sampling_window:int=2, down_sampling_method:str='avg',\n", - "> use_norm:bool=True, decoder_input_size_multiplier:float=0.5,\n", - "> loss=MAE(), valid_loss=None, max_steps:int=1000,\n", - "> learning_rate:float=0.001, num_lr_decays:int=-1,\n", - "> early_stop_patience_steps:int=-1, val_check_steps:int=100,\n", - "> batch_size:int=32, step_size:int=1,\n", - "> scaler_type:str='identity', random_seed:int=1,\n", - "> num_workers_loader:int=0, drop_last_loader:bool=False,\n", - "> optimizer=None, optimizer_kwargs=None, lr_scheduler=None,\n", - "> lr_scheduler_kwargs=None, **trainer_kwargs)\n", - "\n", - "TimeMixer\n", - "**Parameters**
\n", - "`h`: int, Forecast horizon.
\n", - "`input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
\n", - "`n_series`: int, number of time-series.
\n", - "`futr_exog_list`: str list, future exogenous columns.
\n", - "`hist_exog_list`: str list, historic exogenous columns.
\n", - "`stat_exog_list`: str list, static exogenous columns.
\n", - "`d_model`: int, dimension of the model.
\n", - "`d_ff`: int, dimension of the fully-connected network.
\n", - "`dropout`: float, dropout rate.
\n", - "`e_layers`: int, number of encoder layers.
\n", - "`top_k`: int, number of selected frequencies.
\n", - "`decomp_method`: str, method of series decomposition [moving_avg, dft_decomp].
\n", - "`moving_avg`: int, window size of moving average.
\n", - "`channel_independence`: int, 0: channel dependence, 1: channel independence.
\n", - "`down_sampling_layers`: int, number of downsampling layers.
\n", - "`down_sampling_window`: int, size of downsampling window.
\n", - "`down_sampling_method`: str, down sampling method [avg, max, conv].
\n", - "`use_norm`: bool, whether to normalize or not.
\n", - " `decoder_input_size_multiplier`: float = 0.5.
\n", - "`loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", - "`valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", - "`max_steps`: int=1000, maximum number of training steps.
\n", - "`learning_rate`: float=1e-3, Learning rate between (0, 1).
\n", - "`num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n", - "`early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n", - "`val_check_steps`: int=100, Number of training steps between every validation loss check.
\n", - "`batch_size`: int=32, number of different series in each batch.
\n", - "`step_size`: int=1, step size between each window of temporal data.
\n", - "`scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n", - "`random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n", - "`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", - "`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", - "`alias`: str, optional, Custom name of the model.
\n", - "`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", - "`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", - "`lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
\n", - "`lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
\n", - "`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", - "\n", - "**References**
\n", - "[Shiyu Wang, Haixu Wu, Xiaoming Shi, Tengge Hu, Huakun Luo, Lintao Ma, James Y. Zhang, Jun Zhou.\"TimeMixer: Decomposable Multiscale Mixing For Time Series Forecasting\"](https://openreview.net/pdf?id=7oLshfEIC2)" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_doc(TimeMixer)" ] @@ -924,71 +782,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "---\n", - "\n", - "### TimeMixer.fit\n", - "\n", - "> TimeMixer.fit (dataset, val_size=0, test_size=0, random_seed=None,\n", - "> distributed_config=None)\n", - "\n", - "Fit.\n", - "\n", - "The `fit` method, optimizes the neural network's weights using the\n", - "initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n", - "and the `loss` function as defined during the initialization.\n", - "Within `fit` we use a PyTorch Lightning `Trainer` that\n", - "inherits the initialization's `self.trainer_kwargs`, to customize\n", - "its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n", - "\n", - "The method is designed to be compatible with SKLearn-like classes\n", - "and in particular to be compatible with the StatsForecast library.\n", - "\n", - "By default the `model` is not saving training checkpoints to protect\n", - "disk memory, to get them change `enable_checkpointing=True` in `__init__`.\n", - "\n", - "**Parameters:**
\n", - "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", - "`val_size`: int, validation size for temporal cross-validation.
\n", - "`test_size`: int, test size for temporal cross-validation.
" - ], - "text/plain": [ - "---\n", - "\n", - "### TimeMixer.fit\n", - "\n", - "> TimeMixer.fit (dataset, val_size=0, test_size=0, random_seed=None,\n", - "> distributed_config=None)\n", - "\n", - "Fit.\n", - "\n", - "The `fit` method, optimizes the neural network's weights using the\n", - "initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n", - "and the `loss` function as defined during the initialization.\n", - "Within `fit` we use a PyTorch Lightning `Trainer` that\n", - "inherits the initialization's `self.trainer_kwargs`, to customize\n", - "its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n", - "\n", - "The method is designed to be compatible with SKLearn-like classes\n", - "and in particular to be compatible with the StatsForecast library.\n", - "\n", - "By default the `model` is not saving training checkpoints to protect\n", - "disk memory, to get them change `enable_checkpointing=True` in `__init__`.\n", - "\n", - "**Parameters:**
\n", - "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", - "`val_size`: int, validation size for temporal cross-validation.
\n", - "`test_size`: int, test size for temporal cross-validation.
" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_doc(TimeMixer.fit, name='TimeMixer.fit')" ] @@ -997,51 +791,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "---\n", - "\n", - "### TimeMixer.predict\n", - "\n", - "> TimeMixer.predict (dataset, test_size=None, step_size=1,\n", - "> random_seed=None, **data_module_kwargs)\n", - "\n", - "Predict.\n", - "\n", - "Neural network prediction with PL's `Trainer` execution of `predict_step`.\n", - "\n", - "**Parameters:**
\n", - "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", - "`test_size`: int=None, test size for temporal cross-validation.
\n", - "`step_size`: int=1, Step size between each window.
\n", - "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule)." - ], - "text/plain": [ - "---\n", - "\n", - "### TimeMixer.predict\n", - "\n", - "> TimeMixer.predict (dataset, test_size=None, step_size=1,\n", - "> random_seed=None, **data_module_kwargs)\n", - "\n", - "Predict.\n", - "\n", - "Neural network prediction with PL's `Trainer` execution of `predict_step`.\n", - "\n", - "**Parameters:**
\n", - "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", - "`test_size`: int=None, test size for temporal cross-validation.
\n", - "`step_size`: int=1, Step size between each window.
\n", - "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule)." - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_doc(TimeMixer.predict, name='TimeMixer.predict')" ] @@ -1057,1508 +807,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 1\n", - "GPU available: True (mps), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "\n", - " | Name | Type | Params\n", - "----------------------------------------------------------\n", - "0 | loss | MAE | 0 \n", - "1 | valid_loss | MAE | 0 \n", - "2 | padder | ConstantPad1d | 0 \n", - "3 | scaler | TemporalNorm | 0 \n", - "4 | pdm_blocks | ModuleList | 14.2 K\n", - "5 | preprocess | SeriesDecomp | 0 \n", - "6 | enc_embedding | DataEmbedding_wo_pos | 2.5 K \n", - "7 | normalize_layers | ModuleList | 8 \n", - "8 | predict_layers | ModuleList | 456 \n", - "9 | projection_layer | Linear | 33 \n", - "----------------------------------------------------------\n", - "14.8 K Trainable params\n", - "2.4 K Non-trainable params\n", - "17.2 K Total params\n", - "0.069 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9649c190a0e944a39e40f30fb182c4d7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "#| eval: false\n", "import pandas as pd\n", @@ -2616,1509 +865,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (mps), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "\n", - " | Name | Type | Params\n", - "------------------------------------------------------------\n", - "0 | loss | MAE | 0 \n", - "1 | valid_loss | MAE | 0 \n", - "2 | padder | ConstantPad1d | 0 \n", - "3 | scaler | TemporalNorm | 0 \n", - "4 | pdm_blocks | ModuleList | 22.6 K\n", - "5 | preprocess | SeriesDecomp | 0 \n", - "6 | enc_embedding | DataEmbedding_wo_pos | 2.6 K \n", - "7 | normalize_layers | ModuleList | 8 \n", - "8 | predict_layers | ModuleList | 456 \n", - "9 | projection_layer | Linear | 66 \n", - "10 | out_res_layers | ModuleList | 756 \n", - "11 | regression_layers | ModuleList | 456 \n", - "------------------------------------------------------------\n", - "24.6 K Trainable params\n", - "2.4 K Non-trainable params\n", - "27.0 K Total params\n", - "0.108 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f469cf035b9549df85a57a96d51d77a8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "#| eval: false\n", "fcst = NeuralForecast(models=[model], freq='M')\n", diff --git a/neuralforecast/models/timemixer.py b/neuralforecast/models/timemixer.py index 19b1d972c..571c4e96e 100644 --- a/neuralforecast/models/timemixer.py +++ b/neuralforecast/models/timemixer.py @@ -5,6 +5,7 @@ 'PastDecomposableMixing', 'TimeMixer'] # %% ../../nbs/models.timemixer.ipynb 3 +import math import numpy as np import torch @@ -157,13 +158,13 @@ def __init__(self, seq_len, down_sampling_window, down_sampling_layers): [ nn.Sequential( torch.nn.Linear( - seq_len // (down_sampling_window**i), - seq_len // (down_sampling_window ** (i + 1)), + math.ceil(seq_len // (down_sampling_window**i)), + math.ceil(seq_len // (down_sampling_window ** (i + 1))), ), nn.GELU(), torch.nn.Linear( - seq_len // (down_sampling_window ** (i + 1)), - seq_len // (down_sampling_window ** (i + 1)), + math.ceil(seq_len // (down_sampling_window ** (i + 1))), + math.ceil(seq_len // (down_sampling_window ** (i + 1))), ), ) for i in range(down_sampling_layers) @@ -200,13 +201,13 @@ def __init__(self, seq_len, down_sampling_window, down_sampling_layers): [ nn.Sequential( torch.nn.Linear( - seq_len // (down_sampling_window ** (i + 1)), - seq_len // (down_sampling_window**i), + math.ceil(seq_len / (down_sampling_window ** (i + 1))), + math.ceil(seq_len / (down_sampling_window**i)), ), nn.GELU(), torch.nn.Linear( - seq_len // (down_sampling_window**i), - seq_len // (down_sampling_window**i), + math.ceil(seq_len / (down_sampling_window**i)), + math.ceil(seq_len / (down_sampling_window**i)), ), ) for i in reversed(range(down_sampling_layers)) @@ -516,7 +517,7 @@ def __init__( self.predict_layers = torch.nn.ModuleList( [ torch.nn.Linear( - self.input_size // (self.down_sampling_window**i), + math.ceil(self.input_size // (self.down_sampling_window**i)), self.h, ) for i in range(self.down_sampling_layers + 1)