From 8804f0ba0640b47e49570a4a5bdc6edf97b29fdd Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Mon, 8 May 2023 08:11:38 +0200 Subject: [PATCH 1/7] add normalization to block_rnn --- darts/models/forecasting/block_rnn_model.py | 81 +++++++++++++++++---- darts/utils/torch.py | 25 +++++++ 2 files changed, 93 insertions(+), 13 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 71dea28e43..a2f1e94bf4 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -11,6 +11,7 @@ from darts.logging import get_logger, raise_if_not from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel +from darts.utils.torch import ExtractRnnOutput, TemporalBatchNorm1d logger = get_logger(__name__) @@ -27,6 +28,7 @@ def __init__( nr_params: int, num_layers_out_fc: Optional[List] = None, dropout: float = 0.0, + normalization: str = None, **kwargs, ): @@ -62,6 +64,8 @@ def __init__( This network connects the last hidden layer of the PyTorch RNN module to the output. dropout The fraction of neurons that are dropped in all-but-last RNN layers. + normalization + The name of the normalization applied after RNN and FC layers ("batch", "layer") **kwargs all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class. @@ -88,32 +92,25 @@ def __init__( self.name = name # Defining the RNN module - self.rnn = getattr(nn, name)( - input_size, hidden_dim, num_layers, batch_first=True, dropout=dropout + self.rnn = self._rnn_sequence( + name, input_size, hidden_dim, num_layers, dropout, normalization ) # The RNN module is followed by a fully connected layer, which maps the last hidden layer # to the output of desired length - last = hidden_dim - feats = [] - for feature in num_layers_out_fc + [ - self.output_chunk_length * target_size * nr_params - ]: - feats.append(nn.Linear(last, feature)) - last = feature - self.fc = nn.Sequential(*feats) + self.fc = self._fc_layer( + hidden_dim, num_layers_out_fc, target_size, normalization + ) def forward(self, x_in: Tuple): x, _ = x_in # data is of size (batch_size, input_chunk_length, input_size) batch_size = x.size(0) - out, hidden = self.rnn(x) + hidden = self.rnn(x) """ Here, we apply the FC network only on the last output point (at the last time step) """ - if self.name == "LSTM": - hidden = hidden[0] predictions = hidden[-1, :, :] predictions = self.fc(predictions) predictions = predictions.view( @@ -123,6 +120,64 @@ def forward(self, x_in: Tuple): # predictions is of size (batch_size, output_chunk_length, 1) return predictions + def _rnn_sequence( + self, + name: str, + input_size: int, + hidden_dim: int, + num_layers: int, + dropout: float = 0.0, + normalization: str = None, + ): + + modules = [] + for i in range(num_layers): + input = input_size if (i == 0) else hidden_dim + rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True) + + modules.append(rnn) + modules.append(ExtractRnnOutput()) + + if normalization: + modules.append(self._normalization_layer(normalization, hidden_dim)) + if ( + i < num_layers - 1 + ): # pytorch RNNs don't have dropout applied on the last layer + modules.append(nn.Dropout(dropout)) + return nn.Sequential(*modules) + + def _fc_layer( + self, + input_size: int, + num_layers_out_fc: list[int], + target_size: int, + normalization: str = None, + ): + if not num_layers_out_fc: + num_layers_out_fc = [] + + last = input_size + feats = [] + for feature in num_layers_out_fc: + if normalization: + feats.append(self._normalization_layer(normalization, last, False)) + feats.append(nn.Linear(last, feature)) + last = feature + feats.append(nn.Linear(last, self.out_len * target_size * self.nr_params)) + return nn.Sequential(*feats) + + def _normalization_layer( + self, normalization: str, hidden_size: int, is_temporal: bool + ): + + if normalization == "batch": + if is_temporal: + return TemporalBatchNorm1d(hidden_size) + else: + return nn.BatchNorm1d(hidden_size) + elif normalization == "layer": + return nn.LayerNorm(hidden_size) + class BlockRNNModel(PastCovariatesTorchModel): def __init__( diff --git a/darts/utils/torch.py b/darts/utils/torch.py index 552f285384..d70141aba1 100644 --- a/darts/utils/torch.py +++ b/darts/utils/torch.py @@ -112,3 +112,28 @@ def decorator(self, *args, **kwargs) -> T: return decorated(self, *args, **kwargs) return decorator + + +class TemporalBatchNorm1d(nn.Module): + def __init__(self, feature_size) -> None: + super().__init__() + self.norm = nn.BatchNorm1d(feature_size) + + def forward(self, input): + input = self._reshape_input(input) # Reshape N L C -> N C L + input = self.norm(input) + input = self._reshape_input(input) + return input if len(input) > 1 else input[0] + + def _reshape_input(self, x): + shape = x.shape + return x.reshape(shape[0], shape[2], shape[1]) + + +class ExtractRnnOutput(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, input): + output, _ = input + return output From ca9fddb7039b42cc20f0909c11d3ded2eee52763 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Mon, 8 May 2023 08:38:58 +0200 Subject: [PATCH 2/7] remove todo --- darts/models/forecasting/block_rnn_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index a2f1e94bf4..1f8917a4d5 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -16,7 +16,6 @@ logger = get_logger(__name__) -# TODO add batch norm class _BlockRNNModule(PLPastCovariatesModule): def __init__( self, From e4c608998b1802149a0e04f345739a146764704e Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Thu, 11 May 2023 18:00:17 +0200 Subject: [PATCH 3/7] fix indexing --- darts/models/forecasting/block_rnn_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 1f8917a4d5..4477be7ea5 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -110,7 +110,7 @@ def forward(self, x_in: Tuple): """ Here, we apply the FC network only on the last output point (at the last time step) """ - predictions = hidden[-1, :, :] + predictions = hidden[:, -1, :] predictions = self.fc(predictions) predictions = predictions.view( batch_size, self.out_len, self.target_size, self.nr_params @@ -157,12 +157,13 @@ def _fc_layer( last = input_size feats = [] - for feature in num_layers_out_fc: + for feature in num_layers_out_fc + [ + self.output_chunk_length * target_size * self.nr_params + ]: if normalization: feats.append(self._normalization_layer(normalization, last, False)) feats.append(nn.Linear(last, feature)) last = feature - feats.append(nn.Linear(last, self.out_len * target_size * self.nr_params)) return nn.Sequential(*feats) def _normalization_layer( From 80c9e100a3b4d5a37ad97d26c9351197447e37ac Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Thu, 11 May 2023 18:21:18 +0200 Subject: [PATCH 4/7] clean up unused code --- darts/models/forecasting/block_rnn_model.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 4477be7ea5..ce116fd74d 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -132,6 +132,7 @@ def _rnn_sequence( modules = [] for i in range(num_layers): input = input_size if (i == 0) else hidden_dim + is_last = i == num_layers - 1 rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True) modules.append(rnn) @@ -139,9 +140,7 @@ def _rnn_sequence( if normalization: modules.append(self._normalization_layer(normalization, hidden_dim)) - if ( - i < num_layers - 1 - ): # pytorch RNNs don't have dropout applied on the last layer + if is_last: # pytorch RNNs don't have dropout applied on the last layer modules.append(nn.Dropout(dropout)) return nn.Sequential(*modules) @@ -161,20 +160,15 @@ def _fc_layer( self.output_chunk_length * target_size * self.nr_params ]: if normalization: - feats.append(self._normalization_layer(normalization, last, False)) + feats.append(self._normalization_layer(normalization, last)) feats.append(nn.Linear(last, feature)) last = feature return nn.Sequential(*feats) - def _normalization_layer( - self, normalization: str, hidden_size: int, is_temporal: bool - ): + def _normalization_layer(self, normalization: str, hidden_size: int): if normalization == "batch": - if is_temporal: - return TemporalBatchNorm1d(hidden_size) - else: - return nn.BatchNorm1d(hidden_size) + return TemporalBatchNorm1d(hidden_size) elif normalization == "layer": return nn.LayerNorm(hidden_size) From e89c22e69e4d0d9389a44920d2ca4b6ddd88ccca Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Mon, 15 May 2023 08:32:21 +0200 Subject: [PATCH 5/7] pass hidden state to fc layer --- darts/models/forecasting/block_rnn_model.py | 10 ++++++---- darts/utils/torch.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index ce116fd74d..327440b798 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -110,7 +110,7 @@ def forward(self, x_in: Tuple): """ Here, we apply the FC network only on the last output point (at the last time step) """ - predictions = hidden[:, -1, :] + predictions = hidden[-1, :, :] predictions = self.fc(predictions) predictions = predictions.view( batch_size, self.out_len, self.target_size, self.nr_params @@ -130,18 +130,20 @@ def _rnn_sequence( ): modules = [] + is_lstm = self.name == "LSTM" for i in range(num_layers): input = input_size if (i == 0) else hidden_dim is_last = i == num_layers - 1 rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True) modules.append(rnn) - modules.append(ExtractRnnOutput()) - + modules.append(ExtractRnnOutput(not is_last, is_lstm)) + modules.append(nn.Dropout(dropout)) if normalization: modules.append(self._normalization_layer(normalization, hidden_dim)) - if is_last: # pytorch RNNs don't have dropout applied on the last layer + if not is_last: # pytorch RNNs don't have dropout applied on the last layer modules.append(nn.Dropout(dropout)) + return nn.Sequential(*modules) def _fc_layer( diff --git a/darts/utils/torch.py b/darts/utils/torch.py index d70141aba1..3e6934afc9 100644 --- a/darts/utils/torch.py +++ b/darts/utils/torch.py @@ -131,9 +131,15 @@ def _reshape_input(self, x): class ExtractRnnOutput(nn.Module): - def __init__(self) -> None: + def __init__(self, is_output, is_lstm) -> None: + self.is_output = is_output + self.is_lstm = is_lstm super().__init__() def forward(self, input): - output, _ = input - return output + output, hidden = input + if self.is_output: + return output + if self.is_lstm: + return hidden[0] + return hidden From 76a58fe4b341fb9f796cd353283f9c71621f7380 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Wed, 7 Feb 2024 19:02:42 +0100 Subject: [PATCH 6/7] update block rnn --- darts/models/forecasting/block_rnn_model.py | 242 +++++++++++++++----- 1 file changed, 181 insertions(+), 61 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 327440b798..aeb3799039 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -3,23 +3,27 @@ ------------------------------- """ -from typing import List, Optional, Tuple, Union +import inspect +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn as nn -from darts.logging import get_logger, raise_if_not -from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule +from darts.logging import get_logger, raise_log +from darts.models.forecasting.pl_forecasting_module import ( + PLPastCovariatesModule, + io_processor, +) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel from darts.utils.torch import ExtractRnnOutput, TemporalBatchNorm1d logger = get_logger(__name__) -class _BlockRNNModule(PLPastCovariatesModule): +class CustomBlockRNNModule(PLPastCovariatesModule, ABC): def __init__( self, - name: str, input_size: int, hidden_dim: int, num_layers: int, @@ -30,24 +34,22 @@ def __init__( normalization: str = None, **kwargs, ): + """This class allows to create custom block RNN modules that can later be used with Darts' + :class:`BlockRNNModel`. It adds the backbone that is required to be used with Darts' + :class:`TorchForecastingModel` and :class:`BlockRNNModel`. - """PyTorch module implementing a block RNN to be used in `BlockRNNModel`. + To create a new module, subclass from :class:`CustomBlockRNNModule` and: - PyTorch module implementing a simple block RNN with the specified `name` layer. - This module combines a PyTorch RNN module, together with a fully connected network, which maps the - last hidden layers to output of the desired size `output_chunk_length` and makes it compatible with - `BlockRNNModel`s. + * Define the architecture in the module constructor (`__init__()`) - This module uses an RNN to encode the input sequence, and subsequently uses a fully connected - network as the decoder which takes as input the last hidden state of the encoder RNN. - The final output of the decoder is a sequence of length `output_chunk_length`. In this sense, - the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different - from `_RNNModule` used by the `RNNModel`). + * Add the `forward()` method and define the logic of your module's forward pass + + * Use the custom module class when creating a new :class:`BlockRNNModel` with parameter `model`. + + You can use `darts.models.forecasting.block_rnn_model._BlockRNNModule` as an example. Parameters ---------- - name - The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM"). input_size The dimensionality of the input time series. hidden_dim @@ -67,6 +69,68 @@ def __init__( The name of the normalization applied after RNN and FC layers ("batch", "layer") **kwargs all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class. + """ + super().__init__(**kwargs) + + # Defining parameters + self.input_size = input_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.target_size = target_size + self.nr_params = nr_params + self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc + self.dropout = dropout + self.out_len = self.output_chunk_length + self.normalization = normalization + + @io_processor + @abstractmethod + def forward(self, x_in: Tuple) -> torch.Tensor: + """BlockRNN Module forward. + + Parameters + ---------- + x_in + Tuple of Tensors containing the features of the input sequence. The tuple has elements + (past target, historic future covariates, future covariates, static covariates). + The shape of the past target is `(batch_size, input_length, input_size)`. + + Returns + ------- + torch.Tensor + The BlockRNN output Tensor with shape `(batch_size, output_chunk_length, target_size, nr_params)`. + It contains the prediction at the last time step of the sequence. + """ + pass + + +# TODO add batch norm +class _BlockRNNModule(CustomBlockRNNModule): + def __init__( + self, + name: str, + **kwargs, + ): + + """PyTorch module implementing a block RNN to be used in `BlockRNNModel`. + + PyTorch module implementing a simple block RNN with the specified `name` layer. + This module combines a PyTorch RNN module, together with a fully connected network, which maps the + last hidden layers to output of the desired size `output_chunk_length` and makes it compatible with + `BlockRNNModel`s. + + This module uses an RNN to encode the input sequence, and subsequently uses a fully connected + network as the decoder which takes as input the last hidden state of the encoder RNN. + The final output of the decoder is a sequence of length `output_chunk_length`. In this sense, + the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different + from `_RNNModule` used by the `RNNModel`). + + Parameters + ---------- + name + The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM"). + **kwargs + all parameters required for the :class:`darts.model.forecasting_models.CustomBlockRNNModule` base class. Inputs ------ @@ -81,26 +145,28 @@ def __init__( super().__init__(**kwargs) - # Defining parameters - self.hidden_dim = hidden_dim - self.n_layers = num_layers - self.target_size = target_size - self.nr_params = nr_params - num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc - self.out_len = self.output_chunk_length self.name = name # Defining the RNN module self.rnn = self._rnn_sequence( - name, input_size, hidden_dim, num_layers, dropout, normalization + name, + self.input_size, + self.hidden_dim, + self.num_layers, + self.dropout, + self.normalization, ) # The RNN module is followed by a fully connected layer, which maps the last hidden layer # to the output of desired length self.fc = self._fc_layer( - hidden_dim, num_layers_out_fc, target_size, normalization + self.hidden_dim, + self.num_layers_out_fc, + self.target_size, + self.normalization, ) + @io_processor def forward(self, x_in: Tuple): x, _ = x_in # data is of size (batch_size, input_chunk_length, input_size) @@ -180,7 +246,7 @@ def __init__( self, input_chunk_length: int, output_chunk_length: int, - model: Union[str, nn.Module] = "RNN", + model: Union[str, Type[CustomBlockRNNModule]] = "RNN", hidden_dim: int = 25, n_rnn_layers: int = 1, hidden_fc_sizes: Optional[List] = None, @@ -206,13 +272,19 @@ def __init__( Parameters ---------- input_chunk_length - The number of time steps that will be fed to the internal forecasting module + Number of time steps in the past to take as a model input (per chunk). Applies to the target + series, and past and/or future covariates (if the model supports it). output_chunk_length - Number of time steps to be output by the internal forecasting module. + Number of time steps predicted at once (per chunk) by the internal model. Also, the number of future values + from future covariates to use as a model input (if the model supports future covariates). It is not the same + as forecast horizon `n` used in `predict()`, which is the desired number of prediction points generated + using either a one-shot- or auto-regressive forecast. Setting `n <= output_chunk_length` prevents + auto-regression. This is useful when the covariates don't extend far enough into the future, or to prohibit + the model from using future values of past and / or future covariates for prediction (depending on the + model's covariate support). model - Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"), - or a PyTorch module with the same specifications as - :class:`darts.models.block_rnn_model._BlockRNNModule`. + Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"), or a subclass of + :class:`CustomBlockRNNModule` (the class itself, not an object of the class) with a custom logic. hidden_dim Size for feature maps for each hidden RNN layer (:math:`h_n`). In Darts version <= 0.21, hidden_dim was referred as hidden_size. @@ -247,6 +319,9 @@ def __init__( to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. + use_reversible_instance_norm + Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [1]_. + It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs @@ -270,7 +345,7 @@ def __init__( If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will be discarded). Default: ``False``. save_checkpoints - Whether or not to automatically save the untrained model and checkpoints from training. + Whether to automatically save the untrained model and checkpoints from training. To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where :class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`, :class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using @@ -287,12 +362,16 @@ def __init__( .. highlight:: python .. code-block:: python + def encode_year(idx): + return (idx.year - 1950) / 50 + add_encoders={ 'cyclic': {'future': ['month']}, 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, - 'custom': {'past': [lambda idx: (idx.year - 1950) / 50]}, - 'transformer': Scaler() + 'custom': {'past': [encode_year]}, + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state @@ -311,7 +390,6 @@ def __init__( "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. @@ -349,6 +427,41 @@ def __init__( show_warnings whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of your forecasting use case. Default: ``False``. + + References + ---------- + .. [1] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against + Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p + + Examples + -------- + >>> from darts.datasets import WeatherDataset + >>> from darts.models import BlockRNNModel + >>> series = WeatherDataset().load() + >>> # predicting atmospheric pressure + >>> target = series['p (mbar)'][:100] + >>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100) + >>> past_cov = series['rain (mm)'][:100] + >>> # predict 6 pressure values using the 12 past values of pressure and rainfall, as well as the 6 temperature + >>> model = BlockRNNModel( + >>> input_chunk_length=12, + >>> output_chunk_length=6, + >>> n_rnn_layers=2, + >>> n_epochs=50, + >>> ) + >>> model.fit(target, past_covariates=past_cov) + >>> pred = model.predict(6) + >>> pred.values() + array([[4.97979827], + [3.9707572 ], + [5.27869295], + [5.19697244], + [5.28424783], + [5.22497681]]) + + .. note:: + `RNN example notebook `_ presents techniques + that can be used to improve the forecasts quality compared to this simple usage example. """ super().__init__(**self._extract_torch_model_params(**self.model_params)) @@ -357,14 +470,16 @@ def __init__( # check we got right model type specified: if model not in ["RNN", "LSTM", "GRU"]: - raise_if_not( - isinstance(model, nn.Module), - '{} is not a valid RNN model.\n Please specify "RNN", "LSTM", ' - '"GRU", or give your own PyTorch nn.Module'.format( - model.__class__.__name__ - ), - logger, - ) + if not inspect.isclass(model) or not issubclass( + model, CustomBlockRNNModule + ): + raise_log( + ValueError( + "`model` is not a valid RNN model. Please specify 'RNN', 'LSTM', 'GRU', or give a subclass " + "(not an instance) of darts.models.forecasting.rnn_model.CustomBlockRNNModule." + ), + logger=logger, + ) self.rnn_type_or_module = model self.hidden_fc_sizes = hidden_fc_sizes @@ -372,6 +487,10 @@ def __init__( self.n_rnn_layers = n_rnn_layers self.dropout = dropout + @property + def supports_multivariate(self) -> bool: + return True + def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: # samples are made of (past_target, past_covariates, future_target) input_dim = train_sample[0].shape[1] + ( @@ -380,21 +499,22 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: output_dim = train_sample[-1].shape[1] nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters - if self.rnn_type_or_module in ["RNN", "LSTM", "GRU"]: - hidden_fc_sizes = ( - [] if self.hidden_fc_sizes is None else self.hidden_fc_sizes - ) - model = _BlockRNNModule( - name=self.rnn_type_or_module, - input_size=input_dim, - target_size=output_dim, - nr_params=nr_params, - hidden_dim=self.hidden_dim, - num_layers=self.n_rnn_layers, - num_layers_out_fc=hidden_fc_sizes, - dropout=self.dropout, - **self.pl_module_params, - ) + hidden_fc_sizes = [] if self.hidden_fc_sizes is None else self.hidden_fc_sizes + + kwargs = {} + if isinstance(self.rnn_type_or_module, str): + model_cls = _BlockRNNModule + kwargs["name"] = self.rnn_type_or_module else: - model = self.rnn_type_or_module - return model + model_cls = self.rnn_type_or_module + return model_cls( + input_size=input_dim, + target_size=output_dim, + nr_params=nr_params, + hidden_dim=self.hidden_dim, + num_layers=self.n_rnn_layers, + num_layers_out_fc=hidden_fc_sizes, + dropout=self.dropout, + **self.pl_module_params, + **kwargs, + ) From 5de39eafaf12633c55fdbb5471a556efcb192d45 Mon Sep 17 00:00:00 2001 From: Jan Fidor Date: Wed, 7 Feb 2024 19:12:40 +0100 Subject: [PATCH 7/7] refactor temporal batch norm --- darts/utils/torch.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/darts/utils/torch.py b/darts/utils/torch.py index 3e6934afc9..1ee4bb251e 100644 --- a/darts/utils/torch.py +++ b/darts/utils/torch.py @@ -120,15 +120,11 @@ def __init__(self, feature_size) -> None: self.norm = nn.BatchNorm1d(feature_size) def forward(self, input): - input = self._reshape_input(input) # Reshape N L C -> N C L + input = input.swapaxes(1, 2) input = self.norm(input) - input = self._reshape_input(input) + input = input.swapaxes(1, 2) return input if len(input) > 1 else input[0] - def _reshape_input(self, x): - shape = x.shape - return x.reshape(shape[0], shape[2], shape[1]) - class ExtractRnnOutput(nn.Module): def __init__(self, is_output, is_lstm) -> None: