Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Reversible Instance Normalization #1865

Merged
merged 23 commits into from
Jul 18, 2023
Merged

Conversation

alexcolpitts96
Copy link
Contributor

@alexcolpitts96 alexcolpitts96 commented Jun 29, 2023

Fixes #1121, fixes #1861.

Adds RIN as a normalization layer option requested in #1121.

This is a prerequisite component for #1861 and #1807.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks a lot for this very cool addition @alexcolpitts96 🚀 :)

I only had some minor suggestions and general questions about small deviations from the official implementation and the usage of the Norm in Darts

darts/models/components/layer_norm_variants.py Outdated Show resolved Hide resolved
darts/models/components/layer_norm_variants.py Outdated Show resolved Hide resolved
darts/models/components/layer_norm_variants.py Outdated Show resolved Hide resolved
darts/models/components/layer_norm_variants.py Outdated Show resolved Hide resolved
darts/models/components/layer_norm_variants.py Outdated Show resolved Hide resolved
darts/models/components/layer_norm_variants.py Outdated Show resolved Hide resolved
darts/tests/models/components/test_layer_norm_variants.py Outdated Show resolved Hide resolved
darts/models/components/layer_norm_variants.py Outdated Show resolved Hide resolved
darts/tests/models/components/test_layer_norm_variants.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Jul 15, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (0e1b084) 93.93% compared to head (e301fef) 93.94%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1865   +/-   ##
=======================================
  Coverage   93.93%   93.94%           
=======================================
  Files         125      125           
  Lines       11743    11755   +12     
=======================================
+ Hits        11031    11043   +12     
  Misses        712      712           
Impacted Files Coverage Δ
darts/models/components/layer_norm_variants.py 100.00% <100.00%> (ø)

... and 6 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks @alexcolpitts96 🚀
Only had some last minor suggestions.
After that we can merge :)

CHANGELOG.md Outdated Show resolved Hide resolved
darts/tests/models/components/test_layer_norm_variants.py Outdated Show resolved Hide resolved
darts/tests/models/components/test_layer_norm_variants.py Outdated Show resolved Hide resolved
self.affine_bias = nn.Parameter(torch.zeros(self.input_dim))

def forward(self, x):
calc_dims = tuple(range(1, x.ndim - 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment in what shape x is expected to be, what each dimension of x represents, and why we take the mean over dimension range(1, x.ndim-1)?
this will makes it easier to interpret in the future.

It will probably be easiest to assume that x is based on the output of Darts torch datasets (loaders).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments. I want to keep it as taking a torch.Tensor for now and not a specific data set. I may change this when it comes to integrating it with models since it may make covariate handling very gross in the models.

When I add it to TiDE (and maybe NHiTS for testing) I will see how it looks and update it accordingly with another PR.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @alexcolpitts96, everything is looking great now! 💯
Will merge as soon as all tests passed.

@dennisbader dennisbader merged commit acd5279 into unit8co:master Jul 18, 2023
9 checks passed
@gdevos010
Copy link
Contributor

@dennisbader @alexcolpitts96 I haven't had a chance to do a deep dive into the code changes or what TIDE/TSMixer model requires, but since this MR closed my issue, I thought I would chime in on my understanding of RevIN.

RevIN is not a layer normalization. You can see in the Informer example and it implemented in PatchTST that it is applied at the start and end of the forward() method, not in-between the layers as layer norms are. To reduce the confusion, I would move it to a separate file since having it in layer_norm_variants.py does matter when doing we do the layer norm lookups in some of the models.

The easiest way to implement REVIN is to add it PLForecastingModule as recommended here so that all models have access to RevIN. To me, #1121 is not completed until its added to PLForecastingModule. I see RevIN included in many of the new forecasting models, so I appreciate the effort to include it.

@dennisbader
Copy link
Collaborator

You're right @gdevos010. We should add it to another file than the layer norm variants. So far RevIN is not used by any model, but @alexcolpitts96 mentioned integrating it while working on the other PRs.
I'll reopen your issue so we won't forget.

For non-probabilistic models (not using a likelihood) it will work fine at the beginning and end of the forward pass. We still have to check how to best do it for models using a likelihood.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants