-
Notifications
You must be signed in to change notification settings - Fork 874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Reversible Instance Normalization #1865
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks a lot for this very cool addition @alexcolpitts96 🚀 :)
I only had some minor suggestions and general questions about small deviations from the official implementation and the usage of the Norm in Darts
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #1865 +/- ##
=======================================
Coverage 93.93% 93.94%
=======================================
Files 125 125
Lines 11743 11755 +12
=======================================
+ Hits 11031 11043 +12
Misses 712 712
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks @alexcolpitts96 🚀
Only had some last minor suggestions.
After that we can merge :)
self.affine_bias = nn.Parameter(torch.zeros(self.input_dim)) | ||
|
||
def forward(self, x): | ||
calc_dims = tuple(range(1, x.ndim - 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment in what shape x
is expected to be, what each dimension of x
represents, and why we take the mean over dimension range(1, x.ndim-1)
?
this will makes it easier to interpret in the future.
It will probably be easiest to assume that x
is based on the output of Darts torch datasets (loaders).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some comments. I want to keep it as taking a torch.Tensor for now and not a specific data set. I may change this when it comes to integrating it with models since it may make covariate handling very gross in the models.
When I add it to TiDE (and maybe NHiTS for testing) I will see how it looks and update it accordingly with another PR.
Co-authored-by: Dennis Bader <[email protected]>
Co-authored-by: Dennis Bader <[email protected]>
Co-authored-by: Dennis Bader <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @alexcolpitts96, everything is looking great now! 💯
Will merge as soon as all tests passed.
@dennisbader @alexcolpitts96 I haven't had a chance to do a deep dive into the code changes or what TIDE/TSMixer model requires, but since this MR closed my issue, I thought I would chime in on my understanding of RevIN. RevIN is not a layer normalization. You can see in the Informer example and it implemented in PatchTST that it is applied at the start and end of the The easiest way to implement REVIN is to add it |
You're right @gdevos010. We should add it to another file than the layer norm variants. So far RevIN is not used by any model, but @alexcolpitts96 mentioned integrating it while working on the other PRs. For non-probabilistic models (not using a likelihood) it will work fine at the beginning and end of the forward pass. We still have to check how to best do it for models using a likelihood. |
Fixes #1121, fixes #1861.
Adds RIN as a normalization layer option requested in #1121.
This is a prerequisite component for #1861 and #1807.