diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 36cf6e210d..aeb3799039 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -16,6 +16,7 @@ io_processor, ) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel +from darts.utils.torch import ExtractRnnOutput, TemporalBatchNorm1d logger = get_logger(__name__) @@ -30,6 +31,7 @@ def __init__( nr_params: int, num_layers_out_fc: Optional[List] = None, dropout: float = 0.0, + normalization: str = None, **kwargs, ): """This class allows to create custom block RNN modules that can later be used with Darts' @@ -63,6 +65,8 @@ def __init__( This network connects the last hidden layer of the PyTorch RNN module to the output. dropout The fraction of neurons that are dropped in all-but-last RNN layers. + normalization + The name of the normalization applied after RNN and FC layers ("batch", "layer") **kwargs all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class. """ @@ -77,6 +81,7 @@ def __init__( self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc self.dropout = dropout self.out_len = self.output_chunk_length + self.normalization = normalization @io_processor @abstractmethod @@ -143,24 +148,23 @@ def __init__( self.name = name # Defining the RNN module - self.rnn = getattr(nn, self.name)( + self.rnn = self._rnn_sequence( + name, self.input_size, self.hidden_dim, self.num_layers, - batch_first=True, - dropout=self.dropout, + self.dropout, + self.normalization, ) # The RNN module is followed by a fully connected layer, which maps the last hidden layer # to the output of desired length - last = self.hidden_dim - feats = [] - for feature in self.num_layers_out_fc + [ - self.out_len * self.target_size * self.nr_params - ]: - feats.append(nn.Linear(last, feature)) - last = feature - self.fc = nn.Sequential(*feats) + self.fc = self._fc_layer( + self.hidden_dim, + self.num_layers_out_fc, + self.target_size, + self.normalization, + ) @io_processor def forward(self, x_in: Tuple): @@ -168,12 +172,10 @@ def forward(self, x_in: Tuple): # data is of size (batch_size, input_chunk_length, input_size) batch_size = x.size(0) - out, hidden = self.rnn(x) + hidden = self.rnn(x) """ Here, we apply the FC network only on the last output point (at the last time step) """ - if self.name == "LSTM": - hidden = hidden[0] predictions = hidden[-1, :, :] predictions = self.fc(predictions) predictions = predictions.view( @@ -183,6 +185,61 @@ def forward(self, x_in: Tuple): # predictions is of size (batch_size, output_chunk_length, 1) return predictions + def _rnn_sequence( + self, + name: str, + input_size: int, + hidden_dim: int, + num_layers: int, + dropout: float = 0.0, + normalization: str = None, + ): + + modules = [] + is_lstm = self.name == "LSTM" + for i in range(num_layers): + input = input_size if (i == 0) else hidden_dim + is_last = i == num_layers - 1 + rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True) + + modules.append(rnn) + modules.append(ExtractRnnOutput(not is_last, is_lstm)) + modules.append(nn.Dropout(dropout)) + if normalization: + modules.append(self._normalization_layer(normalization, hidden_dim)) + if not is_last: # pytorch RNNs don't have dropout applied on the last layer + modules.append(nn.Dropout(dropout)) + + return nn.Sequential(*modules) + + def _fc_layer( + self, + input_size: int, + num_layers_out_fc: list[int], + target_size: int, + normalization: str = None, + ): + if not num_layers_out_fc: + num_layers_out_fc = [] + + last = input_size + feats = [] + for feature in num_layers_out_fc + [ + self.output_chunk_length * target_size * self.nr_params + ]: + if normalization: + feats.append(self._normalization_layer(normalization, last)) + feats.append(nn.Linear(last, feature)) + last = feature + return nn.Sequential(*feats) + + def _normalization_layer(self, normalization: str, hidden_size: int): + + if normalization == "batch": + return TemporalBatchNorm1d(hidden_size) + elif normalization == "layer": + return nn.LayerNorm(hidden_size) + class BlockRNNModel(PastCovariatesTorchModel): def __init__( diff --git a/darts/utils/torch.py b/darts/utils/torch.py index 552f285384..1ee4bb251e 100644 --- a/darts/utils/torch.py +++ b/darts/utils/torch.py @@ -112,3 +112,30 @@ def decorator(self, *args, **kwargs) -> T: return decorated(self, *args, **kwargs) return decorator + + +class TemporalBatchNorm1d(nn.Module): + def __init__(self, feature_size) -> None: + super().__init__() + self.norm = nn.BatchNorm1d(feature_size) + + def forward(self, input): + input = input.swapaxes(1, 2) + input = self.norm(input) + input = input.swapaxes(1, 2) + return input if len(input) > 1 else input[0] + + +class ExtractRnnOutput(nn.Module): + def __init__(self, is_output, is_lstm) -> None: + self.is_output = is_output + self.is_lstm = is_lstm + super().__init__() + + def forward(self, input): + output, hidden = input + if self.is_output: + return output + if self.is_lstm: + return hidden[0] + return hidden