Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer refactorisation #1915

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions darts/models/forecasting/transformer_model.py
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor):

return predictions

# Allow teacher forcing
def training_step(self, train_batch, batch_idx) -> torch.Tensor:
"""performs the training step"""
train_batch = list(train_batch)
Expand Down Expand Up @@ -410,7 +409,7 @@ def __init__(
The multi-head attention mechanism is highly parallelizable, which makes the transformer architecture
very suitable to be trained with GPUs.

The transformer architecture implemented here is based on [1]_ and uses teacher forcing [2]_.
The transformer architecture implemented here is based on [1]_ abd uses teacher forcing [4]_.
JanFidor marked this conversation as resolved.
Show resolved Hide resolved

This model supports past covariates (known for `input_chunk_length` points before prediction time).

Expand Down Expand Up @@ -513,14 +512,11 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [encode_year]},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'transformer': Scaler()
}
..
Expand Down Expand Up @@ -587,6 +583,7 @@ def encode_year(idx):
.. [2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202.
.. [3] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p
.. [4] Teacher Forcing PyTorch tutorial: https://github.com/pytorch/examples/tree/main/word_language_model

Examples
--------
Expand Down
Loading