-
Notifications
You must be signed in to change notification settings - Fork 874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add normalization to BlockRNNModel #1748
Open
JanFidor
wants to merge
9
commits into
unit8co:master
Choose a base branch
from
JanFidor:feature/rnn-normalization
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
8804f0b
add normalization to block_rnn
JanFidor ca9fddb
remove todo
JanFidor e4c6089
fix indexing
JanFidor 80c9e10
clean up unused code
JanFidor e89c22e
pass hidden state to fc layer
JanFidor 544ab44
Merge branch 'master' into feature/rnn-normalization
JanFidor 76a58fe
update block rnn
JanFidor f951dae
Merge remote-tracking branch 'upstream/master' into feature/rnn-norma…
JanFidor 5de39ea
refactor temporal batch norm
JanFidor File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
io_processor, | ||
) | ||
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel | ||
from darts.utils.torch import ExtractRnnOutput, TemporalBatchNorm1d | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
@@ -30,6 +31,7 @@ def __init__( | |
nr_params: int, | ||
num_layers_out_fc: Optional[List] = None, | ||
dropout: float = 0.0, | ||
normalization: str = None, | ||
**kwargs, | ||
): | ||
"""This class allows to create custom block RNN modules that can later be used with Darts' | ||
|
@@ -63,6 +65,8 @@ def __init__( | |
This network connects the last hidden layer of the PyTorch RNN module to the output. | ||
dropout | ||
The fraction of neurons that are dropped in all-but-last RNN layers. | ||
normalization | ||
The name of the normalization applied after RNN and FC layers ("batch", "layer") | ||
**kwargs | ||
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class. | ||
""" | ||
|
@@ -77,6 +81,7 @@ def __init__( | |
self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc | ||
self.dropout = dropout | ||
self.out_len = self.output_chunk_length | ||
self.normalization = normalization | ||
|
||
@io_processor | ||
@abstractmethod | ||
|
@@ -143,37 +148,34 @@ def __init__( | |
self.name = name | ||
|
||
# Defining the RNN module | ||
self.rnn = getattr(nn, self.name)( | ||
self.rnn = self._rnn_sequence( | ||
name, | ||
self.input_size, | ||
self.hidden_dim, | ||
self.num_layers, | ||
batch_first=True, | ||
dropout=self.dropout, | ||
self.dropout, | ||
self.normalization, | ||
) | ||
|
||
# The RNN module is followed by a fully connected layer, which maps the last hidden layer | ||
# to the output of desired length | ||
last = self.hidden_dim | ||
feats = [] | ||
for feature in self.num_layers_out_fc + [ | ||
self.out_len * self.target_size * self.nr_params | ||
]: | ||
feats.append(nn.Linear(last, feature)) | ||
last = feature | ||
self.fc = nn.Sequential(*feats) | ||
self.fc = self._fc_layer( | ||
self.hidden_dim, | ||
self.num_layers_out_fc, | ||
self.target_size, | ||
self.normalization, | ||
) | ||
|
||
@io_processor | ||
def forward(self, x_in: Tuple): | ||
x, _ = x_in | ||
# data is of size (batch_size, input_chunk_length, input_size) | ||
batch_size = x.size(0) | ||
|
||
out, hidden = self.rnn(x) | ||
hidden = self.rnn(x) | ||
|
||
""" Here, we apply the FC network only on the last output point (at the last time step) | ||
""" | ||
if self.name == "LSTM": | ||
hidden = hidden[0] | ||
predictions = hidden[-1, :, :] | ||
predictions = self.fc(predictions) | ||
predictions = predictions.view( | ||
|
@@ -183,6 +185,61 @@ def forward(self, x_in: Tuple): | |
# predictions is of size (batch_size, output_chunk_length, 1) | ||
return predictions | ||
|
||
def _rnn_sequence( | ||
self, | ||
name: str, | ||
input_size: int, | ||
hidden_dim: int, | ||
num_layers: int, | ||
dropout: float = 0.0, | ||
normalization: str = None, | ||
): | ||
|
||
modules = [] | ||
is_lstm = self.name == "LSTM" | ||
for i in range(num_layers): | ||
input = input_size if (i == 0) else hidden_dim | ||
is_last = i == num_layers - 1 | ||
rnn = getattr(nn, name)(input, hidden_dim, 1, batch_first=True) | ||
|
||
modules.append(rnn) | ||
modules.append(ExtractRnnOutput(not is_last, is_lstm)) | ||
modules.append(nn.Dropout(dropout)) | ||
if normalization: | ||
modules.append(self._normalization_layer(normalization, hidden_dim)) | ||
if not is_last: # pytorch RNNs don't have dropout applied on the last layer | ||
modules.append(nn.Dropout(dropout)) | ||
|
||
return nn.Sequential(*modules) | ||
|
||
def _fc_layer( | ||
self, | ||
input_size: int, | ||
num_layers_out_fc: list[int], | ||
target_size: int, | ||
normalization: str = None, | ||
): | ||
if not num_layers_out_fc: | ||
num_layers_out_fc = [] | ||
|
||
last = input_size | ||
feats = [] | ||
for feature in num_layers_out_fc + [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i will rather use the extend method for lists |
||
self.output_chunk_length * target_size * self.nr_params | ||
]: | ||
if normalization: | ||
feats.append(self._normalization_layer(normalization, last)) | ||
feats.append(nn.Linear(last, feature)) | ||
last = feature | ||
return nn.Sequential(*feats) | ||
|
||
def _normalization_layer(self, normalization: str, hidden_size: int): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if |
||
|
||
if normalization == "batch": | ||
return TemporalBatchNorm1d(hidden_size) | ||
elif normalization == "layer": | ||
return nn.LayerNorm(hidden_size) | ||
|
||
|
||
class BlockRNNModel(PastCovariatesTorchModel): | ||
def __init__( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get this point here.
num_layers_out_fc
is a list of integers correct?Suppose
num_layers_out_fc = []
, then notnum_layers_out_fc is True.
So why
num_layers_out_fc = []
?