-
Notifications
You must be signed in to change notification settings - Fork 874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add normalization to BlockRNNModel #1748
base: master
Are you sure you want to change the base?
Conversation
Some of the tests were failing, I'll check if continues after merging develop. One of them was test_fit_predict_determinism() which after debugging turned out to fail for ARIMA model, It wasn't in a scope of this PR so I'm unsure what might have happened. Might be a problem with my local build, I'll wait and see what the github actions say |
Codecov ReportAttention:
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## master #1748 +/- ##
==========================================
- Coverage 93.88% 93.78% -0.10%
==========================================
Files 135 135
Lines 13425 13461 +36
==========================================
+ Hits 12604 12625 +21
- Misses 821 836 +15 ☔ View full report in Codecov by Sentry. |
I've been thinking whether adding batch norm makes sense in this case, as repeated rescaling would cause gradient explosion, the very thing LSTM / GRU were supposed to combat. I'm inclined to only allow layer normalization (maybe also group norm), so that users don't accidentally fall into that trap. Let me know if you think that would fit with darts design philosophy ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
darts/utils/torch.py
Outdated
self.norm = nn.BatchNorm1d(feature_size) | ||
|
||
def forward(self, input): | ||
input = self._reshape_input(input) # Reshape N L C -> N C L |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is more about swapping axes that reshaping, I would instead use the corresponding torch function:
input = self._reshape_input(input) # Reshape N L C -> N C L | |
# Reshape N L C -> N C L | |
input = input.swapaxes(1,2) |
darts/utils/torch.py
Outdated
def forward(self, input): | ||
input = self._reshape_input(input) # Reshape N L C -> N C L | ||
input = self.norm(input) | ||
input = self._reshape_input(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input = self._reshape_input(input) | |
input = input.swapaxes(1,2) |
Thanks for another review @madtoinou ! The article looks exciting (at least after skimming it and reading the abstract ;P ) I found some implementations online, but I'd rather understand the actual idea first before implementing it, so it might take me a little longer compared to the other 2 PRs |
Hi @madtoinou , quick update! I've read the paper and have an idea how to implement it. It might need a little bit of magic to get the time_step_index into the model input, but I think it should be doable, I'll let you know when I'll get everything running or if I stumble into some problem |
Quick update @madtoinou. I've been browsing through the codebase and wanted to get your thoughts on my planned approach. I think that the simplest approach would be to manually add a past encoder with static position, but that would require expanding IntegerIndexEncoder which only supports 'relative' for now. That said, I'm not sure at which point the Encoders are applied to the TS and this approach depends on it happening before TS are sliced for training. It's also possible to manually add a "static index" component, but I think this approach would be more elegant and static IntegerIndexEncoder might be useful in other implementations in the future |
Hi again @madtoinou! I wanted to get your thoughts on my new idea for the implementation. I went back to the paper and found the mention of using the batch norms specifically when training. Wouldn't it just suffice to store |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
1a4627f
to
544ab44
Compare
…lization # Conflicts: # darts/models/forecasting/block_rnn_model.py
It took some playing around but I think I managed to fix most of the git history (please ignore the git push --force hahaha) |
target_size: int, | ||
normalization: str = None, | ||
): | ||
if not num_layers_out_fc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get this point here.
num_layers_out_fc
is a list of integers correct?
Suppose num_layers_out_fc = []
, then not num_layers_out_fc is True.
So why num_layers_out_fc = []
?
|
||
last = input_size | ||
feats = [] | ||
for feature in num_layers_out_fc + [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i will rather use the extend method for lists
last = feature | ||
return nn.Sequential(*feats) | ||
|
||
def _normalization_layer(self, normalization: str, hidden_size: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if normalization
is different from batch
and layer
the method return None
. is this intended?
Fixes #1649.
Summary
I've added normalization parameter to the BlockRNNModel, I've brainstormed how to do it for RNNModel and I couldn't come up with a way that wouldn't require some type of dynamic aggregation of the hidden states, so I decided to make the PR for BlockRNN for now.
I added two torch modules to simplify the rnn sequence, not sure if it's the cleanest way to implement it, but it's at least very readable.
Other Information
I also added layer norm, because it was a simple addition and it seems to be the recommended normalization for RNNs. I also considered adding group normalization, but it would either need constant num_groups parameter or additional constructor parameter for BlockRNNModel