-
Notifications
You must be signed in to change notification settings - Fork 874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/model deeptime #1329
base: master
Are you sure you want to change the base?
Feat/model deeptime #1329
Conversation
…ociated tests (inspired from the nbeats tests)
…figure_optimizer method
…and taking advantage of the safeguards offered by darts
Codecov ReportBase: 93.97% // Head: 94.03% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #1329 +/- ##
==========================================
+ Coverage 93.97% 94.03% +0.05%
==========================================
Files 82 83 +1
Lines 8917 9102 +185
==========================================
+ Hits 8380 8559 +179
- Misses 537 543 +6
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few comments on quick initial things I noticed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good, nice job @madtoinou !
After glancing at the paper, I also think it would be a nice addition to Darts.
I haven't looked into all minute details of the processing being done but I trust you :) I've got a few small comments. Perhaps the main one concerns nr_params
which we should try to exploit before we merge.
…_global_forecasting_models, the number of epochs during the reduced to the striuct minimum, corrected typo in docstring, removed the mutable default argument, added check in TorchForecastingModel for n_epochs
…crease the length of the prediction from 2 to 3, last version was relying on erroneous broadcasting
…failing at the moment
…eepTime, removed some comments in the forward method, corrected typo
…duler default argument, added check for scheduler warmup_epochs value, removed hardcoded typing of ridge regressor regularization term
…d to warmup_epochs
I find that the results do not seem fantastic in the probabilistic setting. E.g. when running the following code: from darts.datasets import AirPassengersDataset
from darts.dataprocessing.transformers import Scaler
from darts.models import DeepTimeModel
from darts.utils.likelihood_models import GaussianLikelihood, LaplaceLikelihood
series = AirPassengersDataset().load().astype(np.float32)
scaler = Scaler()
train, val = scaler.fit_transform(series[:-36]), scaler.transform(series[-36:])
model = DeepTimeModel(input_chunk_length=24,
output_chunk_length=12,
likelihood=GaussianLikelihood())
model.fit(train, epochs=100)
pred = model.predict(series=train, n=36, num_samples=300)
train.plot()
pred.plot() I get this - the variance seems almost zero. I'm wondering whether this might be due to our treatment of the distributions parameters, which perhaps happens too early in the processing (when creating the time representations), which could (maybe?) cause degenerate results. Could we maybe find a way to "tile" tensors somewhere else later in the forward pass? WDYT @madtoinou ? |
This is indeed a bit disappointing, I should have spend more time looking at the variance of the resulting distribution. There is not much room for tiling downstream: after the INR (fully connected network), there is only the ridge regression trying to solve the equation AX = B where A is the time representation transpose time itself, and B is the time representation transposed multiplied by the observations. I don't see how we could tweak this part. I am going to experiment with using different Fourier features for each distribution parameter (before the INR), it should add bit of heterogeneity but I am not sure that it could solve the issue. |
Fixes #1152.
Summary
Implement the DeepTIMe model from https://arxiv.org/pdf/2207.06046.pdf, based on the original repository https://github.com/salesforce/DeepTime and the article pseudo-code.
Also implement some basics tests, inspired by the tests for N-Beats.
Other Information
In the original article, distinct optimizers are defined for the three groups of parameters: Ridge Regression regularization term, the biais/norm of the Implicit Neural Representation (INR) network and the weights of the INR. This was accomplished by overriding the configure_optimizer method and partially breaking the logic behind the lr_scheduler_cls and lr_scheduler_kwargs arguments. To make the model easier to use out of the box, the default arguments correspond to the original article parameters (including for the optimizer).
All the module necessary for this architecture were included in the same file to limit the fragmentation of the code. The Ridge Regression and the INR modules could however be extracted if others models require them.
The support for the
nr_params
functionnality is not implemented yet.