From bee1e11283a45dad87ed97e6316f302eb2613d92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 26 Dec 2024 00:59:33 +0100 Subject: [PATCH] refactor split init --- pytorch_forecasting/models/deepar/__init__.py | 447 +-------- pytorch_forecasting/models/deepar/_deepar.py | 446 +++++++++ pytorch_forecasting/models/mlp/__init__.py | 179 +--- pytorch_forecasting/models/mlp/_decodermlp.py | 179 ++++ pytorch_forecasting/models/nbeats/__init__.py | 376 +------- pytorch_forecasting/models/nbeats/_nbeats.py | 376 ++++++++ pytorch_forecasting/models/nhits/__init__.py | 595 +----------- pytorch_forecasting/models/nhits/_nhits.py | 595 ++++++++++++ pytorch_forecasting/models/rnn/__init__.py | 318 +------ pytorch_forecasting/models/rnn/_rnn.py | 317 ++++++ .../temporal_fusion_transformer/__init__.py | 899 +----------------- .../temporal_fusion_transformer/_tft.py | 898 +++++++++++++++++ 12 files changed, 2837 insertions(+), 2788 deletions(-) create mode 100644 pytorch_forecasting/models/deepar/_deepar.py create mode 100644 pytorch_forecasting/models/mlp/_decodermlp.py create mode 100644 pytorch_forecasting/models/nbeats/_nbeats.py create mode 100644 pytorch_forecasting/models/nhits/_nhits.py create mode 100644 pytorch_forecasting/models/rnn/_rnn.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/_tft.py diff --git a/pytorch_forecasting/models/deepar/__init__.py b/pytorch_forecasting/models/deepar/__init__.py index f9bcb186..679f296f 100644 --- a/pytorch_forecasting/models/deepar/__init__.py +++ b/pytorch_forecasting/models/deepar/__init__.py @@ -1,446 +1,5 @@ -""" -`DeepAR: Probabilistic forecasting with autoregressive recurrent networks -`_ -which is the one of the most popular forecasting algorithms and is often used as a baseline -""" +"""DeepAR: Probabilistic forecasting with autoregressive recurrent networks.""" -from copy import deepcopy -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from pytorch_forecasting.models.deepar._deepar import DeepAR -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -from torch.utils.data.dataloader import DataLoader - -from pytorch_forecasting.data.encoders import MultiNormalizer, NaNLabelEncoder -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet -from pytorch_forecasting.metrics import ( - MAE, - MAPE, - MASE, - RMSE, - SMAPE, - DistributionLoss, - MultiLoss, - MultivariateDistributionLoss, - NormalDistributionLoss, -) -from pytorch_forecasting.models.base_model import AutoRegressiveBaseModelWithCovariates, Prediction -from pytorch_forecasting.models.nn import HiddenState, MultiEmbedding, get_rnn -from pytorch_forecasting.utils import apply_to_list, to_list - - -class DeepAR(AutoRegressiveBaseModelWithCovariates): - def __init__( - self, - cell_type: str = "LSTM", - hidden_size: int = 10, - rnn_layers: int = 2, - dropout: float = 0.1, - static_categoricals: Optional[List[str]] = None, - static_reals: Optional[List[str]] = None, - time_varying_categoricals_encoder: Optional[List[str]] = None, - time_varying_categoricals_decoder: Optional[List[str]] = None, - categorical_groups: Optional[Dict[str, List[str]]] = None, - time_varying_reals_encoder: Optional[List[str]] = None, - time_varying_reals_decoder: Optional[List[str]] = None, - embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, - embedding_paddings: Optional[List[str]] = None, - embedding_labels: Optional[Dict[str, np.ndarray]] = None, - x_reals: Optional[List[str]] = None, - x_categoricals: Optional[List[str]] = None, - n_validation_samples: int = None, - n_plotting_samples: int = None, - target: Union[str, List[str]] = None, - target_lags: Optional[Dict[str, List[int]]] = None, - loss: DistributionLoss = None, - logging_metrics: nn.ModuleList = None, - **kwargs, - ): - """ - DeepAR Network. - - The code is based on the article `DeepAR: Probabilistic forecasting with autoregressive recurrent networks - `_. - - By using a Multivariate Loss such as the - :py:class:`~pytorch_forecasting.metrics.MultivariateNormalDistributionLoss`, - the network is converted into a `DeepVAR network `_. - - Args: - cell_type (str, optional): Recurrent cell type ["LSTM", "GRU"]. Defaults to "LSTM". - hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with - ``rnn_layers``. Defaults to 10. - rnn_layers (int, optional): Number of RNN layers - important hyperparameter. Defaults to 2. - dropout (float, optional): Dropout in RNN layers. Defaults to 0.1. - static_categoricals: integer of positions of static categorical variables - static_reals: integer of positions of static continuous variables - time_varying_categoricals_encoder: integer of positions of categorical variables for encoder - time_varying_categoricals_decoder: integer of positions of categorical variables for decoder - time_varying_reals_encoder: integer of positions of continuous variables for encoder - time_varying_reals_decoder: integer of positions of continuous variables for decoder - categorical_groups: dictionary where values - are list of categorical variables that are forming together a new categorical - variable which is the key in the dictionary - x_reals: order of continuous variables in tensor passed to forward function - x_categoricals: order of categorical variables in tensor passed to forward function - embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and - embedding size - embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector - embedding_labels: dictionary mapping (string) indices to list of categorical labels - n_validation_samples (int, optional): Number of samples to use for calculating validation metrics. - Defaults to None, i.e. no sampling at validation stage and using "mean" of distribution for logging - metrics calculation. - n_plotting_samples (int, optional): Number of samples to generate for plotting predictions - during training. Defaults to ``n_validation_samples`` if not None or 100 otherwise. - target (str, optional): Target variable or list of target variables. Defaults to None. - target_lags (Dict[str, Dict[str, int]]): dictionary of target names mapped to list of time steps by - which the variable should be lagged. - Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data, - add at least the target variables with the corresponding lags to improve performance. - Defaults to no lags, i.e. an empty dictionary. - loss (DistributionLoss, optional): Distribution loss function. Keep in mind that each distribution - loss function might have specific requirements for target normalization. - Defaults to :py:class:`~pytorch_forecasting.metrics.NormalDistributionLoss`. - logging_metrics (nn.ModuleList, optional): Metrics to log during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). - """ - if loss is None: - loss = NormalDistributionLoss() - if logging_metrics is None: - logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - if n_plotting_samples is None: - if n_validation_samples is None: - n_plotting_samples = n_validation_samples - else: - n_plotting_samples = 100 - if static_categoricals is None: - static_categoricals = [] - if static_reals is None: - static_reals = [] - if time_varying_categoricals_encoder is None: - time_varying_categoricals_encoder = [] - if time_varying_categoricals_decoder is None: - time_varying_categoricals_decoder = [] - if categorical_groups is None: - categorical_groups = {} - if time_varying_reals_encoder is None: - time_varying_reals_encoder = [] - if time_varying_reals_decoder is None: - time_varying_reals_decoder = [] - if embedding_sizes is None: - embedding_sizes = {} - if embedding_paddings is None: - embedding_paddings = [] - if embedding_labels is None: - embedding_labels = {} - if x_reals is None: - x_reals = [] - if x_categoricals is None: - x_categoricals = [] - if target_lags is None: - target_lags = {} - self.save_hyperparameters() - # store loss function separately as it is a module - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - - self.embeddings = MultiEmbedding( - embedding_sizes=embedding_sizes, - embedding_paddings=embedding_paddings, - categorical_groups=categorical_groups, - x_categoricals=x_categoricals, - ) - - lagged_target_names = [l for lags in target_lags.values() for l in lags] - assert set(self.encoder_variables) - set(to_list(target)) - set(lagged_target_names) == set( - self.decoder_variables - ) - set(lagged_target_names), "Encoder and decoder variables have to be the same apart from target variable" - for targeti in to_list(target): - assert ( - targeti in time_varying_reals_encoder - ), f"target {targeti} has to be real" # todo: remove this restriction - assert (isinstance(target, str) and isinstance(loss, DistributionLoss)) or ( - isinstance(target, (list, tuple)) and isinstance(loss, MultiLoss) and len(loss) == len(target) - ), "number of targets should be equivalent to number of loss metrics" - - rnn_class = get_rnn(cell_type) - cont_size = len(self.reals) - cat_size = sum(self.embeddings.output_size.values()) - input_size = cont_size + cat_size - self.rnn = rnn_class( - input_size=input_size, - hidden_size=self.hparams.hidden_size, - num_layers=self.hparams.rnn_layers, - dropout=self.hparams.dropout if self.hparams.rnn_layers > 1 else 0, - batch_first=True, - ) - - # add linear layers for argument projects - if isinstance(target, str): # single target - self.distribution_projector = nn.Linear(self.hparams.hidden_size, len(self.loss.distribution_arguments)) - else: # multi target - self.distribution_projector = nn.ModuleList( - [nn.Linear(self.hparams.hidden_size, len(args)) for args in self.loss.distribution_arguments] - ) - - @classmethod - def from_dataset( - cls, - dataset: TimeSeriesDataSet, - allowed_encoder_known_variable_names: List[str] = None, - **kwargs, - ): - """ - Create model from dataset. - - Args: - dataset: timeseries dataset - allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all - **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) - - Returns: - DeepAR network - """ - new_kwargs = {} - if dataset.multi_target: - new_kwargs.setdefault("loss", MultiLoss([NormalDistributionLoss()] * len(dataset.target_names))) - new_kwargs.update(kwargs) - assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( - not isinstance(dataset.target_normalizer, MultiNormalizer) - or all(not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer) - ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction - if isinstance(new_kwargs.get("loss", None), MultivariateDistributionLoss): - assert ( - dataset.min_prediction_length == dataset.max_prediction_length - ), "Multivariate models require constant prediction lenghts" - return super().from_dataset( - dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs - ) - - def construct_input_vector( - self, x_cat: torch.Tensor, x_cont: torch.Tensor, one_off_target: torch.Tensor = None - ) -> torch.Tensor: - """ - Create input vector into RNN network - - Args: - one_off_target: tensor to insert into first position of target. If None (default), remove first time step. - """ - # create input vector - if len(self.categoricals) > 0: - embeddings = self.embeddings(x_cat) - flat_embeddings = torch.cat(list(embeddings.values()), dim=-1) - input_vector = flat_embeddings - - if len(self.reals) > 0: - input_vector = x_cont.clone() - - if len(self.reals) > 0 and len(self.categoricals) > 0: - input_vector = torch.cat([x_cont, flat_embeddings], dim=-1) - - # shift target by one - input_vector[..., self.target_positions] = torch.roll( - input_vector[..., self.target_positions], shifts=1, dims=1 - ) - - if one_off_target is not None: # set first target input (which is rolled over) - input_vector[:, 0, self.target_positions] = one_off_target - else: - input_vector = input_vector[:, 1:] - - # shift target - return input_vector - - def encode(self, x: Dict[str, torch.Tensor]) -> HiddenState: - """ - Encode sequence into hidden state - """ - # encode using rnn - assert x["encoder_lengths"].min() > 0 - encoder_lengths = x["encoder_lengths"] - 1 - input_vector = self.construct_input_vector(x["encoder_cat"], x["encoder_cont"]) - _, hidden_state = self.rnn( - input_vector, lengths=encoder_lengths, enforce_sorted=False - ) # second ouput is not needed (hidden state) - return hidden_state - - def decode_all( - self, - x: torch.Tensor, - hidden_state: HiddenState, - lengths: torch.Tensor = None, - ): - decoder_output, hidden_state = self.rnn(x, hidden_state, lengths=lengths, enforce_sorted=False) - if isinstance(self.hparams.target, str): # single target - output = self.distribution_projector(decoder_output) - else: - output = [projector(decoder_output) for projector in self.distribution_projector] - return output, hidden_state - - def decode( - self, - input_vector: torch.Tensor, - target_scale: torch.Tensor, - decoder_lengths: torch.Tensor, - hidden_state: HiddenState, - n_samples: int = None, - ) -> Tuple[torch.Tensor, bool]: - """ - Decode hidden state of RNN into prediction. If n_smaples is given, - decode not by using actual values but rather by - sampling new targets from past predictions iteratively - """ - if n_samples is None: - output, _ = self.decode_all(input_vector, hidden_state, lengths=decoder_lengths) - output = self.transform_output(output, target_scale=target_scale) - else: - # run in eval, i.e. simulation mode - target_pos = self.target_positions - lagged_target_positions = self.lagged_target_positions - # repeat for n_samples - input_vector = input_vector.repeat_interleave(n_samples, 0) - hidden_state = self.rnn.repeat_interleave(hidden_state, n_samples) - target_scale = apply_to_list(target_scale, lambda x: x.repeat_interleave(n_samples, 0)) - - # define function to run at every decoding step - def decode_one( - idx, - lagged_targets, - hidden_state, - ): - x = input_vector[:, [idx]] - x[:, 0, target_pos] = lagged_targets[-1] - for lag, lag_positions in lagged_target_positions.items(): - if idx > lag: - x[:, 0, lag_positions] = lagged_targets[-lag] - prediction, hidden_state = self.decode_all(x, hidden_state) - prediction = apply_to_list(prediction, lambda x: x[:, 0]) # select first time step - return prediction, hidden_state - - # make predictions which are fed into next step - output = self.decode_autoregressive( - decode_one, - first_target=input_vector[:, 0, target_pos], - first_hidden_state=hidden_state, - target_scale=target_scale, - n_decoder_steps=input_vector.size(1), - n_samples=n_samples, - ) - # reshape predictions for n_samples: - # from n_samples * batch_size x time steps to batch_size x time steps x n_samples - output = apply_to_list(output, lambda x: x.reshape(-1, n_samples, input_vector.size(1)).permute(0, 2, 1)) - return output - - def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: - """ - Forward network - """ - hidden_state = self.encode(x) - # decode - input_vector = self.construct_input_vector( - x["decoder_cat"], - x["decoder_cont"], - one_off_target=x["encoder_cont"][ - torch.arange(x["encoder_cont"].size(0), device=x["encoder_cont"].device), - x["encoder_lengths"] - 1, - self.target_positions.unsqueeze(-1), - ].T.contiguous(), - ) - - if self.training: - assert n_samples is None, "cannot sample from decoder when training" - output = self.decode( - input_vector, - decoder_lengths=x["decoder_lengths"], - target_scale=x["target_scale"], - hidden_state=hidden_state, - n_samples=n_samples, - ) - # return relevant part - return self.to_network_output(prediction=output) - - def create_log(self, x, y, out, batch_idx): - n_samples = [self.hparams.n_validation_samples, self.hparams.n_plotting_samples][self.training] - log = super().create_log( - x, - y, - out, - batch_idx, - prediction_kwargs=dict(n_samples=n_samples), - quantiles_kwargs=dict(n_samples=n_samples), - ) - return log - - def predict( - self, - data: Union[DataLoader, pd.DataFrame, TimeSeriesDataSet], - mode: Union[str, Tuple[str, str]] = "prediction", - return_index: bool = False, - return_decoder_lengths: bool = False, - batch_size: int = 64, - num_workers: int = 0, - fast_dev_run: bool = False, - return_x: bool = False, - return_y: bool = False, - mode_kwargs: Dict[str, Any] = None, - trainer_kwargs: Optional[Dict[str, Any]] = None, - write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch", - output_dir: Optional[str] = None, - n_samples: int = 100, - **kwargs, - ) -> Prediction: - """ - predict dataloader - - Args: - dataloader: dataloader, dataframe or dataset - mode: one of "prediction", "quantiles", "samples" or "raw", or tuple ``("raw", output_name)`` where - output_name is a name in the dictionary returned by ``forward()`` - return_index: if to return the prediction index (in the same order as the output, i.e. the row of the - dataframe corresponds to the first dimension of the output and the given time index is the time index - of the first prediction) - return_decoder_lengths: if to return decoder_lengths (in the same order as the output - batch_size: batch size for dataloader - only used if data is not a dataloader is passed - num_workers: number of workers for dataloader - only used if data is not a dataloader is passed - fast_dev_run: if to only return results of first batch - show_progress_bar: if to show progress bar. Defaults to False. - return_x: if to return network inputs (in the same order as prediction output) - return_y: if to return network targets (in the same order as prediction output) - mode_kwargs (Dict[str, Any]): keyword arguments for ``to_prediction()`` or ``to_quantiles()`` - for modes "prediction" and "quantiles" - trainer_kwargs (Dict[str, Any], optional): keyword arguments for the trainer - write_interval: interval to write predictions to disk - output_dir: directory to write predictions to. Defaults to None. If set function will return empty list - n_samples: number of samples to draw. Defaults to 100. - - Returns: - Prediction: if one of the ```return`` arguments is present, - prediction tuple with fields ``prediction``, ``x``, ``y``, ``index`` and ``decoder_lengths`` - """ - if isinstance(mode, str): - if mode in ["prediction", "quantiles"]: - if mode_kwargs is None: - mode_kwargs = dict(use_metric=False) - else: - mode_kwargs = deepcopy(mode_kwargs) - mode_kwargs["use_metric"] = False - elif mode == "samples": - mode = ("raw", "prediction") - return super().predict( - data=data, - mode=mode, - return_decoder_lengths=return_decoder_lengths, - return_index=return_index, - n_samples=n_samples, # new keyword that is passed to forward function - return_x=return_x, - fast_dev_run=fast_dev_run, - num_workers=num_workers, - batch_size=batch_size, - mode_kwargs=mode_kwargs, - trainer_kwargs=trainer_kwargs, - write_interval=write_interval, - output_dir=output_dir, - return_y=return_y, - **kwargs, - ) +__all__ = ["DeepAR"] diff --git a/pytorch_forecasting/models/deepar/_deepar.py b/pytorch_forecasting/models/deepar/_deepar.py new file mode 100644 index 00000000..f9bcb186 --- /dev/null +++ b/pytorch_forecasting/models/deepar/_deepar.py @@ -0,0 +1,446 @@ +""" +`DeepAR: Probabilistic forecasting with autoregressive recurrent networks +`_ +which is the one of the most popular forecasting algorithms and is often used as a baseline +""" + +from copy import deepcopy +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader + +from pytorch_forecasting.data.encoders import MultiNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.metrics import ( + MAE, + MAPE, + MASE, + RMSE, + SMAPE, + DistributionLoss, + MultiLoss, + MultivariateDistributionLoss, + NormalDistributionLoss, +) +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModelWithCovariates, Prediction +from pytorch_forecasting.models.nn import HiddenState, MultiEmbedding, get_rnn +from pytorch_forecasting.utils import apply_to_list, to_list + + +class DeepAR(AutoRegressiveBaseModelWithCovariates): + def __init__( + self, + cell_type: str = "LSTM", + hidden_size: int = 10, + rnn_layers: int = 2, + dropout: float = 0.1, + static_categoricals: Optional[List[str]] = None, + static_reals: Optional[List[str]] = None, + time_varying_categoricals_encoder: Optional[List[str]] = None, + time_varying_categoricals_decoder: Optional[List[str]] = None, + categorical_groups: Optional[Dict[str, List[str]]] = None, + time_varying_reals_encoder: Optional[List[str]] = None, + time_varying_reals_decoder: Optional[List[str]] = None, + embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, + embedding_paddings: Optional[List[str]] = None, + embedding_labels: Optional[Dict[str, np.ndarray]] = None, + x_reals: Optional[List[str]] = None, + x_categoricals: Optional[List[str]] = None, + n_validation_samples: int = None, + n_plotting_samples: int = None, + target: Union[str, List[str]] = None, + target_lags: Optional[Dict[str, List[int]]] = None, + loss: DistributionLoss = None, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + DeepAR Network. + + The code is based on the article `DeepAR: Probabilistic forecasting with autoregressive recurrent networks + `_. + + By using a Multivariate Loss such as the + :py:class:`~pytorch_forecasting.metrics.MultivariateNormalDistributionLoss`, + the network is converted into a `DeepVAR network `_. + + Args: + cell_type (str, optional): Recurrent cell type ["LSTM", "GRU"]. Defaults to "LSTM". + hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with + ``rnn_layers``. Defaults to 10. + rnn_layers (int, optional): Number of RNN layers - important hyperparameter. Defaults to 2. + dropout (float, optional): Dropout in RNN layers. Defaults to 0.1. + static_categoricals: integer of positions of static categorical variables + static_reals: integer of positions of static continuous variables + time_varying_categoricals_encoder: integer of positions of categorical variables for encoder + time_varying_categoricals_decoder: integer of positions of categorical variables for decoder + time_varying_reals_encoder: integer of positions of continuous variables for encoder + time_varying_reals_decoder: integer of positions of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + n_validation_samples (int, optional): Number of samples to use for calculating validation metrics. + Defaults to None, i.e. no sampling at validation stage and using "mean" of distribution for logging + metrics calculation. + n_plotting_samples (int, optional): Number of samples to generate for plotting predictions + during training. Defaults to ``n_validation_samples`` if not None or 100 otherwise. + target (str, optional): Target variable or list of target variables. Defaults to None. + target_lags (Dict[str, Dict[str, int]]): dictionary of target names mapped to list of time steps by + which the variable should be lagged. + Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data, + add at least the target variables with the corresponding lags to improve performance. + Defaults to no lags, i.e. an empty dictionary. + loss (DistributionLoss, optional): Distribution loss function. Keep in mind that each distribution + loss function might have specific requirements for target normalization. + Defaults to :py:class:`~pytorch_forecasting.metrics.NormalDistributionLoss`. + logging_metrics (nn.ModuleList, optional): Metrics to log during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). + """ + if loss is None: + loss = NormalDistributionLoss() + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if n_plotting_samples is None: + if n_validation_samples is None: + n_plotting_samples = n_validation_samples + else: + n_plotting_samples = 100 + if static_categoricals is None: + static_categoricals = [] + if static_reals is None: + static_reals = [] + if time_varying_categoricals_encoder is None: + time_varying_categoricals_encoder = [] + if time_varying_categoricals_decoder is None: + time_varying_categoricals_decoder = [] + if categorical_groups is None: + categorical_groups = {} + if time_varying_reals_encoder is None: + time_varying_reals_encoder = [] + if time_varying_reals_decoder is None: + time_varying_reals_decoder = [] + if embedding_sizes is None: + embedding_sizes = {} + if embedding_paddings is None: + embedding_paddings = [] + if embedding_labels is None: + embedding_labels = {} + if x_reals is None: + x_reals = [] + if x_categoricals is None: + x_categoricals = [] + if target_lags is None: + target_lags = {} + self.save_hyperparameters() + # store loss function separately as it is a module + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + self.embeddings = MultiEmbedding( + embedding_sizes=embedding_sizes, + embedding_paddings=embedding_paddings, + categorical_groups=categorical_groups, + x_categoricals=x_categoricals, + ) + + lagged_target_names = [l for lags in target_lags.values() for l in lags] + assert set(self.encoder_variables) - set(to_list(target)) - set(lagged_target_names) == set( + self.decoder_variables + ) - set(lagged_target_names), "Encoder and decoder variables have to be the same apart from target variable" + for targeti in to_list(target): + assert ( + targeti in time_varying_reals_encoder + ), f"target {targeti} has to be real" # todo: remove this restriction + assert (isinstance(target, str) and isinstance(loss, DistributionLoss)) or ( + isinstance(target, (list, tuple)) and isinstance(loss, MultiLoss) and len(loss) == len(target) + ), "number of targets should be equivalent to number of loss metrics" + + rnn_class = get_rnn(cell_type) + cont_size = len(self.reals) + cat_size = sum(self.embeddings.output_size.values()) + input_size = cont_size + cat_size + self.rnn = rnn_class( + input_size=input_size, + hidden_size=self.hparams.hidden_size, + num_layers=self.hparams.rnn_layers, + dropout=self.hparams.dropout if self.hparams.rnn_layers > 1 else 0, + batch_first=True, + ) + + # add linear layers for argument projects + if isinstance(target, str): # single target + self.distribution_projector = nn.Linear(self.hparams.hidden_size, len(self.loss.distribution_arguments)) + else: # multi target + self.distribution_projector = nn.ModuleList( + [nn.Linear(self.hparams.hidden_size, len(args)) for args in self.loss.distribution_arguments] + ) + + @classmethod + def from_dataset( + cls, + dataset: TimeSeriesDataSet, + allowed_encoder_known_variable_names: List[str] = None, + **kwargs, + ): + """ + Create model from dataset. + + Args: + dataset: timeseries dataset + allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all + **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) + + Returns: + DeepAR network + """ + new_kwargs = {} + if dataset.multi_target: + new_kwargs.setdefault("loss", MultiLoss([NormalDistributionLoss()] * len(dataset.target_names))) + new_kwargs.update(kwargs) + assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( + not isinstance(dataset.target_normalizer, MultiNormalizer) + or all(not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer) + ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction + if isinstance(new_kwargs.get("loss", None), MultivariateDistributionLoss): + assert ( + dataset.min_prediction_length == dataset.max_prediction_length + ), "Multivariate models require constant prediction lenghts" + return super().from_dataset( + dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs + ) + + def construct_input_vector( + self, x_cat: torch.Tensor, x_cont: torch.Tensor, one_off_target: torch.Tensor = None + ) -> torch.Tensor: + """ + Create input vector into RNN network + + Args: + one_off_target: tensor to insert into first position of target. If None (default), remove first time step. + """ + # create input vector + if len(self.categoricals) > 0: + embeddings = self.embeddings(x_cat) + flat_embeddings = torch.cat(list(embeddings.values()), dim=-1) + input_vector = flat_embeddings + + if len(self.reals) > 0: + input_vector = x_cont.clone() + + if len(self.reals) > 0 and len(self.categoricals) > 0: + input_vector = torch.cat([x_cont, flat_embeddings], dim=-1) + + # shift target by one + input_vector[..., self.target_positions] = torch.roll( + input_vector[..., self.target_positions], shifts=1, dims=1 + ) + + if one_off_target is not None: # set first target input (which is rolled over) + input_vector[:, 0, self.target_positions] = one_off_target + else: + input_vector = input_vector[:, 1:] + + # shift target + return input_vector + + def encode(self, x: Dict[str, torch.Tensor]) -> HiddenState: + """ + Encode sequence into hidden state + """ + # encode using rnn + assert x["encoder_lengths"].min() > 0 + encoder_lengths = x["encoder_lengths"] - 1 + input_vector = self.construct_input_vector(x["encoder_cat"], x["encoder_cont"]) + _, hidden_state = self.rnn( + input_vector, lengths=encoder_lengths, enforce_sorted=False + ) # second ouput is not needed (hidden state) + return hidden_state + + def decode_all( + self, + x: torch.Tensor, + hidden_state: HiddenState, + lengths: torch.Tensor = None, + ): + decoder_output, hidden_state = self.rnn(x, hidden_state, lengths=lengths, enforce_sorted=False) + if isinstance(self.hparams.target, str): # single target + output = self.distribution_projector(decoder_output) + else: + output = [projector(decoder_output) for projector in self.distribution_projector] + return output, hidden_state + + def decode( + self, + input_vector: torch.Tensor, + target_scale: torch.Tensor, + decoder_lengths: torch.Tensor, + hidden_state: HiddenState, + n_samples: int = None, + ) -> Tuple[torch.Tensor, bool]: + """ + Decode hidden state of RNN into prediction. If n_smaples is given, + decode not by using actual values but rather by + sampling new targets from past predictions iteratively + """ + if n_samples is None: + output, _ = self.decode_all(input_vector, hidden_state, lengths=decoder_lengths) + output = self.transform_output(output, target_scale=target_scale) + else: + # run in eval, i.e. simulation mode + target_pos = self.target_positions + lagged_target_positions = self.lagged_target_positions + # repeat for n_samples + input_vector = input_vector.repeat_interleave(n_samples, 0) + hidden_state = self.rnn.repeat_interleave(hidden_state, n_samples) + target_scale = apply_to_list(target_scale, lambda x: x.repeat_interleave(n_samples, 0)) + + # define function to run at every decoding step + def decode_one( + idx, + lagged_targets, + hidden_state, + ): + x = input_vector[:, [idx]] + x[:, 0, target_pos] = lagged_targets[-1] + for lag, lag_positions in lagged_target_positions.items(): + if idx > lag: + x[:, 0, lag_positions] = lagged_targets[-lag] + prediction, hidden_state = self.decode_all(x, hidden_state) + prediction = apply_to_list(prediction, lambda x: x[:, 0]) # select first time step + return prediction, hidden_state + + # make predictions which are fed into next step + output = self.decode_autoregressive( + decode_one, + first_target=input_vector[:, 0, target_pos], + first_hidden_state=hidden_state, + target_scale=target_scale, + n_decoder_steps=input_vector.size(1), + n_samples=n_samples, + ) + # reshape predictions for n_samples: + # from n_samples * batch_size x time steps to batch_size x time steps x n_samples + output = apply_to_list(output, lambda x: x.reshape(-1, n_samples, input_vector.size(1)).permute(0, 2, 1)) + return output + + def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: + """ + Forward network + """ + hidden_state = self.encode(x) + # decode + input_vector = self.construct_input_vector( + x["decoder_cat"], + x["decoder_cont"], + one_off_target=x["encoder_cont"][ + torch.arange(x["encoder_cont"].size(0), device=x["encoder_cont"].device), + x["encoder_lengths"] - 1, + self.target_positions.unsqueeze(-1), + ].T.contiguous(), + ) + + if self.training: + assert n_samples is None, "cannot sample from decoder when training" + output = self.decode( + input_vector, + decoder_lengths=x["decoder_lengths"], + target_scale=x["target_scale"], + hidden_state=hidden_state, + n_samples=n_samples, + ) + # return relevant part + return self.to_network_output(prediction=output) + + def create_log(self, x, y, out, batch_idx): + n_samples = [self.hparams.n_validation_samples, self.hparams.n_plotting_samples][self.training] + log = super().create_log( + x, + y, + out, + batch_idx, + prediction_kwargs=dict(n_samples=n_samples), + quantiles_kwargs=dict(n_samples=n_samples), + ) + return log + + def predict( + self, + data: Union[DataLoader, pd.DataFrame, TimeSeriesDataSet], + mode: Union[str, Tuple[str, str]] = "prediction", + return_index: bool = False, + return_decoder_lengths: bool = False, + batch_size: int = 64, + num_workers: int = 0, + fast_dev_run: bool = False, + return_x: bool = False, + return_y: bool = False, + mode_kwargs: Dict[str, Any] = None, + trainer_kwargs: Optional[Dict[str, Any]] = None, + write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch", + output_dir: Optional[str] = None, + n_samples: int = 100, + **kwargs, + ) -> Prediction: + """ + predict dataloader + + Args: + dataloader: dataloader, dataframe or dataset + mode: one of "prediction", "quantiles", "samples" or "raw", or tuple ``("raw", output_name)`` where + output_name is a name in the dictionary returned by ``forward()`` + return_index: if to return the prediction index (in the same order as the output, i.e. the row of the + dataframe corresponds to the first dimension of the output and the given time index is the time index + of the first prediction) + return_decoder_lengths: if to return decoder_lengths (in the same order as the output + batch_size: batch size for dataloader - only used if data is not a dataloader is passed + num_workers: number of workers for dataloader - only used if data is not a dataloader is passed + fast_dev_run: if to only return results of first batch + show_progress_bar: if to show progress bar. Defaults to False. + return_x: if to return network inputs (in the same order as prediction output) + return_y: if to return network targets (in the same order as prediction output) + mode_kwargs (Dict[str, Any]): keyword arguments for ``to_prediction()`` or ``to_quantiles()`` + for modes "prediction" and "quantiles" + trainer_kwargs (Dict[str, Any], optional): keyword arguments for the trainer + write_interval: interval to write predictions to disk + output_dir: directory to write predictions to. Defaults to None. If set function will return empty list + n_samples: number of samples to draw. Defaults to 100. + + Returns: + Prediction: if one of the ```return`` arguments is present, + prediction tuple with fields ``prediction``, ``x``, ``y``, ``index`` and ``decoder_lengths`` + """ + if isinstance(mode, str): + if mode in ["prediction", "quantiles"]: + if mode_kwargs is None: + mode_kwargs = dict(use_metric=False) + else: + mode_kwargs = deepcopy(mode_kwargs) + mode_kwargs["use_metric"] = False + elif mode == "samples": + mode = ("raw", "prediction") + return super().predict( + data=data, + mode=mode, + return_decoder_lengths=return_decoder_lengths, + return_index=return_index, + n_samples=n_samples, # new keyword that is passed to forward function + return_x=return_x, + fast_dev_run=fast_dev_run, + num_workers=num_workers, + batch_size=batch_size, + mode_kwargs=mode_kwargs, + trainer_kwargs=trainer_kwargs, + write_interval=write_interval, + output_dir=output_dir, + return_y=return_y, + **kwargs, + ) diff --git a/pytorch_forecasting/models/mlp/__init__.py b/pytorch_forecasting/models/mlp/__init__.py index 24b72c0f..6a3532fc 100644 --- a/pytorch_forecasting/models/mlp/__init__.py +++ b/pytorch_forecasting/models/mlp/__init__.py @@ -1,179 +1,6 @@ -""" -Simple models based on fully connected networks -""" +"""Simple models based on fully connected networks.""" -from typing import Dict, List, Tuple, Union, Optional - -import numpy as np -import torch -from torch import nn - -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss -from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.mlp._decodermlp import DecoderMLP from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule -from pytorch_forecasting.models.nn.embeddings import MultiEmbedding - - -class DecoderMLP(BaseModelWithCovariates): - """ - MLP on the decoder. - - MLP that predicts output only based on information available in the decoder. - """ - - def __init__( - self, - activation_class: str = "ReLU", - hidden_size: int = 300, - n_hidden_layers: int = 3, - dropout: float = 0.1, - norm: bool = True, - static_categoricals: Optional[List[str]] = None, - static_reals: Optional[List[str]] = None, - time_varying_categoricals_encoder: Optional[List[str]] = None, - time_varying_categoricals_decoder: Optional[List[str]] = None, - categorical_groups: Optional[Dict[str, List[str]]] = None, - time_varying_reals_encoder: Optional[List[str]] = None, - time_varying_reals_decoder: Optional[List[str]] = None, - embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, - embedding_paddings: Optional[List[str]] = None, - embedding_labels: Optional[Dict[str, np.ndarray]] = None, - x_reals: Optional[List[str]] = None, - x_categoricals: Optional[List[str]] = None, - output_size: Union[int, List[int]] = 1, - target: Union[str, List[str]] = None, - loss: MultiHorizonMetric = None, - logging_metrics: nn.ModuleList = None, - **kwargs, - ): - """ - Args: - activation_class (str, optional): PyTorch activation class. Defaults to "ReLU". - hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with - ``n_hidden_layers``. Defaults to 10. - n_hidden_layers (int, optional): Number of hidden layers - important hyperparameter. Defaults to 2. - dropout (float, optional): Dropout. Defaults to 0.1. - norm (bool, optional): if to use normalization in the MLP. Defaults to True. - static_categoricals: integer of positions of static categorical variables - static_reals: integer of positions of static continuous variables - time_varying_categoricals_encoder: integer of positions of categorical variables for encoder - time_varying_categoricals_decoder: integer of positions of categorical variables for decoder - time_varying_reals_encoder: integer of positions of continuous variables for encoder - time_varying_reals_decoder: integer of positions of continuous variables for decoder - categorical_groups: dictionary where values - are list of categorical variables that are forming together a new categorical - variable which is the key in the dictionary - x_reals: order of continuous variables in tensor passed to forward function - x_categoricals: order of categorical variables in tensor passed to forward function - embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and - embedding size - embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector - embedding_labels: dictionary mapping (string) indices to list of categorical labels - output_size (Union[int, List[int]], optional): number of outputs (e.g. number of quantiles for - QuantileLoss and one target or list of output sizes). - target (str, optional): Target variable or list of target variables. Defaults to None. - loss (MultiHorizonMetric, optional): loss: loss function taking prediction and targets. - Defaults to QuantileLoss. - logging_metrics (nn.ModuleList, optional): Metrics to log during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). - """ - if loss is None: - loss = QuantileLoss() - if logging_metrics is None: - logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - if static_categoricals is None: - static_categoricals = [] - if static_reals is None: - static_reals = [] - if time_varying_reals_encoder is None: - time_varying_reals_encoder = [] - if time_varying_categoricals_decoder is None: - time_varying_categoricals_decoder = [] - if categorical_groups is None: - categorical_groups = {} - if time_varying_reals_encoder is None: - time_varying_reals_encoder = [] - if time_varying_reals_decoder is None: - time_varying_reals_decoder = [] - if embedding_sizes is None: - embedding_sizes = {} - if embedding_paddings is None: - embedding_paddings = [] - if embedding_labels is None: - embedding_labels = {} - if x_reals is None: - x_reals = [] - if x_categoricals is None: - x_categoricals = [] - self.save_hyperparameters() - # store loss function separately as it is a module - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - - self.input_embeddings = MultiEmbedding( - embedding_sizes={ - name: val - for name, val in embedding_sizes.items() - if name in self.decoder_variables + self.static_variables - }, - embedding_paddings=embedding_paddings, - categorical_groups=categorical_groups, - x_categoricals=x_categoricals, - ) - # define network - if isinstance(self.hparams.output_size, int): - mlp_output_size = self.hparams.output_size - else: - mlp_output_size = sum(self.hparams.output_size) - - cont_size = len(self.decoder_reals_positions) - cat_size = sum(self.input_embeddings.output_size.values()) - input_size = cont_size + cat_size - - self.mlp = FullyConnectedModule( - dropout=dropout, - norm=self.hparams.norm, - activation_class=getattr(nn, self.hparams.activation_class), - input_size=input_size, - output_size=mlp_output_size, - hidden_size=self.hparams.hidden_size, - n_hidden_layers=self.hparams.n_hidden_layers, - ) - - @property - def decoder_reals_positions(self) -> List[int]: - return [ - self.hparams.x_reals.index(name) - for name in self.reals - if name in self.decoder_variables + self.static_variables - ] - - def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: - """ - Forward network - """ - # x is a batch generated based on the TimeSeriesDataset - batch_size = x["decoder_lengths"].size(0) - embeddings = self.input_embeddings(x["decoder_cat"]) # returns dictionary with embedding tensors - network_input = torch.cat( - [x["decoder_cont"][..., self.decoder_reals_positions]] + list(embeddings.values()), - dim=-1, - ) - prediction = self.mlp(network_input.view(-1, self.mlp.input_size)).view( - batch_size, network_input.size(1), self.mlp.output_size - ) - - # cut prediction into pieces for multiple targets - if self.n_targets > 1: - prediction = torch.split(prediction, self.hparams.output_size, dim=-1) - - # We need to return a dictionary that at least contains the prediction - # The parameter can be directly forwarded from the input. - prediction = self.transform_output(prediction, target_scale=x["target_scale"]) - return self.to_network_output(prediction=prediction) - @classmethod - def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): - new_kwargs = cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()) - kwargs.update(new_kwargs) - return super().from_dataset(dataset, **kwargs) +__all__ = ["DecoderMLP", "FullyConnectedModule"] diff --git a/pytorch_forecasting/models/mlp/_decodermlp.py b/pytorch_forecasting/models/mlp/_decodermlp.py new file mode 100644 index 00000000..24b72c0f --- /dev/null +++ b/pytorch_forecasting/models/mlp/_decodermlp.py @@ -0,0 +1,179 @@ +""" +Simple models based on fully connected networks +""" + +from typing import Dict, List, Tuple, Union, Optional + +import numpy as np +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss +from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule +from pytorch_forecasting.models.nn.embeddings import MultiEmbedding + + +class DecoderMLP(BaseModelWithCovariates): + """ + MLP on the decoder. + + MLP that predicts output only based on information available in the decoder. + """ + + def __init__( + self, + activation_class: str = "ReLU", + hidden_size: int = 300, + n_hidden_layers: int = 3, + dropout: float = 0.1, + norm: bool = True, + static_categoricals: Optional[List[str]] = None, + static_reals: Optional[List[str]] = None, + time_varying_categoricals_encoder: Optional[List[str]] = None, + time_varying_categoricals_decoder: Optional[List[str]] = None, + categorical_groups: Optional[Dict[str, List[str]]] = None, + time_varying_reals_encoder: Optional[List[str]] = None, + time_varying_reals_decoder: Optional[List[str]] = None, + embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, + embedding_paddings: Optional[List[str]] = None, + embedding_labels: Optional[Dict[str, np.ndarray]] = None, + x_reals: Optional[List[str]] = None, + x_categoricals: Optional[List[str]] = None, + output_size: Union[int, List[int]] = 1, + target: Union[str, List[str]] = None, + loss: MultiHorizonMetric = None, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Args: + activation_class (str, optional): PyTorch activation class. Defaults to "ReLU". + hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with + ``n_hidden_layers``. Defaults to 10. + n_hidden_layers (int, optional): Number of hidden layers - important hyperparameter. Defaults to 2. + dropout (float, optional): Dropout. Defaults to 0.1. + norm (bool, optional): if to use normalization in the MLP. Defaults to True. + static_categoricals: integer of positions of static categorical variables + static_reals: integer of positions of static continuous variables + time_varying_categoricals_encoder: integer of positions of categorical variables for encoder + time_varying_categoricals_decoder: integer of positions of categorical variables for decoder + time_varying_reals_encoder: integer of positions of continuous variables for encoder + time_varying_reals_decoder: integer of positions of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + output_size (Union[int, List[int]], optional): number of outputs (e.g. number of quantiles for + QuantileLoss and one target or list of output sizes). + target (str, optional): Target variable or list of target variables. Defaults to None. + loss (MultiHorizonMetric, optional): loss: loss function taking prediction and targets. + Defaults to QuantileLoss. + logging_metrics (nn.ModuleList, optional): Metrics to log during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). + """ + if loss is None: + loss = QuantileLoss() + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if static_categoricals is None: + static_categoricals = [] + if static_reals is None: + static_reals = [] + if time_varying_reals_encoder is None: + time_varying_reals_encoder = [] + if time_varying_categoricals_decoder is None: + time_varying_categoricals_decoder = [] + if categorical_groups is None: + categorical_groups = {} + if time_varying_reals_encoder is None: + time_varying_reals_encoder = [] + if time_varying_reals_decoder is None: + time_varying_reals_decoder = [] + if embedding_sizes is None: + embedding_sizes = {} + if embedding_paddings is None: + embedding_paddings = [] + if embedding_labels is None: + embedding_labels = {} + if x_reals is None: + x_reals = [] + if x_categoricals is None: + x_categoricals = [] + self.save_hyperparameters() + # store loss function separately as it is a module + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + self.input_embeddings = MultiEmbedding( + embedding_sizes={ + name: val + for name, val in embedding_sizes.items() + if name in self.decoder_variables + self.static_variables + }, + embedding_paddings=embedding_paddings, + categorical_groups=categorical_groups, + x_categoricals=x_categoricals, + ) + # define network + if isinstance(self.hparams.output_size, int): + mlp_output_size = self.hparams.output_size + else: + mlp_output_size = sum(self.hparams.output_size) + + cont_size = len(self.decoder_reals_positions) + cat_size = sum(self.input_embeddings.output_size.values()) + input_size = cont_size + cat_size + + self.mlp = FullyConnectedModule( + dropout=dropout, + norm=self.hparams.norm, + activation_class=getattr(nn, self.hparams.activation_class), + input_size=input_size, + output_size=mlp_output_size, + hidden_size=self.hparams.hidden_size, + n_hidden_layers=self.hparams.n_hidden_layers, + ) + + @property + def decoder_reals_positions(self) -> List[int]: + return [ + self.hparams.x_reals.index(name) + for name in self.reals + if name in self.decoder_variables + self.static_variables + ] + + def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: + """ + Forward network + """ + # x is a batch generated based on the TimeSeriesDataset + batch_size = x["decoder_lengths"].size(0) + embeddings = self.input_embeddings(x["decoder_cat"]) # returns dictionary with embedding tensors + network_input = torch.cat( + [x["decoder_cont"][..., self.decoder_reals_positions]] + list(embeddings.values()), + dim=-1, + ) + prediction = self.mlp(network_input.view(-1, self.mlp.input_size)).view( + batch_size, network_input.size(1), self.mlp.output_size + ) + + # cut prediction into pieces for multiple targets + if self.n_targets > 1: + prediction = torch.split(prediction, self.hparams.output_size, dim=-1) + + # We need to return a dictionary that at least contains the prediction + # The parameter can be directly forwarded from the input. + prediction = self.transform_output(prediction, target_scale=x["target_scale"]) + return self.to_network_output(prediction=prediction) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + new_kwargs = cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()) + kwargs.update(new_kwargs) + return super().from_dataset(dataset, **kwargs) diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 8d00392c..dcf4e1b3 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,376 +1,6 @@ -""" -N-Beats model for timeseries forecasting without covariates. -""" +"""N-Beats model for timeseries forecasting without covariates.""" -from typing import Dict, List, Optional - -import torch -from torch import nn - -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.nbeats._nbeats import NBeats from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock -from pytorch_forecasting.utils._dependencies import _check_matplotlib - - -class NBeats(BaseModel): - def __init__( - self, - stack_types: Optional[List[str]] = None, - num_blocks: Optional[List[int]] = None, - num_block_layers: Optional[List[int]] = None, - widths: Optional[List[int]] = None, - sharing: Optional[List[bool]] = None, - expansion_coefficient_lengths: Optional[List[int]] = None, - prediction_length: int = 1, - context_length: int = 1, - dropout: float = 0.1, - learning_rate: float = 1e-2, - log_interval: int = -1, - log_gradient_flow: bool = False, - log_val_interval: int = None, - weight_decay: float = 1e-3, - loss: MultiHorizonMetric = None, - reduce_on_plateau_patience: int = 1000, - backcast_loss_ratio: float = 0.0, - logging_metrics: nn.ModuleList = None, - **kwargs, - ): - """ - Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. - - Based on the article - `N-BEATS: Neural basis expansion analysis for interpretable time series - forecasting `_. The network has (if used as ensemble) outperformed all - other methods - including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably - the most - important benchmark for univariate time series forecasting. - - The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform - N-BEATS. - - Args: - stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings - of length 1 or ‘num_stacks’. Default and recommended value - for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”] - num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’. - Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] - num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length - 1 or ‘num_stacks’. - Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4] - width: Widths of the fully connected layers with ReLu activation in the blocks. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512] - Recommended value for interpretable mode: [256, 2048] - sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False] - Recommended value for interpretable mode: [True] - expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion - coefficient. - If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” - (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. - A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for - interpretable mode: [3] - prediction_length: Length of the prediction. Also known as 'horizon'. - context_length: Number of time units that condition the predictions. Also known as 'lookback period'. - Should be between 1-10 times the prediction length. - backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. - A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and - forecast lengths). Defaults to 0.0, i.e. no weight. - loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ - if expansion_coefficient_lengths is None: - expansion_coefficient_lengths = [3, 7] - if sharing is None: - sharing = [True, True] - if widths is None: - widths = [32, 512] - if num_block_layers is None: - num_block_layers = [3, 3] - if num_blocks is None: - num_blocks = [3, 3] - if stack_types is None: - stack_types = ["trend", "seasonality"] - if logging_metrics is None: - logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - if loss is None: - loss = MASE() - self.save_hyperparameters() - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - - # setup stacks - self.net_blocks = nn.ModuleList() - for stack_id, stack_type in enumerate(stack_types): - for _ in range(num_blocks[stack_id]): - if stack_type == "generic": - net_block = NBEATSGenericBlock( - units=self.hparams.widths[stack_id], - thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], - num_block_layers=self.hparams.num_block_layers[stack_id], - backcast_length=context_length, - forecast_length=prediction_length, - dropout=self.hparams.dropout, - ) - elif stack_type == "seasonality": - net_block = NBEATSSeasonalBlock( - units=self.hparams.widths[stack_id], - num_block_layers=self.hparams.num_block_layers[stack_id], - backcast_length=context_length, - forecast_length=prediction_length, - min_period=self.hparams.expansion_coefficient_lengths[stack_id], - dropout=self.hparams.dropout, - ) - elif stack_type == "trend": - net_block = NBEATSTrendBlock( - units=self.hparams.widths[stack_id], - thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], - num_block_layers=self.hparams.num_block_layers[stack_id], - backcast_length=context_length, - forecast_length=prediction_length, - dropout=self.hparams.dropout, - ) - else: - raise ValueError(f"Unknown stack type {stack_type}") - - self.net_blocks.append(net_block) - - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - Pass forward of network. - - Args: - x (Dict[str, torch.Tensor]): input from dataloader generated from - :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Returns: - Dict[str, torch.Tensor]: output of model - """ - target = x["encoder_cont"][..., 0] - - timesteps = self.hparams.context_length + self.hparams.prediction_length - generic_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] - trend_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] - seasonal_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] - forecast = torch.zeros( - (target.size(0), self.hparams.prediction_length), dtype=torch.float32, device=self.device - ) - - backcast = target # initialize backcast - for i, block in enumerate(self.net_blocks): - # evaluate block - backcast_block, forecast_block = block(backcast) - - # add for interpretation - full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) - if isinstance(block, NBEATSTrendBlock): - trend_forecast.append(full) - elif isinstance(block, NBEATSSeasonalBlock): - seasonal_forecast.append(full) - else: - generic_forecast.append(full) - - # update backcast and forecast - backcast = ( - backcast - backcast_block - ) # do not use backcast -= backcast_block as this signifies an inline operation - forecast = forecast + forecast_block - - return self.to_network_output( - prediction=self.transform_output(forecast, target_scale=x["target_scale"]), - backcast=self.transform_output(prediction=target - backcast, target_scale=x["target_scale"]), - trend=self.transform_output(torch.stack(trend_forecast, dim=0).sum(0), target_scale=x["target_scale"]), - seasonality=self.transform_output( - torch.stack(seasonal_forecast, dim=0).sum(0), target_scale=x["target_scale"] - ), - generic=self.transform_output(torch.stack(generic_forecast, dim=0).sum(0), target_scale=x["target_scale"]), - ) - - @classmethod - def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): - """ - Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Args: - dataset (TimeSeriesDataSet): dataset where sole predictor is the target. - **kwargs: additional arguments to be passed to ``__init__`` method. - - Returns: - NBeats - """ - new_kwargs = {"prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length} - new_kwargs.update(kwargs) - - # validate arguments - assert isinstance(dataset.target, str), "only one target is allowed (passed as string to dataset)" - assert not isinstance( - dataset.target_normalizer, NaNLabelEncoder - ), "only regression tasks are supported - target must not be categorical" - assert ( - dataset.min_encoder_length == dataset.max_encoder_length - ), "only fixed encoder length is allowed, but min_encoder_length != max_encoder_length" - - assert ( - dataset.max_prediction_length == dataset.min_prediction_length - ), "only fixed prediction length is allowed, but max_prediction_length != min_prediction_length" - - assert dataset.randomize_length is None, "length has to be fixed, but randomize_length is not None" - assert not dataset.add_relative_time_idx, "add_relative_time_idx has to be False" - - assert ( - len(dataset.flat_categoricals) == 0 - and len(dataset.reals) == 1 - and len(dataset._time_varying_unknown_reals) == 1 - and dataset._time_varying_unknown_reals[0] == dataset.target - ), "The only variable as input should be the target which is part of time_varying_unknown_reals" - - # initialize class - return super().from_dataset(dataset, **new_kwargs) - - def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: - """ - Take training / validation step. - """ - log, out = super().step(x, y, batch_idx=batch_idx) - - if self.hparams.backcast_loss_ratio > 0 and not self.predicting: # add loss from backcast - backcast = out["backcast"] - backcast_weight = ( - self.hparams.backcast_loss_ratio * self.hparams.prediction_length / self.hparams.context_length - ) - backcast_weight = backcast_weight / (backcast_weight + 1) # normalize - forecast_weight = 1 - backcast_weight - if isinstance(self.loss, MASE): - backcast_loss = self.loss(backcast, x["encoder_target"], x["decoder_target"]) * backcast_weight - else: - backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight - label = ["val", "train"][self.training] - self.log( - f"{label}_backcast_loss", - backcast_loss, - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - self.log( - f"{label}_forecast_loss", - log["loss"], - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - log["loss"] = log["loss"] * forecast_weight + backcast_loss - - self.log_interpretation(x, out, batch_idx=batch_idx) - return log, out - - def log_interpretation(self, x, out, batch_idx): - """ - Log interpretation of network predictions in tensorboard. - """ - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - - label = ["val", "train"][self.training] - if self.log_interval > 0 and batch_idx % self.log_interval == 0: - fig = self.plot_interpretation(x, out, idx=0) - name = f"{label.capitalize()} interpretation of item 0 in " - if self.training: - name += f"step {self.global_step}" - else: - name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) - - def plot_interpretation( - self, - x: Dict[str, torch.Tensor], - output: Dict[str, torch.Tensor], - idx: int, - ax=None, - plot_seasonality_and_generic_on_secondary_axis: bool = False, - ): - """ - Plot interpretation. - - Plot two pannels: prediction and backcast vs actuals and - decomposition of prediction into trend, seasonality and generic forecast. - - Args: - x (Dict[str, torch.Tensor]): network input - output (Dict[str, torch.Tensor]): network output - idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. - Defaults to None. - plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and - generic forecast on secondary axis in second panel. Defaults to False. - - Returns: - plt.Figure: matplotlib figure - """ - _check_matplotlib("plot_interpretation") - - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots(2, 1, figsize=(6, 8)) - else: - fig = ax[0].get_figure() - - time = torch.arange(-self.hparams.context_length, self.hparams.prediction_length) - - # plot target vs prediction - ax[0].plot(time, torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]).detach().cpu(), label="target") - ax[0].plot( - time, - torch.cat( - [ - output["backcast"][idx].detach(), - output["prediction"][idx].detach(), - ], - dim=0, - ).cpu(), - label="prediction", - ) - ax[0].set_xlabel("Time") - - # plot blocks - prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) - next(prop_cycle) # prediction - next(prop_cycle) # observations - if plot_seasonality_and_generic_on_secondary_axis: - ax2 = ax[1].twinx() - ax2.set_ylabel("Seasonality / Generic") - else: - ax2 = ax[1] - for title in ["trend", "seasonality", "generic"]: - if title not in self.hparams.stack_types: - continue - if title == "trend": - ax[1].plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - else: - ax2.plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - ax[1].set_xlabel("Time") - ax[1].set_ylabel("Decomposition") - fig.legend() - return fig +__all__ = ["NBeats", "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock"] diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py new file mode 100644 index 00000000..8d00392c --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -0,0 +1,376 @@ +""" +N-Beats model for timeseries forecasting without covariates. +""" + +from typing import Dict, List, Optional + +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class NBeats(BaseModel): + def __init__( + self, + stack_types: Optional[List[str]] = None, + num_blocks: Optional[List[int]] = None, + num_block_layers: Optional[List[int]] = None, + widths: Optional[List[int]] = None, + sharing: Optional[List[bool]] = None, + expansion_coefficient_lengths: Optional[List[int]] = None, + prediction_length: int = 1, + context_length: int = 1, + dropout: float = 0.1, + learning_rate: float = 1e-2, + log_interval: int = -1, + log_gradient_flow: bool = False, + log_val_interval: int = None, + weight_decay: float = 1e-3, + loss: MultiHorizonMetric = None, + reduce_on_plateau_patience: int = 1000, + backcast_loss_ratio: float = 0.0, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if used as ensemble) outperformed all + other methods + including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably + the most + important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform + N-BEATS. + + Args: + stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings + of length 1 or ‘num_stacks’. Default and recommended value + for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”] + num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’. + Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] + num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length + 1 or ‘num_stacks’. + Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4] + width: Widths of the fully connected layers with ReLu activation in the blocks. + A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512] + Recommended value for interpretable mode: [256, 2048] + sharing: Whether the weights are shared with the other blocks per stack. + A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False] + Recommended value for interpretable mode: [True] + expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion + coefficient. + If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” + (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. + A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for + interpretable mode: [3] + prediction_length: Length of the prediction. Also known as 'horizon'. + context_length: Number of time units that condition the predictions. Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. + A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and + forecast lengths). Defaults to 0.0, i.e. no weight. + loss: loss to optimize. Defaults to MASE(). + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ + if expansion_coefficient_lengths is None: + expansion_coefficient_lengths = [3, 7] + if sharing is None: + sharing = [True, True] + if widths is None: + widths = [32, 512] + if num_block_layers is None: + num_block_layers = [3, 3] + if num_blocks is None: + num_blocks = [3, 3] + if stack_types is None: + stack_types = ["trend", "seasonality"] + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MASE() + self.save_hyperparameters() + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # setup stacks + self.net_blocks = nn.ModuleList() + for stack_id, stack_type in enumerate(stack_types): + for _ in range(num_blocks[stack_id]): + if stack_type == "generic": + net_block = NBEATSGenericBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=self.hparams.dropout, + ) + elif stack_type == "seasonality": + net_block = NBEATSSeasonalBlock( + units=self.hparams.widths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + min_period=self.hparams.expansion_coefficient_lengths[stack_id], + dropout=self.hparams.dropout, + ) + elif stack_type == "trend": + net_block = NBEATSTrendBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=self.hparams.dropout, + ) + else: + raise ValueError(f"Unknown stack type {stack_type}") + + self.net_blocks.append(net_block) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Pass forward of network. + + Args: + x (Dict[str, torch.Tensor]): input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Returns: + Dict[str, torch.Tensor]: output of model + """ + target = x["encoder_cont"][..., 0] + + timesteps = self.hparams.context_length + self.hparams.prediction_length + generic_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] + trend_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] + seasonal_forecast = [torch.zeros((target.size(0), timesteps), dtype=torch.float32, device=self.device)] + forecast = torch.zeros( + (target.size(0), self.hparams.prediction_length), dtype=torch.float32, device=self.device + ) + + backcast = target # initialize backcast + for i, block in enumerate(self.net_blocks): + # evaluate block + backcast_block, forecast_block = block(backcast) + + # add for interpretation + full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) + if isinstance(block, NBEATSTrendBlock): + trend_forecast.append(full) + elif isinstance(block, NBEATSSeasonalBlock): + seasonal_forecast.append(full) + else: + generic_forecast.append(full) + + # update backcast and forecast + backcast = ( + backcast - backcast_block + ) # do not use backcast -= backcast_block as this signifies an inline operation + forecast = forecast + forecast_block + + return self.to_network_output( + prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + backcast=self.transform_output(prediction=target - backcast, target_scale=x["target_scale"]), + trend=self.transform_output(torch.stack(trend_forecast, dim=0).sum(0), target_scale=x["target_scale"]), + seasonality=self.transform_output( + torch.stack(seasonal_forecast, dim=0).sum(0), target_scale=x["target_scale"] + ), + generic=self.transform_output(torch.stack(generic_forecast, dim=0).sum(0), target_scale=x["target_scale"]), + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Args: + dataset (TimeSeriesDataSet): dataset where sole predictor is the target. + **kwargs: additional arguments to be passed to ``__init__`` method. + + Returns: + NBeats + """ + new_kwargs = {"prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length} + new_kwargs.update(kwargs) + + # validate arguments + assert isinstance(dataset.target, str), "only one target is allowed (passed as string to dataset)" + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + assert ( + dataset.min_encoder_length == dataset.max_encoder_length + ), "only fixed encoder length is allowed, but min_encoder_length != max_encoder_length" + + assert ( + dataset.max_prediction_length == dataset.min_prediction_length + ), "only fixed prediction length is allowed, but max_prediction_length != min_prediction_length" + + assert dataset.randomize_length is None, "length has to be fixed, but randomize_length is not None" + assert not dataset.add_relative_time_idx, "add_relative_time_idx has to be False" + + assert ( + len(dataset.flat_categoricals) == 0 + and len(dataset.reals) == 1 + and len(dataset._time_varying_unknown_reals) == 1 + and dataset._time_varying_unknown_reals[0] == dataset.target + ), "The only variable as input should be the target which is part of time_varying_unknown_reals" + + # initialize class + return super().from_dataset(dataset, **new_kwargs) + + def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: + """ + Take training / validation step. + """ + log, out = super().step(x, y, batch_idx=batch_idx) + + if self.hparams.backcast_loss_ratio > 0 and not self.predicting: # add loss from backcast + backcast = out["backcast"] + backcast_weight = ( + self.hparams.backcast_loss_ratio * self.hparams.prediction_length / self.hparams.context_length + ) + backcast_weight = backcast_weight / (backcast_weight + 1) # normalize + forecast_weight = 1 - backcast_weight + if isinstance(self.loss, MASE): + backcast_loss = self.loss(backcast, x["encoder_target"], x["decoder_target"]) * backcast_weight + else: + backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight + label = ["val", "train"][self.training] + self.log( + f"{label}_backcast_loss", + backcast_loss, + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + self.log( + f"{label}_forecast_loss", + log["loss"], + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + log["loss"] = log["loss"] * forecast_weight + backcast_loss + + self.log_interpretation(x, out, batch_idx=batch_idx) + return log, out + + def log_interpretation(self, x, out, batch_idx): + """ + Log interpretation of network predictions in tensorboard. + """ + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + label = ["val", "train"][self.training] + if self.log_interval > 0 and batch_idx % self.log_interval == 0: + fig = self.plot_interpretation(x, out, idx=0) + name = f"{label.capitalize()} interpretation of item 0 in " + if self.training: + name += f"step {self.global_step}" + else: + name += f"batch {batch_idx}" + self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + + def plot_interpretation( + self, + x: Dict[str, torch.Tensor], + output: Dict[str, torch.Tensor], + idx: int, + ax=None, + plot_seasonality_and_generic_on_secondary_axis: bool = False, + ): + """ + Plot interpretation. + + Plot two pannels: prediction and backcast vs actuals and + decomposition of prediction into trend, seasonality and generic forecast. + + Args: + x (Dict[str, torch.Tensor]): network input + output (Dict[str, torch.Tensor]): network output + idx (int): index of sample for which to plot the interpretation. + ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. + Defaults to None. + plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and + generic forecast on secondary axis in second panel. Defaults to False. + + Returns: + plt.Figure: matplotlib figure + """ + _check_matplotlib("plot_interpretation") + + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(2, 1, figsize=(6, 8)) + else: + fig = ax[0].get_figure() + + time = torch.arange(-self.hparams.context_length, self.hparams.prediction_length) + + # plot target vs prediction + ax[0].plot(time, torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]).detach().cpu(), label="target") + ax[0].plot( + time, + torch.cat( + [ + output["backcast"][idx].detach(), + output["prediction"][idx].detach(), + ], + dim=0, + ).cpu(), + label="prediction", + ) + ax[0].set_xlabel("Time") + + # plot blocks + prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) + next(prop_cycle) # prediction + next(prop_cycle) # observations + if plot_seasonality_and_generic_on_secondary_axis: + ax2 = ax[1].twinx() + ax2.set_ylabel("Seasonality / Generic") + else: + ax2 = ax[1] + for title in ["trend", "seasonality", "generic"]: + if title not in self.hparams.stack_types: + continue + if title == "trend": + ax[1].plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + else: + ax2.plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + ax[1].set_xlabel("Time") + ax[1].set_ylabel("Decomposition") + + fig.legend() + return fig diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 6d790213..e7eb452d 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -1,595 +1,6 @@ -""" -N-HiTS model for timeseries forecasting with covariates. -""" +"""N-HiTS model for timeseries forecasting with covariates.""" -from copy import copy -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch import nn - -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss -from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.nhits._nhits import NHiTS from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule -from pytorch_forecasting.models.nn.embeddings import MultiEmbedding -from pytorch_forecasting.utils import create_mask, to_list -from pytorch_forecasting.utils._dependencies import _check_matplotlib - - -class NHiTS(BaseModelWithCovariates): - def __init__( - self, - output_size: Union[int, List[int]] = 1, - static_categoricals: Optional[List[str]] = None, - static_reals: Optional[List[str]] = None, - time_varying_categoricals_encoder: Optional[List[str]] = None, - time_varying_categoricals_decoder: Optional[List[str]] = None, - categorical_groups: Optional[Dict[str, List[str]]] = None, - time_varying_reals_encoder: Optional[List[str]] = None, - time_varying_reals_decoder: Optional[List[str]] = None, - embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, - embedding_paddings: Optional[List[str]] = None, - embedding_labels: Optional[List[str]] = None, - x_reals: Optional[List[str]] = None, - x_categoricals: Optional[List[str]] = None, - context_length: int = 1, - prediction_length: int = 1, - static_hidden_size: Optional[int] = None, - naive_level: bool = True, - shared_weights: bool = True, - activation: str = "ReLU", - initialization: str = "lecun_normal", - n_blocks: Optional[List[str]] = None, - n_layers: Union[int, List[int]] = 2, - hidden_size: int = 512, - pooling_sizes: Optional[List[int]] = None, - downsample_frequencies: Optional[List[int]] = None, - pooling_mode: str = "max", - interpolation_mode: str = "linear", - batch_normalization: bool = False, - dropout: float = 0.0, - learning_rate: float = 1e-2, - log_interval: int = -1, - log_gradient_flow: bool = False, - log_val_interval: int = None, - weight_decay: float = 1e-3, - loss: MultiHorizonMetric = None, - reduce_on_plateau_patience: int = 1000, - backcast_loss_ratio: float = 0.0, - logging_metrics: nn.ModuleList = None, - **kwargs, - ): - """ - Initialize N-HiTS Model - use its :py:meth:`~from_dataset` method if possible. - - Based on the article - `N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting `_. - The network has shown to increase accuracy by ~25% against - :py:class:`~pytorch_forecasting.models.nbeats.NBeats` and also supports covariates. - - Args: - hidden_size (int): size of hidden layers and can range from 8 to 1024 - use 32-128 if no - covariates are employed. Defaults to 512. - static_hidden_size (Optional[int], optional): size of hidden layers for static variables. - Defaults to hidden_size. - loss: loss to optimize. Defaults to MASE(). QuantileLoss is also supported - shared_weights (bool, optional): if True, weights of blocks are shared in each stack. Defaults to True. - naive_level (bool, optional): if True, native forecast of last observation is added at the beginnging. - Defaults to True. - initialization (str, optional): Initialization method. One of ['orthogonal', 'he_uniform', 'glorot_uniform', - 'glorot_normal', 'lecun_normal']. Defaults to "lecun_normal". - n_blocks (List[int], optional): list of blocks used in each stack (i.e. length of stacks). - Defaults to [1, 1, 1]. - n_layers (Union[int, List[int]], optional): Number of layers per block or list of number of - layers used by blocks in each stack (i.e. length of stacks). Defaults to 2. - pooling_sizes (Optional[List[int]], optional): List of pooling sizes for input for each stack, - i.e. higher means more smoothing of input. Using an ordering of higher to lower in the list - improves results. - Defaults to a heuristic. - pooling_mode (str, optional): Pooling mode for summarizing input. One of ['max','average']. - Defaults to "max". - downsample_frequencies (Optional[List[int]], optional): Downsample multiplier of output for each stack, i.e. - higher means more interpolation at forecast time is required. Should be equal or higher - than pooling_sizes but smaller equal prediction_length. - Defaults to a heuristic to match pooling_sizes. - interpolation_mode (str, optional): Interpolation mode for forecasting. One of ['linear', 'nearest', - 'cubic-x'] where 'x' is replaced by a batch size for the interpolation. Defaults to "linear". - batch_normalization (bool, optional): Whether carry out batch normalization. Defaults to False. - dropout (float, optional): dropout rate for hidden layers. Defaults to 0.0. - activation (str, optional): activation function. One of ['ReLU', 'Softplus', 'Tanh', 'SELU', - 'LeakyReLU', 'PReLU', 'Sigmoid']. Defaults to "ReLU". - output_size: number of outputs (typically number of quantiles for QuantileLoss and one target or list - of output sizes but currently only point-forecasts allowed). Set automatically. - static_categoricals: names of static categorical variables - static_reals: names of static continuous variables - time_varying_categoricals_encoder: names of categorical variables for encoder - time_varying_categoricals_decoder: names of categorical variables for decoder - time_varying_reals_encoder: names of continuous variables for encoder - time_varying_reals_decoder: names of continuous variables for decoder - categorical_groups: dictionary where values - are list of categorical variables that are forming together a new categorical - variable which is the key in the dictionary - x_reals: order of continuous variables in tensor passed to forward function - x_categoricals: order of categorical variables in tensor passed to forward function - hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical - embedding size) - hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection - (fallback to hidden_continuous_size if index is not in dictionary) - embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and - embedding size - embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector - embedding_labels: dictionary mapping (string) indices to list of categorical labels - learning_rate: learning rate - log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 - , will log multiple entries per batch. Defaults to -1. - log_val_interval: frequency with which to log validation set metrics, defaults to log_interval - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - prediction_length: Length of the prediction. Also known as 'horizon'. - context_length: Number of time units that condition the predictions. Also known as 'lookback period'. - Should be between 1-10 times the prediction length. - backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. - A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and - forecast lengths). Defaults to 0.0, i.e. no weight. - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ - if static_categoricals is None: - static_categoricals = [] - if static_reals is None: - static_reals = [] - if time_varying_categoricals_encoder is None: - time_varying_categoricals_encoder = [] - if time_varying_categoricals_decoder is None: - time_varying_categoricals_decoder = [] - if categorical_groups is None: - categorical_groups = {} - if time_varying_reals_encoder is None: - time_varying_reals_encoder = [] - if time_varying_reals_decoder is None: - time_varying_reals_decoder = [] - if embedding_sizes is None: - embedding_sizes = {} - if embedding_paddings is None: - embedding_paddings = [] - if embedding_labels is None: - embedding_labels = {} - if x_reals is None: - x_reals = [] - if x_categoricals is None: - x_categoricals = [] - if n_blocks is None: - n_blocks = [1, 1, 1] - if logging_metrics is None: - logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - if loss is None: - loss = MASE() - - if activation == "SELU": - self.hparams.initialization = "lecun_normal" - - # provide default downsampling sizes - n_stacks = len(n_blocks) - if pooling_sizes is None: - pooling_sizes = np.exp2(np.round(np.linspace(0.49, np.log2(prediction_length / 2), n_stacks))) - pooling_sizes = [int(x) for x in pooling_sizes[::-1]] - # remove zero from pooling_sizes - pooling_sizes = max(pooling_sizes, [1] * len(pooling_sizes)) - if downsample_frequencies is None: - downsample_frequencies = [min(prediction_length, int(np.power(x, 1.5))) for x in pooling_sizes] - # remove zero from downsample_frequencies - downsample_frequencies = max(downsample_frequencies, [1] * len(downsample_frequencies)) - - # set static hidden size - if static_hidden_size is None: - static_hidden_size = hidden_size - - # set layers - if isinstance(n_layers, int): - n_layers = [n_layers] * n_stacks - - self.save_hyperparameters() - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - - self.embeddings = MultiEmbedding( - embedding_sizes=self.hparams.embedding_sizes, - categorical_groups=self.hparams.categorical_groups, - embedding_paddings=self.hparams.embedding_paddings, - x_categoricals=self.hparams.x_categoricals, - ) - - self.model = NHiTSModule( - context_length=self.hparams.context_length, - prediction_length=self.hparams.prediction_length, - output_size=to_list(output_size), - static_size=self.static_size, - encoder_covariate_size=self.encoder_covariate_size, - decoder_covariate_size=self.decoder_covariate_size, - static_hidden_size=self.hparams.static_hidden_size, - n_blocks=self.hparams.n_blocks, - n_layers=self.hparams.n_layers, - hidden_size=self.n_stacks * [2 * [self.hparams.hidden_size]], - pooling_sizes=self.hparams.pooling_sizes, - downsample_frequencies=self.hparams.downsample_frequencies, - pooling_mode=self.hparams.pooling_mode, - interpolation_mode=self.hparams.interpolation_mode, - dropout=self.hparams.dropout, - activation=self.hparams.activation, - initialization=self.hparams.initialization, - batch_normalization=self.hparams.batch_normalization, - shared_weights=self.hparams.shared_weights, - naive_level=self.hparams.naive_level, - ) - - @property - def decoder_covariate_size(self) -> int: - """Decoder covariates size. - - Returns: - int: size of time-dependent covariates used by the decoder - """ - return len(set(self.hparams.time_varying_reals_decoder) - set(self.target_names)) + sum( - self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder - ) - - @property - def encoder_covariate_size(self) -> int: - """Encoder covariate size. - - Returns: - int: size of time-dependent covariates used by the encoder - """ - return len(set(self.hparams.time_varying_reals_encoder) - set(self.target_names)) + sum( - self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder - ) - - @property - def static_size(self) -> int: - """Static covariate size. - - Returns: - int: size of static covariates - """ - return len(self.hparams.static_reals) + sum( - self.embeddings.output_size[name] for name in self.hparams.static_categoricals - ) - - @property - def n_stacks(self) -> int: - """Number of stacks. - - Returns: - int: number of stacks. - """ - return len(self.hparams.n_blocks) - - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - Pass forward of network. - - Args: - x (Dict[str, torch.Tensor]): input from dataloader generated from - :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Returns: - Dict[str, torch.Tensor]: output of model - """ - # covariates - if self.encoder_covariate_size > 0: - encoder_features = self.extract_features(x, self.embeddings, period="encoder") - encoder_x_t = torch.concat( - [encoder_features[name] for name in self.encoder_variables if name not in self.target_names], - dim=2, - ) - else: - encoder_x_t = None - - if self.decoder_covariate_size > 0: - decoder_features = self.extract_features(x, self.embeddings, period="decoder") - decoder_x_t = torch.concat([decoder_features[name] for name in self.decoder_variables], dim=2) - else: - decoder_x_t = None - - # statics - if self.static_size > 0: - x_s = torch.concat([encoder_features[name][:, 0] for name in self.static_variables], dim=1) - else: - x_s = None - - # target - encoder_y = x["encoder_cont"][..., self.target_positions] - encoder_mask = create_mask(x["encoder_lengths"].max(), x["encoder_lengths"], inverse=True) - - # run model - forecast, backcast, block_forecasts, block_backcasts = self.model( - encoder_y, encoder_mask, encoder_x_t, decoder_x_t, x_s - ) - backcast = encoder_y - backcast - - # create block output: detach and split by block - block_backcasts = block_backcasts.detach() - block_forecasts = block_forecasts.detach() - - if isinstance(self.hparams.output_size, (tuple, list)): - forecast = forecast.split(self.hparams.output_size, dim=2) - backcast = backcast.split(1, dim=2) - block_backcasts = tuple( - self.transform_output(block.squeeze(3).split(1, dim=2), target_scale=x["target_scale"]) - for block in block_backcasts.split(1, dim=3) - ) - block_forecasts = tuple( - self.transform_output( - block.squeeze(3).split(self.hparams.output_size, dim=2), target_scale=x["target_scale"] - ) - for block in block_forecasts.split(1, dim=3) - ) - else: - block_backcasts = tuple( - self.transform_output(block.squeeze(3), target_scale=x["target_scale"], loss=MultiHorizonMetric()) - for block in block_backcasts.split(1, dim=3) - ) - block_forecasts = tuple( - self.transform_output(block.squeeze(3), target_scale=x["target_scale"]) - for block in block_forecasts.split(1, dim=3) - ) - - return self.to_network_output( - prediction=self.transform_output( - forecast, target_scale=x["target_scale"] - ), # (n_outputs x) n_samples x n_timesteps x output_size - backcast=self.transform_output( - backcast, target_scale=x["target_scale"], loss=MultiHorizonMetric() - ), # (n_outputs x) n_samples x n_timesteps x 1 - block_backcasts=block_backcasts, # n_blocks x (n_outputs x) n_samples x n_timesteps x 1 - block_forecasts=block_forecasts, # n_blocks x (n_outputs x) n_samples x n_timesteps x output_size - ) - - @classmethod - def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): - """ - Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Args: - dataset (TimeSeriesDataSet): dataset where sole predictor is the target. - **kwargs: additional arguments to be passed to ``__init__`` method. - - Returns: - NBeats - """ - # validate arguments - assert not isinstance( - dataset.target_normalizer, NaNLabelEncoder - ), "only regression tasks are supported - target must not be categorical" - assert ( - dataset.min_encoder_length == dataset.max_encoder_length - ), "only fixed encoder length is allowed, but min_encoder_length != max_encoder_length" - - assert ( - dataset.max_prediction_length == dataset.min_prediction_length - ), "only fixed prediction length is allowed, but max_prediction_length != min_prediction_length" - - assert dataset.randomize_length is None, "length has to be fixed, but randomize_length is not None" - assert not dataset.add_relative_time_idx, "add_relative_time_idx has to be False" - - new_kwargs = copy(kwargs) - new_kwargs.update( - {"prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length} - ) - new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, MASE())) - - assert (new_kwargs.get("backcast_loss_ratio", 0) == 0) | ( - isinstance(new_kwargs["output_size"], int) and new_kwargs["output_size"] == 1 - ) or all( - o == 1 for o in new_kwargs["output_size"] - ), "output sizes can only be of size 1, i.e. point forecasts if backcast_loss_ratio > 0" - - # initialize class - return super().from_dataset(dataset, **new_kwargs) - - def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: - """ - Take training / validation step. - """ - log, out = super().step(x, y, batch_idx=batch_idx) - - if self.hparams.backcast_loss_ratio > 0 and not self.predicting: # add loss from backcast - backcast = out["backcast"] - backcast_weight = ( - self.hparams.backcast_loss_ratio * self.hparams.prediction_length / self.hparams.context_length - ) - backcast_weight = backcast_weight / (backcast_weight + 1) # normalize - forecast_weight = 1 - backcast_weight - if isinstance(self.loss, (MultiLoss, MASE)): - backcast_loss = ( - self.loss( - backcast, - (x["encoder_target"], None), - encoder_target=x["decoder_target"], - encoder_lengths=x["decoder_lengths"], - ) - * backcast_weight - ) - else: - backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight - label = ["val", "train"][self.training] - self.log( - f"{label}_backcast_loss", - backcast_loss, - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - self.log( - f"{label}_forecast_loss", - log["loss"], - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - log["loss"] = log["loss"] * forecast_weight + backcast_loss - - # log interpretation - self.log_interpretation(x, out, batch_idx=batch_idx) - return log, out - - def plot_interpretation( - self, - x: Dict[str, torch.Tensor], - output: Dict[str, torch.Tensor], - idx: int, - ax=None, - ): - """ - Plot interpretation. - - Plot two pannels: prediction and backcast vs actuals and - decomposition of prediction into different block predictions which capture different frequencies. - - Args: - x (Dict[str, torch.Tensor]): network input - output (Dict[str, torch.Tensor]): network output - idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. - Defaults to None. - - Returns: - plt.Figure: matplotlib figure - """ - _check_matplotlib("plot_interpretation") - - from matplotlib import pyplot as plt - - if not isinstance(self.loss, MultiLoss): # not multi-target - prediction = self.to_prediction(dict(prediction=output["prediction"][[idx]].detach()))[0].cpu() - block_forecasts = [ - self.to_prediction(dict(prediction=block[[idx]].detach()))[0].cpu() - for block in output["block_forecasts"] - ] - elif isinstance(output["prediction"], (tuple, list)): # multi-target - figs = [] - # predictions and block forecasts need to be converted - prediction = [p[[idx]].detach() for p in output["prediction"]] # select index - prediction = self.to_prediction(dict(prediction=prediction)) # transform to prediction - prediction = [p[0].cpu() for p in prediction] # select first and only index - - block_forecasts = [ - self.to_prediction(dict(prediction=[b[[idx]].detach() for b in block])) - for block in output["block_forecasts"] - ] - block_forecasts = [[b[0].cpu() for b in block] for block in block_forecasts] - - for i in range(len(self.target_names)): - if ax is not None: - ax_i = ax[i] - else: - ax_i = None - - figs.append( - self.plot_interpretation( - dict(encoder_target=x["encoder_target"][i], decoder_target=x["decoder_target"][i]), - dict( - backcast=output["backcast"][i], - prediction=prediction[i], - block_backcasts=[block[i] for block in output["block_backcasts"]], - block_forecasts=[block[i] for block in block_forecasts], - ), - idx=idx, - ax=ax_i, - ) - ) - return figs - else: - prediction = output["prediction"] # multi target that has already been transformed - block_forecasts = output["block_forecasts"] - - if ax is None: - fig, ax = plt.subplots(2, 1, figsize=(6, 8), sharex=True, sharey=True) - else: - fig = ax[0].get_figure() - - # plot target vs prediction - # target - prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) - color = next(prop_cycle)["color"] - ax[0].plot(torch.arange(-self.hparams.context_length, 0), x["encoder_target"][idx].detach().cpu(), c=color) - ax[0].plot( - torch.arange(self.hparams.prediction_length), - x["decoder_target"][idx].detach().cpu(), - label="Target", - c=color, - ) - # prediction - color = next(prop_cycle)["color"] - ax[0].plot( - torch.arange(-self.hparams.context_length, 0), - output["backcast"][idx][..., 0].detach().cpu(), - label="Backcast", - c=color, - ) - ax[0].plot( - torch.arange(self.hparams.prediction_length), - prediction, - label="Forecast", - c=color, - ) - - # plot blocks - for pooling_size, block_backcast, block_forecast in zip( - self.hparams.pooling_sizes, output["block_backcasts"][1:], block_forecasts - ): - color = next(prop_cycle)["color"] - ax[1].plot( - torch.arange(-self.hparams.context_length, 0), - block_backcast[idx][..., 0].detach().cpu(), - c=color, - ) - ax[1].plot( - torch.arange(self.hparams.prediction_length), - block_forecast, - c=color, - label=f"Pooling size: {pooling_size}", - ) - ax[1].set_xlabel("Time") - - fig.legend() - return fig - - def log_interpretation(self, x, out, batch_idx): - """ - Log interpretation of network predictions in tensorboard. - """ - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - label = ["val", "train"][self.training] - if self.log_interval > 0 and batch_idx % self.log_interval == 0: - fig = self.plot_interpretation(x, out, idx=0) - name = f"{label.capitalize()} interpretation of item 0 in " - if self.training: - name += f"step {self.global_step}" - else: - name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) - if isinstance(fig, (list, tuple)): - for idx, f in enumerate(fig): - self.logger.experiment.add_figure( - f"{self.target_names[idx]} {name}", - f, - global_step=self.global_step, - ) - else: - self.logger.experiment.add_figure( - name, - fig, - global_step=self.global_step, - ) +__all__ = ["NHits", "NHiTSModule"] diff --git a/pytorch_forecasting/models/nhits/_nhits.py b/pytorch_forecasting/models/nhits/_nhits.py new file mode 100644 index 00000000..6d790213 --- /dev/null +++ b/pytorch_forecasting/models/nhits/_nhits.py @@ -0,0 +1,595 @@ +""" +N-HiTS model for timeseries forecasting with covariates. +""" + +from copy import copy +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss +from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule +from pytorch_forecasting.models.nn.embeddings import MultiEmbedding +from pytorch_forecasting.utils import create_mask, to_list +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class NHiTS(BaseModelWithCovariates): + def __init__( + self, + output_size: Union[int, List[int]] = 1, + static_categoricals: Optional[List[str]] = None, + static_reals: Optional[List[str]] = None, + time_varying_categoricals_encoder: Optional[List[str]] = None, + time_varying_categoricals_decoder: Optional[List[str]] = None, + categorical_groups: Optional[Dict[str, List[str]]] = None, + time_varying_reals_encoder: Optional[List[str]] = None, + time_varying_reals_decoder: Optional[List[str]] = None, + embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, + embedding_paddings: Optional[List[str]] = None, + embedding_labels: Optional[List[str]] = None, + x_reals: Optional[List[str]] = None, + x_categoricals: Optional[List[str]] = None, + context_length: int = 1, + prediction_length: int = 1, + static_hidden_size: Optional[int] = None, + naive_level: bool = True, + shared_weights: bool = True, + activation: str = "ReLU", + initialization: str = "lecun_normal", + n_blocks: Optional[List[str]] = None, + n_layers: Union[int, List[int]] = 2, + hidden_size: int = 512, + pooling_sizes: Optional[List[int]] = None, + downsample_frequencies: Optional[List[int]] = None, + pooling_mode: str = "max", + interpolation_mode: str = "linear", + batch_normalization: bool = False, + dropout: float = 0.0, + learning_rate: float = 1e-2, + log_interval: int = -1, + log_gradient_flow: bool = False, + log_val_interval: int = None, + weight_decay: float = 1e-3, + loss: MultiHorizonMetric = None, + reduce_on_plateau_patience: int = 1000, + backcast_loss_ratio: float = 0.0, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Initialize N-HiTS Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting `_. + The network has shown to increase accuracy by ~25% against + :py:class:`~pytorch_forecasting.models.nbeats.NBeats` and also supports covariates. + + Args: + hidden_size (int): size of hidden layers and can range from 8 to 1024 - use 32-128 if no + covariates are employed. Defaults to 512. + static_hidden_size (Optional[int], optional): size of hidden layers for static variables. + Defaults to hidden_size. + loss: loss to optimize. Defaults to MASE(). QuantileLoss is also supported + shared_weights (bool, optional): if True, weights of blocks are shared in each stack. Defaults to True. + naive_level (bool, optional): if True, native forecast of last observation is added at the beginnging. + Defaults to True. + initialization (str, optional): Initialization method. One of ['orthogonal', 'he_uniform', 'glorot_uniform', + 'glorot_normal', 'lecun_normal']. Defaults to "lecun_normal". + n_blocks (List[int], optional): list of blocks used in each stack (i.e. length of stacks). + Defaults to [1, 1, 1]. + n_layers (Union[int, List[int]], optional): Number of layers per block or list of number of + layers used by blocks in each stack (i.e. length of stacks). Defaults to 2. + pooling_sizes (Optional[List[int]], optional): List of pooling sizes for input for each stack, + i.e. higher means more smoothing of input. Using an ordering of higher to lower in the list + improves results. + Defaults to a heuristic. + pooling_mode (str, optional): Pooling mode for summarizing input. One of ['max','average']. + Defaults to "max". + downsample_frequencies (Optional[List[int]], optional): Downsample multiplier of output for each stack, i.e. + higher means more interpolation at forecast time is required. Should be equal or higher + than pooling_sizes but smaller equal prediction_length. + Defaults to a heuristic to match pooling_sizes. + interpolation_mode (str, optional): Interpolation mode for forecasting. One of ['linear', 'nearest', + 'cubic-x'] where 'x' is replaced by a batch size for the interpolation. Defaults to "linear". + batch_normalization (bool, optional): Whether carry out batch normalization. Defaults to False. + dropout (float, optional): dropout rate for hidden layers. Defaults to 0.0. + activation (str, optional): activation function. One of ['ReLU', 'Softplus', 'Tanh', 'SELU', + 'LeakyReLU', 'PReLU', 'Sigmoid']. Defaults to "ReLU". + output_size: number of outputs (typically number of quantiles for QuantileLoss and one target or list + of output sizes but currently only point-forecasts allowed). Set automatically. + static_categoricals: names of static categorical variables + static_reals: names of static continuous variables + time_varying_categoricals_encoder: names of categorical variables for encoder + time_varying_categoricals_decoder: names of categorical variables for decoder + time_varying_reals_encoder: names of continuous variables for encoder + time_varying_reals_decoder: names of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical + embedding size) + hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection + (fallback to hidden_continuous_size if index is not in dictionary) + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + learning_rate: learning rate + log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 + , will log multiple entries per batch. Defaults to -1. + log_val_interval: frequency with which to log validation set metrics, defaults to log_interval + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + prediction_length: Length of the prediction. Also known as 'horizon'. + context_length: Number of time units that condition the predictions. Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. + A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and + forecast lengths). Defaults to 0.0, i.e. no weight. + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ + if static_categoricals is None: + static_categoricals = [] + if static_reals is None: + static_reals = [] + if time_varying_categoricals_encoder is None: + time_varying_categoricals_encoder = [] + if time_varying_categoricals_decoder is None: + time_varying_categoricals_decoder = [] + if categorical_groups is None: + categorical_groups = {} + if time_varying_reals_encoder is None: + time_varying_reals_encoder = [] + if time_varying_reals_decoder is None: + time_varying_reals_decoder = [] + if embedding_sizes is None: + embedding_sizes = {} + if embedding_paddings is None: + embedding_paddings = [] + if embedding_labels is None: + embedding_labels = {} + if x_reals is None: + x_reals = [] + if x_categoricals is None: + x_categoricals = [] + if n_blocks is None: + n_blocks = [1, 1, 1] + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MASE() + + if activation == "SELU": + self.hparams.initialization = "lecun_normal" + + # provide default downsampling sizes + n_stacks = len(n_blocks) + if pooling_sizes is None: + pooling_sizes = np.exp2(np.round(np.linspace(0.49, np.log2(prediction_length / 2), n_stacks))) + pooling_sizes = [int(x) for x in pooling_sizes[::-1]] + # remove zero from pooling_sizes + pooling_sizes = max(pooling_sizes, [1] * len(pooling_sizes)) + if downsample_frequencies is None: + downsample_frequencies = [min(prediction_length, int(np.power(x, 1.5))) for x in pooling_sizes] + # remove zero from downsample_frequencies + downsample_frequencies = max(downsample_frequencies, [1] * len(downsample_frequencies)) + + # set static hidden size + if static_hidden_size is None: + static_hidden_size = hidden_size + + # set layers + if isinstance(n_layers, int): + n_layers = [n_layers] * n_stacks + + self.save_hyperparameters() + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + self.embeddings = MultiEmbedding( + embedding_sizes=self.hparams.embedding_sizes, + categorical_groups=self.hparams.categorical_groups, + embedding_paddings=self.hparams.embedding_paddings, + x_categoricals=self.hparams.x_categoricals, + ) + + self.model = NHiTSModule( + context_length=self.hparams.context_length, + prediction_length=self.hparams.prediction_length, + output_size=to_list(output_size), + static_size=self.static_size, + encoder_covariate_size=self.encoder_covariate_size, + decoder_covariate_size=self.decoder_covariate_size, + static_hidden_size=self.hparams.static_hidden_size, + n_blocks=self.hparams.n_blocks, + n_layers=self.hparams.n_layers, + hidden_size=self.n_stacks * [2 * [self.hparams.hidden_size]], + pooling_sizes=self.hparams.pooling_sizes, + downsample_frequencies=self.hparams.downsample_frequencies, + pooling_mode=self.hparams.pooling_mode, + interpolation_mode=self.hparams.interpolation_mode, + dropout=self.hparams.dropout, + activation=self.hparams.activation, + initialization=self.hparams.initialization, + batch_normalization=self.hparams.batch_normalization, + shared_weights=self.hparams.shared_weights, + naive_level=self.hparams.naive_level, + ) + + @property + def decoder_covariate_size(self) -> int: + """Decoder covariates size. + + Returns: + int: size of time-dependent covariates used by the decoder + """ + return len(set(self.hparams.time_varying_reals_decoder) - set(self.target_names)) + sum( + self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder + ) + + @property + def encoder_covariate_size(self) -> int: + """Encoder covariate size. + + Returns: + int: size of time-dependent covariates used by the encoder + """ + return len(set(self.hparams.time_varying_reals_encoder) - set(self.target_names)) + sum( + self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder + ) + + @property + def static_size(self) -> int: + """Static covariate size. + + Returns: + int: size of static covariates + """ + return len(self.hparams.static_reals) + sum( + self.embeddings.output_size[name] for name in self.hparams.static_categoricals + ) + + @property + def n_stacks(self) -> int: + """Number of stacks. + + Returns: + int: number of stacks. + """ + return len(self.hparams.n_blocks) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Pass forward of network. + + Args: + x (Dict[str, torch.Tensor]): input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Returns: + Dict[str, torch.Tensor]: output of model + """ + # covariates + if self.encoder_covariate_size > 0: + encoder_features = self.extract_features(x, self.embeddings, period="encoder") + encoder_x_t = torch.concat( + [encoder_features[name] for name in self.encoder_variables if name not in self.target_names], + dim=2, + ) + else: + encoder_x_t = None + + if self.decoder_covariate_size > 0: + decoder_features = self.extract_features(x, self.embeddings, period="decoder") + decoder_x_t = torch.concat([decoder_features[name] for name in self.decoder_variables], dim=2) + else: + decoder_x_t = None + + # statics + if self.static_size > 0: + x_s = torch.concat([encoder_features[name][:, 0] for name in self.static_variables], dim=1) + else: + x_s = None + + # target + encoder_y = x["encoder_cont"][..., self.target_positions] + encoder_mask = create_mask(x["encoder_lengths"].max(), x["encoder_lengths"], inverse=True) + + # run model + forecast, backcast, block_forecasts, block_backcasts = self.model( + encoder_y, encoder_mask, encoder_x_t, decoder_x_t, x_s + ) + backcast = encoder_y - backcast + + # create block output: detach and split by block + block_backcasts = block_backcasts.detach() + block_forecasts = block_forecasts.detach() + + if isinstance(self.hparams.output_size, (tuple, list)): + forecast = forecast.split(self.hparams.output_size, dim=2) + backcast = backcast.split(1, dim=2) + block_backcasts = tuple( + self.transform_output(block.squeeze(3).split(1, dim=2), target_scale=x["target_scale"]) + for block in block_backcasts.split(1, dim=3) + ) + block_forecasts = tuple( + self.transform_output( + block.squeeze(3).split(self.hparams.output_size, dim=2), target_scale=x["target_scale"] + ) + for block in block_forecasts.split(1, dim=3) + ) + else: + block_backcasts = tuple( + self.transform_output(block.squeeze(3), target_scale=x["target_scale"], loss=MultiHorizonMetric()) + for block in block_backcasts.split(1, dim=3) + ) + block_forecasts = tuple( + self.transform_output(block.squeeze(3), target_scale=x["target_scale"]) + for block in block_forecasts.split(1, dim=3) + ) + + return self.to_network_output( + prediction=self.transform_output( + forecast, target_scale=x["target_scale"] + ), # (n_outputs x) n_samples x n_timesteps x output_size + backcast=self.transform_output( + backcast, target_scale=x["target_scale"], loss=MultiHorizonMetric() + ), # (n_outputs x) n_samples x n_timesteps x 1 + block_backcasts=block_backcasts, # n_blocks x (n_outputs x) n_samples x n_timesteps x 1 + block_forecasts=block_forecasts, # n_blocks x (n_outputs x) n_samples x n_timesteps x output_size + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Args: + dataset (TimeSeriesDataSet): dataset where sole predictor is the target. + **kwargs: additional arguments to be passed to ``__init__`` method. + + Returns: + NBeats + """ + # validate arguments + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + assert ( + dataset.min_encoder_length == dataset.max_encoder_length + ), "only fixed encoder length is allowed, but min_encoder_length != max_encoder_length" + + assert ( + dataset.max_prediction_length == dataset.min_prediction_length + ), "only fixed prediction length is allowed, but max_prediction_length != min_prediction_length" + + assert dataset.randomize_length is None, "length has to be fixed, but randomize_length is not None" + assert not dataset.add_relative_time_idx, "add_relative_time_idx has to be False" + + new_kwargs = copy(kwargs) + new_kwargs.update( + {"prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length} + ) + new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, MASE())) + + assert (new_kwargs.get("backcast_loss_ratio", 0) == 0) | ( + isinstance(new_kwargs["output_size"], int) and new_kwargs["output_size"] == 1 + ) or all( + o == 1 for o in new_kwargs["output_size"] + ), "output sizes can only be of size 1, i.e. point forecasts if backcast_loss_ratio > 0" + + # initialize class + return super().from_dataset(dataset, **new_kwargs) + + def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: + """ + Take training / validation step. + """ + log, out = super().step(x, y, batch_idx=batch_idx) + + if self.hparams.backcast_loss_ratio > 0 and not self.predicting: # add loss from backcast + backcast = out["backcast"] + backcast_weight = ( + self.hparams.backcast_loss_ratio * self.hparams.prediction_length / self.hparams.context_length + ) + backcast_weight = backcast_weight / (backcast_weight + 1) # normalize + forecast_weight = 1 - backcast_weight + if isinstance(self.loss, (MultiLoss, MASE)): + backcast_loss = ( + self.loss( + backcast, + (x["encoder_target"], None), + encoder_target=x["decoder_target"], + encoder_lengths=x["decoder_lengths"], + ) + * backcast_weight + ) + else: + backcast_loss = self.loss(backcast, x["encoder_target"]) * backcast_weight + label = ["val", "train"][self.training] + self.log( + f"{label}_backcast_loss", + backcast_loss, + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + self.log( + f"{label}_forecast_loss", + log["loss"], + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + log["loss"] = log["loss"] * forecast_weight + backcast_loss + + # log interpretation + self.log_interpretation(x, out, batch_idx=batch_idx) + return log, out + + def plot_interpretation( + self, + x: Dict[str, torch.Tensor], + output: Dict[str, torch.Tensor], + idx: int, + ax=None, + ): + """ + Plot interpretation. + + Plot two pannels: prediction and backcast vs actuals and + decomposition of prediction into different block predictions which capture different frequencies. + + Args: + x (Dict[str, torch.Tensor]): network input + output (Dict[str, torch.Tensor]): network output + idx (int): index of sample for which to plot the interpretation. + ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. + Defaults to None. + + Returns: + plt.Figure: matplotlib figure + """ + _check_matplotlib("plot_interpretation") + + from matplotlib import pyplot as plt + + if not isinstance(self.loss, MultiLoss): # not multi-target + prediction = self.to_prediction(dict(prediction=output["prediction"][[idx]].detach()))[0].cpu() + block_forecasts = [ + self.to_prediction(dict(prediction=block[[idx]].detach()))[0].cpu() + for block in output["block_forecasts"] + ] + elif isinstance(output["prediction"], (tuple, list)): # multi-target + figs = [] + # predictions and block forecasts need to be converted + prediction = [p[[idx]].detach() for p in output["prediction"]] # select index + prediction = self.to_prediction(dict(prediction=prediction)) # transform to prediction + prediction = [p[0].cpu() for p in prediction] # select first and only index + + block_forecasts = [ + self.to_prediction(dict(prediction=[b[[idx]].detach() for b in block])) + for block in output["block_forecasts"] + ] + block_forecasts = [[b[0].cpu() for b in block] for block in block_forecasts] + + for i in range(len(self.target_names)): + if ax is not None: + ax_i = ax[i] + else: + ax_i = None + + figs.append( + self.plot_interpretation( + dict(encoder_target=x["encoder_target"][i], decoder_target=x["decoder_target"][i]), + dict( + backcast=output["backcast"][i], + prediction=prediction[i], + block_backcasts=[block[i] for block in output["block_backcasts"]], + block_forecasts=[block[i] for block in block_forecasts], + ), + idx=idx, + ax=ax_i, + ) + ) + return figs + else: + prediction = output["prediction"] # multi target that has already been transformed + block_forecasts = output["block_forecasts"] + + if ax is None: + fig, ax = plt.subplots(2, 1, figsize=(6, 8), sharex=True, sharey=True) + else: + fig = ax[0].get_figure() + + # plot target vs prediction + # target + prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) + color = next(prop_cycle)["color"] + ax[0].plot(torch.arange(-self.hparams.context_length, 0), x["encoder_target"][idx].detach().cpu(), c=color) + ax[0].plot( + torch.arange(self.hparams.prediction_length), + x["decoder_target"][idx].detach().cpu(), + label="Target", + c=color, + ) + # prediction + color = next(prop_cycle)["color"] + ax[0].plot( + torch.arange(-self.hparams.context_length, 0), + output["backcast"][idx][..., 0].detach().cpu(), + label="Backcast", + c=color, + ) + ax[0].plot( + torch.arange(self.hparams.prediction_length), + prediction, + label="Forecast", + c=color, + ) + + # plot blocks + for pooling_size, block_backcast, block_forecast in zip( + self.hparams.pooling_sizes, output["block_backcasts"][1:], block_forecasts + ): + color = next(prop_cycle)["color"] + ax[1].plot( + torch.arange(-self.hparams.context_length, 0), + block_backcast[idx][..., 0].detach().cpu(), + c=color, + ) + ax[1].plot( + torch.arange(self.hparams.prediction_length), + block_forecast, + c=color, + label=f"Pooling size: {pooling_size}", + ) + ax[1].set_xlabel("Time") + + fig.legend() + return fig + + def log_interpretation(self, x, out, batch_idx): + """ + Log interpretation of network predictions in tensorboard. + """ + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + label = ["val", "train"][self.training] + if self.log_interval > 0 and batch_idx % self.log_interval == 0: + fig = self.plot_interpretation(x, out, idx=0) + name = f"{label.capitalize()} interpretation of item 0 in " + if self.training: + name += f"step {self.global_step}" + else: + name += f"batch {batch_idx}" + self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + if isinstance(fig, (list, tuple)): + for idx, f in enumerate(fig): + self.logger.experiment.add_figure( + f"{self.target_names[idx]} {name}", + f, + global_step=self.global_step, + ) + else: + self.logger.experiment.add_figure( + name, + fig, + global_step=self.global_step, + ) diff --git a/pytorch_forecasting/models/rnn/__init__.py b/pytorch_forecasting/models/rnn/__init__.py index 142892dc..dfa9d809 100644 --- a/pytorch_forecasting/models/rnn/__init__.py +++ b/pytorch_forecasting/models/rnn/__init__.py @@ -1,317 +1,5 @@ -""" -Simple recurrent model - either with LSTM or GRU cells. -""" +"""Simple recurrent model - either with LSTM or GRU cells.""" -from copy import copy -from typing import Dict, List, Tuple, Union, Optional +from pytorch_forecasting.models.rnn._rnn import RecurrentNetwork -import numpy as np -import torch -import torch.nn as nn - -from pytorch_forecasting.data.encoders import MultiNormalizer, NaNLabelEncoder -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet -from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss, QuantileLoss -from pytorch_forecasting.models.base_model import AutoRegressiveBaseModelWithCovariates -from pytorch_forecasting.models.nn import HiddenState, MultiEmbedding, get_rnn -from pytorch_forecasting.utils import apply_to_list, to_list - - -class RecurrentNetwork(AutoRegressiveBaseModelWithCovariates): - def __init__( - self, - cell_type: str = "LSTM", - hidden_size: int = 10, - rnn_layers: int = 2, - dropout: float = 0.1, - static_categoricals: Optional[List[str]] = None, - static_reals: Optional[List[str]] = None, - time_varying_categoricals_encoder: Optional[List[str]] = None, - time_varying_categoricals_decoder: Optional[List[str]] = None, - categorical_groups: Optional[Dict[str, List[str]]] = None, - time_varying_reals_encoder: Optional[List[str]] = None, - time_varying_reals_decoder: Optional[List[str]] = None, - embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, - embedding_paddings: Optional[List[str]] = None, - embedding_labels: Optional[Dict[str, np.ndarray]] = None, - x_reals: Optional[List[str]] = None, - x_categoricals: Optional[List[str]] = None, - output_size: Union[int, List[int]] = 1, - target: Union[str, List[str]] = None, - target_lags: Optional[Dict[str, List[int]]] = None, - loss: MultiHorizonMetric = None, - logging_metrics: nn.ModuleList = None, - **kwargs, - ): - """ - Recurrent Network. - - Simple LSTM or GRU layer followed by output layer - - Args: - cell_type (str, optional): Recurrent cell type ["LSTM", "GRU"]. Defaults to "LSTM". - hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with - ``rnn_layers``. Defaults to 10. - rnn_layers (int, optional): Number of RNN layers - important hyperparameter. Defaults to 2. - dropout (float, optional): Dropout in RNN layers. Defaults to 0.1. - static_categoricals: integer of positions of static categorical variables - static_reals: integer of positions of static continuous variables - time_varying_categoricals_encoder: integer of positions of categorical variables for encoder - time_varying_categoricals_decoder: integer of positions of categorical variables for decoder - time_varying_reals_encoder: integer of positions of continuous variables for encoder - time_varying_reals_decoder: integer of positions of continuous variables for decoder - categorical_groups: dictionary where values - are list of categorical variables that are forming together a new categorical - variable which is the key in the dictionary - x_reals: order of continuous variables in tensor passed to forward function - x_categoricals: order of categorical variables in tensor passed to forward function - embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and - embedding size - embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector - embedding_labels: dictionary mapping (string) indices to list of categorical labels - output_size (Union[int, List[int]], optional): number of outputs (e.g. number of quantiles for - QuantileLoss and one target or list of output sizes). - target (str, optional): Target variable or list of target variables. Defaults to None. - target_lags (Dict[str, Dict[str, int]]): dictionary of target names mapped to list of time steps by - which the variable should be lagged. - Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data, - add at least the target variables with the corresponding lags to improve performance. - Defaults to no lags, i.e. an empty dictionary. - loss (MultiHorizonMetric, optional): loss: loss function taking prediction and targets. - logging_metrics (nn.ModuleList, optional): Metrics to log during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). - """ - if static_categoricals is None: - static_categoricals = [] - if static_reals is None: - static_reals = [] - if time_varying_categoricals_encoder is None: - time_varying_categoricals_encoder = [] - if time_varying_categoricals_decoder is None: - time_varying_categoricals_decoder = [] - if categorical_groups is None: - categorical_groups = {} - if time_varying_reals_encoder is None: - time_varying_reals_encoder = [] - if time_varying_reals_decoder is None: - time_varying_reals_decoder = [] - if embedding_sizes is None: - embedding_sizes = {} - if embedding_paddings is None: - embedding_paddings = [] - if embedding_labels is None: - embedding_labels = {} - if x_reals is None: - x_reals = [] - if x_categoricals is None: - x_categoricals = [] - if target_lags is None: - target_lags = {} - if loss is None: - loss = MAE() - if logging_metrics is None: - logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - self.save_hyperparameters() - # store loss function separately as it is a module - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - - self.embeddings = MultiEmbedding( - embedding_sizes=embedding_sizes, - embedding_paddings=embedding_paddings, - categorical_groups=categorical_groups, - x_categoricals=x_categoricals, - ) - - lagged_target_names = [l for lags in target_lags.values() for l in lags] - assert set(self.encoder_variables) - set(to_list(target)) - set(lagged_target_names) == set( - self.decoder_variables - ) - set(lagged_target_names), "Encoder and decoder variables have to be the same apart from target variable" - for targeti in to_list(target): - assert ( - targeti in time_varying_reals_encoder - ), f"target {targeti} has to be real" # todo: remove this restriction - assert (isinstance(target, str) and isinstance(loss, MultiHorizonMetric)) or ( - isinstance(target, (list, tuple)) and isinstance(loss, MultiLoss) and len(loss) == len(target) - ), "number of targets should be equivalent to number of loss metrics" - - rnn_class = get_rnn(cell_type) - cont_size = len(self.reals) - cat_size = sum(self.embeddings.output_size.values()) - input_size = cont_size + cat_size - self.rnn = rnn_class( - input_size=input_size, - hidden_size=self.hparams.hidden_size, - num_layers=self.hparams.rnn_layers, - dropout=self.hparams.dropout if self.hparams.rnn_layers > 1 else 0, - batch_first=True, - ) - - # add linear layers for argument projects - if isinstance(target, str): # single target - self.output_projector = nn.Linear(self.hparams.hidden_size, self.hparams.output_size) - assert not isinstance(self.loss, QuantileLoss), "QuantileLoss does not work with recurrent network" - else: # multi target - self.output_projector = nn.ModuleList( - [nn.Linear(self.hparams.hidden_size, size) for size in self.hparams.output_size] - ) - for l in self.loss: - assert not isinstance(l, QuantileLoss), "QuantileLoss does not work with recurrent network" - - @classmethod - def from_dataset( - cls, - dataset: TimeSeriesDataSet, - allowed_encoder_known_variable_names: List[str] = None, - **kwargs, - ): - """ - Create model from dataset. - - Args: - dataset: timeseries dataset - allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all - **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) - - Returns: - Recurrent network - """ - new_kwargs = copy(kwargs) - new_kwargs.update(cls.deduce_default_output_parameters(dataset=dataset, kwargs=kwargs, default_loss=MAE())) - assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( - not isinstance(dataset.target_normalizer, MultiNormalizer) - or all(not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer) - ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction - return super().from_dataset( - dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs - ) - - def construct_input_vector( - self, x_cat: torch.Tensor, x_cont: torch.Tensor, one_off_target: torch.Tensor = None - ) -> torch.Tensor: - """ - Create input vector into RNN network - - Args: - one_off_target: tensor to insert into first position of target. If None (default), remove first time step. - """ - # create input vector - if len(self.categoricals) > 0: - embeddings = self.embeddings(x_cat) - flat_embeddings = torch.cat(list(embeddings.values()), dim=-1) - input_vector = flat_embeddings - - if len(self.reals) > 0: - input_vector = x_cont.clone() - - if len(self.reals) > 0 and len(self.categoricals) > 0: - input_vector = torch.cat([x_cont, flat_embeddings], dim=-1) - - # shift target by one - input_vector[..., self.target_positions] = torch.roll( - input_vector[..., self.target_positions], shifts=1, dims=1 - ) - - if one_off_target is not None: # set first target input (which is rolled over) - input_vector[:, 0, self.target_positions] = one_off_target - else: - input_vector = input_vector[:, 1:] - - # shift target - return input_vector - - def encode(self, x: Dict[str, torch.Tensor]) -> HiddenState: - """ - Encode sequence into hidden state - """ - # encode using rnn - assert x["encoder_lengths"].min() > 0 - encoder_lengths = x["encoder_lengths"] - 1 - input_vector = self.construct_input_vector(x["encoder_cat"], x["encoder_cont"]) - _, hidden_state = self.rnn( - input_vector, lengths=encoder_lengths, enforce_sorted=False - ) # second ouput is not needed (hidden state) - return hidden_state - - def decode_all( - self, - x: torch.Tensor, - hidden_state: HiddenState, - lengths: torch.Tensor = None, - ): - decoder_output, hidden_state = self.rnn(x, hidden_state, lengths=lengths, enforce_sorted=False) - if isinstance(self.hparams.target, str): # single target - output = self.output_projector(decoder_output) - else: - output = [projector(decoder_output) for projector in self.output_projector] - return output, hidden_state - - def decode( - self, - input_vector: torch.Tensor, - target_scale: torch.Tensor, - decoder_lengths: torch.Tensor, - hidden_state: HiddenState, - n_samples: int = None, - ) -> Tuple[torch.Tensor, bool]: - """ - Decode hidden state of RNN into prediction. If n_smaples is given, - decode not by using actual values but rather by - sampling new targets from past predictions iteratively - """ - if self.training: - output, _ = self.decode_all(input_vector, hidden_state, lengths=decoder_lengths) - output = self.transform_output(output, target_scale=target_scale) - else: - # run in eval, i.e. simulation mode - target_pos = self.target_positions - lagged_target_positions = self.lagged_target_positions - - # define function to run at every decoding step - def decode_one( - idx, - lagged_targets, - hidden_state, - ): - x = input_vector[:, [idx]] - x[:, 0, target_pos] = lagged_targets[-1] - for lag, lag_positions in lagged_target_positions.items(): - if idx > lag: - x[:, 0, lag_positions] = lagged_targets[-lag] - prediction, hidden_state = self.decode_all(x, hidden_state) - prediction = apply_to_list(prediction, lambda x: x[:, 0]) # select first time step - return prediction, hidden_state - - # make predictions which are fed into next step - output = self.decode_autoregressive( - decode_one, - first_target=input_vector[:, 0, target_pos], - first_hidden_state=hidden_state, - target_scale=target_scale, - n_decoder_steps=input_vector.size(1), - ) - return output - - def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: - """ - Forward network - """ - hidden_state = self.encode(x) - # decode - input_vector = self.construct_input_vector( - x["decoder_cat"], - x["decoder_cont"], - one_off_target=x["encoder_cont"][ - torch.arange(x["encoder_cont"].size(0), device=x["encoder_cont"].device), - x["encoder_lengths"] - 1, - self.target_positions.unsqueeze(-1), - ].T.contiguous(), - ) - - output = self.decode( - input_vector, - decoder_lengths=x["decoder_lengths"], - target_scale=x["target_scale"], - hidden_state=hidden_state, - ) - # return relevant part - return self.to_network_output(prediction=output) +__all__ = ["RecurrentNetwork"] diff --git a/pytorch_forecasting/models/rnn/_rnn.py b/pytorch_forecasting/models/rnn/_rnn.py new file mode 100644 index 00000000..142892dc --- /dev/null +++ b/pytorch_forecasting/models/rnn/_rnn.py @@ -0,0 +1,317 @@ +""" +Simple recurrent model - either with LSTM or GRU cells. +""" + +from copy import copy +from typing import Dict, List, Tuple, Union, Optional + +import numpy as np +import torch +import torch.nn as nn + +from pytorch_forecasting.data.encoders import MultiNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss, QuantileLoss +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModelWithCovariates +from pytorch_forecasting.models.nn import HiddenState, MultiEmbedding, get_rnn +from pytorch_forecasting.utils import apply_to_list, to_list + + +class RecurrentNetwork(AutoRegressiveBaseModelWithCovariates): + def __init__( + self, + cell_type: str = "LSTM", + hidden_size: int = 10, + rnn_layers: int = 2, + dropout: float = 0.1, + static_categoricals: Optional[List[str]] = None, + static_reals: Optional[List[str]] = None, + time_varying_categoricals_encoder: Optional[List[str]] = None, + time_varying_categoricals_decoder: Optional[List[str]] = None, + categorical_groups: Optional[Dict[str, List[str]]] = None, + time_varying_reals_encoder: Optional[List[str]] = None, + time_varying_reals_decoder: Optional[List[str]] = None, + embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, + embedding_paddings: Optional[List[str]] = None, + embedding_labels: Optional[Dict[str, np.ndarray]] = None, + x_reals: Optional[List[str]] = None, + x_categoricals: Optional[List[str]] = None, + output_size: Union[int, List[int]] = 1, + target: Union[str, List[str]] = None, + target_lags: Optional[Dict[str, List[int]]] = None, + loss: MultiHorizonMetric = None, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Recurrent Network. + + Simple LSTM or GRU layer followed by output layer + + Args: + cell_type (str, optional): Recurrent cell type ["LSTM", "GRU"]. Defaults to "LSTM". + hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with + ``rnn_layers``. Defaults to 10. + rnn_layers (int, optional): Number of RNN layers - important hyperparameter. Defaults to 2. + dropout (float, optional): Dropout in RNN layers. Defaults to 0.1. + static_categoricals: integer of positions of static categorical variables + static_reals: integer of positions of static continuous variables + time_varying_categoricals_encoder: integer of positions of categorical variables for encoder + time_varying_categoricals_decoder: integer of positions of categorical variables for decoder + time_varying_reals_encoder: integer of positions of continuous variables for encoder + time_varying_reals_decoder: integer of positions of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + output_size (Union[int, List[int]], optional): number of outputs (e.g. number of quantiles for + QuantileLoss and one target or list of output sizes). + target (str, optional): Target variable or list of target variables. Defaults to None. + target_lags (Dict[str, Dict[str, int]]): dictionary of target names mapped to list of time steps by + which the variable should be lagged. + Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data, + add at least the target variables with the corresponding lags to improve performance. + Defaults to no lags, i.e. an empty dictionary. + loss (MultiHorizonMetric, optional): loss: loss function taking prediction and targets. + logging_metrics (nn.ModuleList, optional): Metrics to log during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). + """ + if static_categoricals is None: + static_categoricals = [] + if static_reals is None: + static_reals = [] + if time_varying_categoricals_encoder is None: + time_varying_categoricals_encoder = [] + if time_varying_categoricals_decoder is None: + time_varying_categoricals_decoder = [] + if categorical_groups is None: + categorical_groups = {} + if time_varying_reals_encoder is None: + time_varying_reals_encoder = [] + if time_varying_reals_decoder is None: + time_varying_reals_decoder = [] + if embedding_sizes is None: + embedding_sizes = {} + if embedding_paddings is None: + embedding_paddings = [] + if embedding_labels is None: + embedding_labels = {} + if x_reals is None: + x_reals = [] + if x_categoricals is None: + x_categoricals = [] + if target_lags is None: + target_lags = {} + if loss is None: + loss = MAE() + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + self.save_hyperparameters() + # store loss function separately as it is a module + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + self.embeddings = MultiEmbedding( + embedding_sizes=embedding_sizes, + embedding_paddings=embedding_paddings, + categorical_groups=categorical_groups, + x_categoricals=x_categoricals, + ) + + lagged_target_names = [l for lags in target_lags.values() for l in lags] + assert set(self.encoder_variables) - set(to_list(target)) - set(lagged_target_names) == set( + self.decoder_variables + ) - set(lagged_target_names), "Encoder and decoder variables have to be the same apart from target variable" + for targeti in to_list(target): + assert ( + targeti in time_varying_reals_encoder + ), f"target {targeti} has to be real" # todo: remove this restriction + assert (isinstance(target, str) and isinstance(loss, MultiHorizonMetric)) or ( + isinstance(target, (list, tuple)) and isinstance(loss, MultiLoss) and len(loss) == len(target) + ), "number of targets should be equivalent to number of loss metrics" + + rnn_class = get_rnn(cell_type) + cont_size = len(self.reals) + cat_size = sum(self.embeddings.output_size.values()) + input_size = cont_size + cat_size + self.rnn = rnn_class( + input_size=input_size, + hidden_size=self.hparams.hidden_size, + num_layers=self.hparams.rnn_layers, + dropout=self.hparams.dropout if self.hparams.rnn_layers > 1 else 0, + batch_first=True, + ) + + # add linear layers for argument projects + if isinstance(target, str): # single target + self.output_projector = nn.Linear(self.hparams.hidden_size, self.hparams.output_size) + assert not isinstance(self.loss, QuantileLoss), "QuantileLoss does not work with recurrent network" + else: # multi target + self.output_projector = nn.ModuleList( + [nn.Linear(self.hparams.hidden_size, size) for size in self.hparams.output_size] + ) + for l in self.loss: + assert not isinstance(l, QuantileLoss), "QuantileLoss does not work with recurrent network" + + @classmethod + def from_dataset( + cls, + dataset: TimeSeriesDataSet, + allowed_encoder_known_variable_names: List[str] = None, + **kwargs, + ): + """ + Create model from dataset. + + Args: + dataset: timeseries dataset + allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all + **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) + + Returns: + Recurrent network + """ + new_kwargs = copy(kwargs) + new_kwargs.update(cls.deduce_default_output_parameters(dataset=dataset, kwargs=kwargs, default_loss=MAE())) + assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( + not isinstance(dataset.target_normalizer, MultiNormalizer) + or all(not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer) + ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction + return super().from_dataset( + dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs + ) + + def construct_input_vector( + self, x_cat: torch.Tensor, x_cont: torch.Tensor, one_off_target: torch.Tensor = None + ) -> torch.Tensor: + """ + Create input vector into RNN network + + Args: + one_off_target: tensor to insert into first position of target. If None (default), remove first time step. + """ + # create input vector + if len(self.categoricals) > 0: + embeddings = self.embeddings(x_cat) + flat_embeddings = torch.cat(list(embeddings.values()), dim=-1) + input_vector = flat_embeddings + + if len(self.reals) > 0: + input_vector = x_cont.clone() + + if len(self.reals) > 0 and len(self.categoricals) > 0: + input_vector = torch.cat([x_cont, flat_embeddings], dim=-1) + + # shift target by one + input_vector[..., self.target_positions] = torch.roll( + input_vector[..., self.target_positions], shifts=1, dims=1 + ) + + if one_off_target is not None: # set first target input (which is rolled over) + input_vector[:, 0, self.target_positions] = one_off_target + else: + input_vector = input_vector[:, 1:] + + # shift target + return input_vector + + def encode(self, x: Dict[str, torch.Tensor]) -> HiddenState: + """ + Encode sequence into hidden state + """ + # encode using rnn + assert x["encoder_lengths"].min() > 0 + encoder_lengths = x["encoder_lengths"] - 1 + input_vector = self.construct_input_vector(x["encoder_cat"], x["encoder_cont"]) + _, hidden_state = self.rnn( + input_vector, lengths=encoder_lengths, enforce_sorted=False + ) # second ouput is not needed (hidden state) + return hidden_state + + def decode_all( + self, + x: torch.Tensor, + hidden_state: HiddenState, + lengths: torch.Tensor = None, + ): + decoder_output, hidden_state = self.rnn(x, hidden_state, lengths=lengths, enforce_sorted=False) + if isinstance(self.hparams.target, str): # single target + output = self.output_projector(decoder_output) + else: + output = [projector(decoder_output) for projector in self.output_projector] + return output, hidden_state + + def decode( + self, + input_vector: torch.Tensor, + target_scale: torch.Tensor, + decoder_lengths: torch.Tensor, + hidden_state: HiddenState, + n_samples: int = None, + ) -> Tuple[torch.Tensor, bool]: + """ + Decode hidden state of RNN into prediction. If n_smaples is given, + decode not by using actual values but rather by + sampling new targets from past predictions iteratively + """ + if self.training: + output, _ = self.decode_all(input_vector, hidden_state, lengths=decoder_lengths) + output = self.transform_output(output, target_scale=target_scale) + else: + # run in eval, i.e. simulation mode + target_pos = self.target_positions + lagged_target_positions = self.lagged_target_positions + + # define function to run at every decoding step + def decode_one( + idx, + lagged_targets, + hidden_state, + ): + x = input_vector[:, [idx]] + x[:, 0, target_pos] = lagged_targets[-1] + for lag, lag_positions in lagged_target_positions.items(): + if idx > lag: + x[:, 0, lag_positions] = lagged_targets[-lag] + prediction, hidden_state = self.decode_all(x, hidden_state) + prediction = apply_to_list(prediction, lambda x: x[:, 0]) # select first time step + return prediction, hidden_state + + # make predictions which are fed into next step + output = self.decode_autoregressive( + decode_one, + first_target=input_vector[:, 0, target_pos], + first_hidden_state=hidden_state, + target_scale=target_scale, + n_decoder_steps=input_vector.size(1), + ) + return output + + def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: + """ + Forward network + """ + hidden_state = self.encode(x) + # decode + input_vector = self.construct_input_vector( + x["decoder_cat"], + x["decoder_cont"], + one_off_target=x["encoder_cont"][ + torch.arange(x["encoder_cont"].size(0), device=x["encoder_cont"].device), + x["encoder_lengths"] - 1, + self.target_positions.unsqueeze(-1), + ].T.contiguous(), + ) + + output = self.decode( + input_vector, + decoder_lengths=x["decoder_lengths"], + target_scale=x["target_scale"], + hidden_state=hidden_state, + ) + # return relevant part + return self.to_network_output(prediction=output) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 5a5ebeac..90a73ff1 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -1,19 +1,6 @@ -""" -The temporal fusion transformer is a powerful predictive model for forecasting timeseries -""" +"""Temporal fusion transformer for forecasting timeseries.""" -from copy import copy -from typing import Dict, List, Tuple, Union, Optional - -import numpy as np -import torch -from torch import nn -from torchmetrics import Metric as LightningMetric - -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.metrics import MAE, MAPE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss -from pytorch_forecasting.models.base_model import BaseModelWithCovariates -from pytorch_forecasting.models.nn import LSTM, MultiEmbedding +from pytorch_forecasting.models.temporal_fusion_transformer._tft import TemporalFusionTransformer from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( AddNorm, GateAddNorm, @@ -22,877 +9,13 @@ InterpretableMultiHeadAttention, VariableSelectionNetwork, ) -from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list -from pytorch_forecasting.utils._dependencies import _check_matplotlib - - -class TemporalFusionTransformer(BaseModelWithCovariates): - def __init__( - self, - hidden_size: int = 16, - lstm_layers: int = 1, - dropout: float = 0.1, - output_size: Union[int, List[int]] = 7, - loss: MultiHorizonMetric = None, - attention_head_size: int = 4, - max_encoder_length: int = 10, - static_categoricals: Optional[List[str]] = None, - static_reals: Optional[List[str]] = None, - time_varying_categoricals_encoder: Optional[List[str]] = None, - time_varying_categoricals_decoder: Optional[List[str]] = None, - categorical_groups: Optional[Union[Dict, List[str]]] = None, - time_varying_reals_encoder: Optional[List[str]] = None, - time_varying_reals_decoder: Optional[List[str]] = None, - x_reals: Optional[List[str]] = None, - x_categoricals: Optional[List[str]] = None, - hidden_continuous_size: int = 8, - hidden_continuous_sizes: Optional[Dict[str, int]] = None, - embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, - embedding_paddings: Optional[List[str]] = None, - embedding_labels: Optional[Dict[str, np.ndarray]] = None, - learning_rate: float = 1e-3, - log_interval: Union[int, float] = -1, - log_val_interval: Union[int, float] = None, - log_gradient_flow: bool = False, - reduce_on_plateau_patience: int = 1000, - monotone_constaints: Optional[Dict[str, int]] = None, - share_single_variable_networks: bool = False, - causal_attention: bool = True, - logging_metrics: nn.ModuleList = None, - **kwargs, - ): - """ - Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible. - - Implementation of the article - `Temporal Fusion Transformers for Interpretable Multi-horizon Time Series - Forecasting `_. The network outperforms DeepAR by Amazon by 36-69% - in benchmarks. - - Enhancements compared to the original implementation (apart from capabilities added through base model - such as monotone constraints): - - * static variables can be continuous - * multiple categorical variables can be summarized with an EmbeddingBag - * variable encoder and decoder length by sample - * categorical embeddings are not transformed by variable selection network (because it is a redundant operation) - * variable dimension in variable selection network are scaled up via linear interpolation to reduce - number of parameters - * non-linear variable processing in variable selection network can be shared among decoder and encoder - (not shared by default) - - Tune its hyperparameters with - :py:func:`~pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters`. - - Args: - - hidden_size: hidden size of network which is its main hyperparameter and can range from 8 to 512 - lstm_layers: number of LSTM layers (2 is mostly optimal) - dropout: dropout rate - output_size: number of outputs (e.g. number of quantiles for QuantileLoss and one target or list - of output sizes). - loss: loss function taking prediction and targets - attention_head_size: number of attention heads (4 is a good default) - max_encoder_length: length to encode (can be far longer than the decoder length but does not have to be) - static_categoricals: names of static categorical variables - static_reals: names of static continuous variables - time_varying_categoricals_encoder: names of categorical variables for encoder - time_varying_categoricals_decoder: names of categorical variables for decoder - time_varying_reals_encoder: names of continuous variables for encoder - time_varying_reals_decoder: names of continuous variables for decoder - categorical_groups: dictionary where values - are list of categorical variables that are forming together a new categorical - variable which is the key in the dictionary - x_reals: order of continuous variables in tensor passed to forward function - x_categoricals: order of categorical variables in tensor passed to forward function - hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical - embedding size) - hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection - (fallback to hidden_continuous_size if index is not in dictionary) - embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and - embedding size - embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector - embedding_labels: dictionary mapping (string) indices to list of categorical labels - learning_rate: learning rate - log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 - , will log multiple entries per batch. Defaults to -1. - log_val_interval: frequency with which to log validation set metrics, defaults to log_interval - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 - monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder - variables mapping - position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive, - larger numbers add more weight to the constraint vs. the loss but are usually not necessary). - This constraint significantly slows down training. Defaults to {}. - share_single_variable_networks (bool): if to share the single variable networks between the encoder and - decoder. Defaults to False. - causal_attention (bool): If to attend only at previous timesteps in the decoder or also include future - predictions. Defaults to True. - logging_metrics (nn.ModuleList[LightningMetric]): list of metrics that are logged during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]). - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ - if monotone_constaints is None: - monotone_constaints = {} - if embedding_labels is None: - embedding_labels = {} - if embedding_paddings is None: - embedding_paddings = [] - if embedding_sizes is None: - embedding_sizes = {} - if hidden_continuous_sizes is None: - hidden_continuous_sizes = {} - if x_categoricals is None: - x_categoricals = [] - if x_reals is None: - x_reals = [] - if time_varying_reals_decoder is None: - time_varying_reals_decoder = [] - if time_varying_reals_encoder is None: - time_varying_reals_encoder = [] - if categorical_groups is None: - categorical_groups = {} - if time_varying_categoricals_decoder is None: - time_varying_categoricals_decoder = [] - if time_varying_categoricals_encoder is None: - time_varying_categoricals_encoder = [] - if static_reals is None: - static_reals = [] - if static_categoricals is None: - static_categoricals = [] - if logging_metrics is None: - logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]) - if loss is None: - loss = QuantileLoss() - self.save_hyperparameters() - # store loss function separately as it is a module - assert isinstance(loss, LightningMetric), "Loss has to be a PyTorch Lightning `Metric`" - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - - # processing inputs - # embeddings - self.input_embeddings = MultiEmbedding( - embedding_sizes=self.hparams.embedding_sizes, - categorical_groups=self.hparams.categorical_groups, - embedding_paddings=self.hparams.embedding_paddings, - x_categoricals=self.hparams.x_categoricals, - max_embedding_size=self.hparams.hidden_size, - ) - - # continuous variable processing - self.prescalers = nn.ModuleDict( - { - name: nn.Linear(1, self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)) - for name in self.reals - } - ) - - # variable selection - # variable selection for static variables - static_input_sizes = { - name: self.input_embeddings.output_size[name] for name in self.hparams.static_categoricals - } - static_input_sizes.update( - { - name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) - for name in self.hparams.static_reals - } - ) - self.static_variable_selection = VariableSelectionNetwork( - input_sizes=static_input_sizes, - hidden_size=self.hparams.hidden_size, - input_embedding_flags={name: True for name in self.hparams.static_categoricals}, - dropout=self.hparams.dropout, - prescalers=self.prescalers, - ) - - # variable selection for encoder and decoder - encoder_input_sizes = { - name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder - } - encoder_input_sizes.update( - { - name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) - for name in self.hparams.time_varying_reals_encoder - } - ) - - decoder_input_sizes = { - name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder - } - decoder_input_sizes.update( - { - name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) - for name in self.hparams.time_varying_reals_decoder - } - ) - - # create single variable grns that are shared across decoder and encoder - if self.hparams.share_single_variable_networks: - self.shared_single_variable_grns = nn.ModuleDict() - for name, input_size in encoder_input_sizes.items(): - self.shared_single_variable_grns[name] = GatedResidualNetwork( - input_size, - min(input_size, self.hparams.hidden_size), - self.hparams.hidden_size, - self.hparams.dropout, - ) - for name, input_size in decoder_input_sizes.items(): - if name not in self.shared_single_variable_grns: - self.shared_single_variable_grns[name] = GatedResidualNetwork( - input_size, - min(input_size, self.hparams.hidden_size), - self.hparams.hidden_size, - self.hparams.dropout, - ) - - self.encoder_variable_selection = VariableSelectionNetwork( - input_sizes=encoder_input_sizes, - hidden_size=self.hparams.hidden_size, - input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_encoder}, - dropout=self.hparams.dropout, - context_size=self.hparams.hidden_size, - prescalers=self.prescalers, - single_variable_grns=( - {} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns - ), - ) - - self.decoder_variable_selection = VariableSelectionNetwork( - input_sizes=decoder_input_sizes, - hidden_size=self.hparams.hidden_size, - input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_decoder}, - dropout=self.hparams.dropout, - context_size=self.hparams.hidden_size, - prescalers=self.prescalers, - single_variable_grns=( - {} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns - ), - ) - - # static encoders - # for variable selection - self.static_context_variable_selection = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - ) - - # for hidden state of the lstm - self.static_context_initial_hidden_lstm = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - ) - - # for cell state of the lstm - self.static_context_initial_cell_lstm = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - ) - - # for post lstm static enrichment - self.static_context_enrichment = GatedResidualNetwork( - self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout - ) - - # lstm encoder (history) and decoder (future) for local processing - self.lstm_encoder = LSTM( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - num_layers=self.hparams.lstm_layers, - dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, - batch_first=True, - ) - - self.lstm_decoder = LSTM( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - num_layers=self.hparams.lstm_layers, - dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, - batch_first=True, - ) - - # skip connection for lstm - self.post_lstm_gate_encoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout) - self.post_lstm_gate_decoder = self.post_lstm_gate_encoder - # self.post_lstm_gate_decoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout) - self.post_lstm_add_norm_encoder = AddNorm(self.hparams.hidden_size, trainable_add=False) - # self.post_lstm_add_norm_decoder = AddNorm(self.hparams.hidden_size, trainable_add=True) - self.post_lstm_add_norm_decoder = self.post_lstm_add_norm_encoder - - # static enrichment and processing past LSTM - self.static_enrichment = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - context_size=self.hparams.hidden_size, - ) - - # attention for long-range processing - self.multihead_attn = InterpretableMultiHeadAttention( - d_model=self.hparams.hidden_size, n_head=self.hparams.attention_head_size, dropout=self.hparams.dropout - ) - self.post_attn_gate_norm = GateAddNorm( - self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False - ) - self.pos_wise_ff = GatedResidualNetwork( - self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, dropout=self.hparams.dropout - ) - - # output processing -> no dropout at this late stage - self.pre_output_gate_norm = GateAddNorm(self.hparams.hidden_size, dropout=None, trainable_add=False) - - if self.n_targets > 1: # if to run with multiple targets - self.output_layer = nn.ModuleList( - [nn.Linear(self.hparams.hidden_size, output_size) for output_size in self.hparams.output_size] - ) - else: - self.output_layer = nn.Linear(self.hparams.hidden_size, self.hparams.output_size) - - @classmethod - def from_dataset( - cls, - dataset: TimeSeriesDataSet, - allowed_encoder_known_variable_names: List[str] = None, - **kwargs, - ): - """ - Create model from dataset. - - Args: - dataset: timeseries dataset - allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all - **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) - - Returns: - TemporalFusionTransformer - """ - # add maximum encoder length - # update defaults - new_kwargs = copy(kwargs) - new_kwargs["max_encoder_length"] = dataset.max_encoder_length - new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss())) - - # create class and return - return super().from_dataset( - dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs - ) - - def expand_static_context(self, context, timesteps): - """ - add time dimension to static context - """ - return context[:, None].expand(-1, timesteps, -1) - - def get_attention_mask(self, encoder_lengths: torch.LongTensor, decoder_lengths: torch.LongTensor): - """ - Returns causal mask to apply for self-attention layer. - """ - decoder_length = decoder_lengths.max() - if self.hparams.causal_attention: - # indices to which is attended - attend_step = torch.arange(decoder_length, device=self.device) - # indices for which is predicted - predict_step = torch.arange(0, decoder_length, device=self.device)[:, None] - # do not attend to steps to self or after prediction - decoder_mask = (attend_step >= predict_step).unsqueeze(0).expand(encoder_lengths.size(0), -1, -1) - else: - # there is value in attending to future forecasts if they are made with knowledge currently - # available - # one possibility is here to use a second attention layer for future attention (assuming different effects - # matter in the future than the past) - # or alternatively using the same layer but allowing forward attention - i.e. only - # masking out non-available data and self - decoder_mask = create_mask(decoder_length, decoder_lengths).unsqueeze(1).expand(-1, decoder_length, -1) - # do not attend to steps where data is padded - encoder_mask = create_mask(encoder_lengths.max(), encoder_lengths).unsqueeze(1).expand(-1, decoder_length, -1) - # combine masks along attended time - first encoder and then decoder - mask = torch.cat( - ( - encoder_mask, - decoder_mask, - ), - dim=2, - ) - return mask - - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - input dimensions: n_samples x time x variables - """ - encoder_lengths = x["encoder_lengths"] - decoder_lengths = x["decoder_lengths"] - x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension - x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1) # concatenate in time dimension - timesteps = x_cont.size(1) # encode + decode length - max_encoder_length = int(encoder_lengths.max()) - input_vectors = self.input_embeddings(x_cat) - input_vectors.update( - { - name: x_cont[..., idx].unsqueeze(-1) - for idx, name in enumerate(self.hparams.x_reals) - if name in self.reals - } - ) - - # Embedding and variable selection - if len(self.static_variables) > 0: - # static embeddings will be constant over entire batch - static_embedding = {name: input_vectors[name][:, 0] for name in self.static_variables} - static_embedding, static_variable_selection = self.static_variable_selection(static_embedding) - else: - static_embedding = torch.zeros( - (x_cont.size(0), self.hparams.hidden_size), dtype=self.dtype, device=self.device - ) - static_variable_selection = torch.zeros((x_cont.size(0), 0), dtype=self.dtype, device=self.device) - - static_context_variable_selection = self.expand_static_context( - self.static_context_variable_selection(static_embedding), timesteps - ) - - embeddings_varying_encoder = { - name: input_vectors[name][:, :max_encoder_length] for name in self.encoder_variables - } - embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection( - embeddings_varying_encoder, - static_context_variable_selection[:, :max_encoder_length], - ) - - embeddings_varying_decoder = { - name: input_vectors[name][:, max_encoder_length:] for name in self.decoder_variables # select decoder - } - embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection( - embeddings_varying_decoder, - static_context_variable_selection[:, max_encoder_length:], - ) - - # LSTM - # calculate initial state - input_hidden = self.static_context_initial_hidden_lstm(static_embedding).expand( - self.hparams.lstm_layers, -1, -1 - ) - input_cell = self.static_context_initial_cell_lstm(static_embedding).expand(self.hparams.lstm_layers, -1, -1) - - # run local encoder - encoder_output, (hidden, cell) = self.lstm_encoder( - embeddings_varying_encoder, (input_hidden, input_cell), lengths=encoder_lengths, enforce_sorted=False - ) - - # run local decoder - decoder_output, _ = self.lstm_decoder( - embeddings_varying_decoder, - (hidden, cell), - lengths=decoder_lengths, - enforce_sorted=False, - ) - - # skip connection over lstm - lstm_output_encoder = self.post_lstm_gate_encoder(encoder_output) - lstm_output_encoder = self.post_lstm_add_norm_encoder(lstm_output_encoder, embeddings_varying_encoder) - - lstm_output_decoder = self.post_lstm_gate_decoder(decoder_output) - lstm_output_decoder = self.post_lstm_add_norm_decoder(lstm_output_decoder, embeddings_varying_decoder) - - lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1) - - # static enrichment - static_context_enrichment = self.static_context_enrichment(static_embedding) - attn_input = self.static_enrichment( - lstm_output, self.expand_static_context(static_context_enrichment, timesteps) - ) - - # Attention - attn_output, attn_output_weights = self.multihead_attn( - q=attn_input[:, max_encoder_length:], # query only for predictions - k=attn_input, - v=attn_input, - mask=self.get_attention_mask(encoder_lengths=encoder_lengths, decoder_lengths=decoder_lengths), - ) - - # skip connection over attention - attn_output = self.post_attn_gate_norm(attn_output, attn_input[:, max_encoder_length:]) - - output = self.pos_wise_ff(attn_output) - - # skip connection over temporal fusion decoder (not LSTM decoder despite the LSTM output contains - # a skip from the variable selection network) - output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:]) - if self.n_targets > 1: # if to use multi-target architecture - output = [output_layer(output) for output_layer in self.output_layer] - else: - output = self.output_layer(output) - - return self.to_network_output( - prediction=self.transform_output(output, target_scale=x["target_scale"]), - encoder_attention=attn_output_weights[..., :max_encoder_length], - decoder_attention=attn_output_weights[..., max_encoder_length:], - static_variables=static_variable_selection, - encoder_variables=encoder_sparse_weights, - decoder_variables=decoder_sparse_weights, - decoder_lengths=decoder_lengths, - encoder_lengths=encoder_lengths, - ) - - def on_fit_end(self): - if self.log_interval > 0: - self.log_embeddings() - - def create_log(self, x, y, out, batch_idx, **kwargs): - log = super().create_log(x, y, out, batch_idx, **kwargs) - if self.log_interval > 0: - log["interpretation"] = self._log_interpretation(out) - return log - - def _log_interpretation(self, out): - # calculate interpretations etc for latter logging - interpretation = self.interpret_output( - detach(out), - reduction="sum", - attention_prediction_horizon=0, # attention only for first prediction horizon - ) - return interpretation - - def on_epoch_end(self, outputs): - """ - run at epoch end for training or validation - """ - if self.log_interval > 0 and not self.training: - self.log_interpretation(outputs) - - def interpret_output( - self, - out: Dict[str, torch.Tensor], - reduction: str = "none", - attention_prediction_horizon: int = 0, - ) -> Dict[str, torch.Tensor]: - """ - interpret output of model - - Args: - out: output as produced by ``forward()`` - reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for - normalizing by encode lengths - attention_prediction_horizon: which prediction horizon to use for attention - - Returns: - interpretations that can be plotted with ``plot_interpretation()`` - """ - # take attention and concatenate if a list to proper attention object - batch_size = len(out["decoder_attention"]) - if isinstance(out["decoder_attention"], (list, tuple)): - # start with decoder attention - # assume issue is in last dimension, we need to find max - max_last_dimension = max(x.size(-1) for x in out["decoder_attention"]) - first_elm = out["decoder_attention"][0] - # create new attention tensor into which we will scatter - decoder_attention = torch.full( - (batch_size, *first_elm.shape[:-1], max_last_dimension), - float("nan"), - dtype=first_elm.dtype, - device=first_elm.device, - ) - # scatter into tensor - for idx, x in enumerate(out["decoder_attention"]): - decoder_length = out["decoder_lengths"][idx] - decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length] - else: - decoder_attention = out["decoder_attention"].clone() - decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"]) - decoder_attention[decoder_mask[..., None, None].expand_as(decoder_attention)] = float("nan") - - if isinstance(out["encoder_attention"], (list, tuple)): - # same game for encoder attention - # create new attention tensor into which we will scatter - first_elm = out["encoder_attention"][0] - encoder_attention = torch.full( - (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), - float("nan"), - dtype=first_elm.dtype, - device=first_elm.device, - ) - # scatter into tensor - for idx, x in enumerate(out["encoder_attention"]): - encoder_length = out["encoder_lengths"][idx] - encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length :] = x[ - ..., :encoder_length - ] - else: - # roll encoder attention (so start last encoder value is on the right) - encoder_attention = out["encoder_attention"].clone() - shifts = encoder_attention.size(3) - out["encoder_lengths"] - new_index = ( - torch.arange(encoder_attention.size(3), device=encoder_attention.device)[None, None, None].expand_as( - encoder_attention - ) - - shifts[:, None, None, None] - ) % encoder_attention.size(3) - encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) - # expand encoder_attention to full size - if encoder_attention.size(-1) < self.hparams.max_encoder_length: - encoder_attention = torch.concat( - [ - torch.full( - ( - *encoder_attention.shape[:-1], - self.hparams.max_encoder_length - out["encoder_lengths"].max(), - ), - float("nan"), - dtype=encoder_attention.dtype, - device=encoder_attention.device, - ), - encoder_attention, - ], - dim=-1, - ) - - # combine attention vector - attention = torch.concat([encoder_attention, decoder_attention], dim=-1) - attention[attention < 1e-5] = float("nan") - - # histogram of decode and encode lengths - encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) - decoder_length_histogram = integer_histogram( - out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1) - ) - - # mask where decoder and encoder where not applied when averaging variable selection weights - encoder_variables = out["encoder_variables"].squeeze(-2).clone() - encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) - encoder_variables = encoder_variables.masked_fill(encode_mask.unsqueeze(-1), 0.0).sum(dim=1) - encoder_variables /= ( - out["encoder_lengths"] - .where(out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])) - .unsqueeze(-1) - ) - - decoder_variables = out["decoder_variables"].squeeze(-2).clone() - decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) - decoder_variables = decoder_variables.masked_fill(decode_mask.unsqueeze(-1), 0.0).sum(dim=1) - decoder_variables /= out["decoder_lengths"].unsqueeze(-1) - - # static variables need no masking - static_variables = out["static_variables"].squeeze(1) - # attention is batch x time x heads x time_to_attend - # average over heads + only keep prediction attention and attention on observed timesteps - attention = masked_op( - attention[ - :, attention_prediction_horizon, :, : self.hparams.max_encoder_length + attention_prediction_horizon - ], - op="mean", - dim=1, - ) - - if reduction != "none": # if to average over batches - static_variables = static_variables.sum(dim=0) - encoder_variables = encoder_variables.sum(dim=0) - decoder_variables = decoder_variables.sum(dim=0) - - attention = masked_op(attention, dim=0, op=reduction) - else: - attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze(-1) # renormalize - - interpretation = dict( - attention=attention.masked_fill(torch.isnan(attention), 0.0), - static_variables=static_variables, - encoder_variables=encoder_variables, - decoder_variables=decoder_variables, - encoder_length_histogram=encoder_length_histogram, - decoder_length_histogram=decoder_length_histogram, - ) - return interpretation - - def plot_prediction( - self, - x: Dict[str, torch.Tensor], - out: Dict[str, torch.Tensor], - idx: int, - plot_attention: bool = True, - add_loss_to_title: bool = False, - show_future_observed: bool = True, - ax=None, - **kwargs, - ): - """ - Plot actuals vs prediction and attention - - Args: - x (Dict[str, torch.Tensor]): network input - out (Dict[str, torch.Tensor]): network output - idx (int): sample index - plot_attention: if to plot attention on secondary axis - add_loss_to_title: if to add loss to title. Default to False. - show_future_observed: if to show actuals for future. Defaults to True. - ax: matplotlib axes to plot on - - Returns: - plt.Figure: matplotlib figure - """ - # plot prediction as normal - fig = super().plot_prediction( - x, - out, - idx=idx, - add_loss_to_title=add_loss_to_title, - show_future_observed=show_future_observed, - ax=ax, - **kwargs, - ) - - # add attention on secondary axis - if plot_attention: - interpretation = self.interpret_output(out.iget(slice(idx, idx + 1))) - for f in to_list(fig): - ax = f.axes[0] - ax2 = ax.twinx() - ax2.set_ylabel("Attention") - encoder_length = x["encoder_lengths"][0] - ax2.plot( - torch.arange(-encoder_length, 0), - interpretation["attention"][0, -encoder_length:].detach().cpu(), - alpha=0.2, - color="k", - ) - f.tight_layout() - return fig - - def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]): - """ - Make figures that interpret model. - - * Attention - * Variable selection weights / importances - - Args: - interpretation: as obtained from ``interpret_output()`` - - Returns: - dictionary of matplotlib figures - """ - _check_matplotlib("plot_interpretation") - - import matplotlib.pyplot as plt - - figs = {} - - # attention - fig, ax = plt.subplots() - attention = interpretation["attention"].detach().cpu() - attention = attention / attention.sum(-1).unsqueeze(-1) - ax.plot( - np.arange(-self.hparams.max_encoder_length, attention.size(0) - self.hparams.max_encoder_length), attention - ) - ax.set_xlabel("Time index") - ax.set_ylabel("Attention") - ax.set_title("Attention") - figs["attention"] = fig - - # variable selection - def make_selection_plot(title, values, labels): - fig, ax = plt.subplots(figsize=(7, len(values) * 0.25 + 2)) - order = np.argsort(values) - values = values / values.sum(-1).unsqueeze(-1) - ax.barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order]) - ax.set_title(title) - ax.set_xlabel("Importance in %") - plt.tight_layout() - return fig - - figs["static_variables"] = make_selection_plot( - "Static variables importance", interpretation["static_variables"].detach().cpu(), self.static_variables - ) - figs["encoder_variables"] = make_selection_plot( - "Encoder variables importance", interpretation["encoder_variables"].detach().cpu(), self.encoder_variables - ) - figs["decoder_variables"] = make_selection_plot( - "Decoder variables importance", interpretation["decoder_variables"].detach().cpu(), self.decoder_variables - ) - - return figs - - def log_interpretation(self, outputs): - """ - Log interpretation metrics to tensorboard. - """ - # extract interpretations - interpretation = { - # use padded_stack because decoder length histogram can be of different length - name: padded_stack([x["interpretation"][name].detach() for x in outputs], side="right", value=0).sum(0) - for name in outputs[0]["interpretation"].keys() - } - # normalize attention with length histogram squared to account for: 1. zeros in attention and - # 2. higher attention due to less values - attention_occurances = interpretation["encoder_length_histogram"][1:].flip(0).float().cumsum(0) - attention_occurances = attention_occurances / attention_occurances.max() - attention_occurances = torch.cat( - [ - attention_occurances, - torch.ones( - interpretation["attention"].size(0) - attention_occurances.size(0), - dtype=attention_occurances.dtype, - device=attention_occurances.device, - ), - ], - dim=0, - ) - interpretation["attention"] = interpretation["attention"] / attention_occurances.pow(2).clamp(1.0) - interpretation["attention"] = interpretation["attention"] / interpretation["attention"].sum() - - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - - import matplotlib.pyplot as plt - - figs = self.plot_interpretation(interpretation) # make interpretation figures - label = self.current_stage - # log to tensorboard - for name, fig in figs.items(): - self.logger.experiment.add_figure( - f"{label.capitalize()} {name} importance", fig, global_step=self.global_step - ) - - # log lengths of encoder/decoder - for type in ["encoder", "decoder"]: - fig, ax = plt.subplots() - lengths = ( - padded_stack([out["interpretation"][f"{type}_length_histogram"] for out in outputs]) - .sum(0) - .detach() - .cpu() - ) - if type == "decoder": - start = 1 - else: - start = 0 - ax.plot(torch.arange(start, start + len(lengths)), lengths) - ax.set_xlabel(f"{type.capitalize()} length") - ax.set_ylabel("Number of samples") - ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") - - self.logger.experiment.add_figure( - f"{label.capitalize()} {type} length distribution", fig, global_step=self.global_step - ) - - def log_embeddings(self): - """ - Log embeddings to tensorboard - """ - - # Don't log embeddings if add_embedding is not available - if not self._logger_supports("add_embedding"): - return None - for name, emb in self.input_embeddings.items(): - labels = self.hparams.embedding_labels[name] - self.logger.experiment.add_embedding( - emb.weight.data.detach().cpu(), metadata=labels, tag=name, global_step=self.global_step - ) +__all__ = [ + "TemporalFusionTransformer", + "AddNorm", + "GateAddNorm", + "GatedLinearUnit", + "GatedResidualNetwork", + "InterpretableMultiHeadAttention", + "VariableSelectionNetwork", +] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft.py new file mode 100644 index 00000000..5a5ebeac --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft.py @@ -0,0 +1,898 @@ +""" +The temporal fusion transformer is a powerful predictive model for forecasting timeseries +""" + +from copy import copy +from typing import Dict, List, Tuple, Union, Optional + +import numpy as np +import torch +from torch import nn +from torchmetrics import Metric as LightningMetric + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.metrics import MAE, MAPE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss +from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.nn import LSTM, MultiEmbedding +from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( + AddNorm, + GateAddNorm, + GatedLinearUnit, + GatedResidualNetwork, + InterpretableMultiHeadAttention, + VariableSelectionNetwork, +) +from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class TemporalFusionTransformer(BaseModelWithCovariates): + def __init__( + self, + hidden_size: int = 16, + lstm_layers: int = 1, + dropout: float = 0.1, + output_size: Union[int, List[int]] = 7, + loss: MultiHorizonMetric = None, + attention_head_size: int = 4, + max_encoder_length: int = 10, + static_categoricals: Optional[List[str]] = None, + static_reals: Optional[List[str]] = None, + time_varying_categoricals_encoder: Optional[List[str]] = None, + time_varying_categoricals_decoder: Optional[List[str]] = None, + categorical_groups: Optional[Union[Dict, List[str]]] = None, + time_varying_reals_encoder: Optional[List[str]] = None, + time_varying_reals_decoder: Optional[List[str]] = None, + x_reals: Optional[List[str]] = None, + x_categoricals: Optional[List[str]] = None, + hidden_continuous_size: int = 8, + hidden_continuous_sizes: Optional[Dict[str, int]] = None, + embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, + embedding_paddings: Optional[List[str]] = None, + embedding_labels: Optional[Dict[str, np.ndarray]] = None, + learning_rate: float = 1e-3, + log_interval: Union[int, float] = -1, + log_val_interval: Union[int, float] = None, + log_gradient_flow: bool = False, + reduce_on_plateau_patience: int = 1000, + monotone_constaints: Optional[Dict[str, int]] = None, + share_single_variable_networks: bool = False, + causal_attention: bool = True, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible. + + Implementation of the article + `Temporal Fusion Transformers for Interpretable Multi-horizon Time Series + Forecasting `_. The network outperforms DeepAR by Amazon by 36-69% + in benchmarks. + + Enhancements compared to the original implementation (apart from capabilities added through base model + such as monotone constraints): + + * static variables can be continuous + * multiple categorical variables can be summarized with an EmbeddingBag + * variable encoder and decoder length by sample + * categorical embeddings are not transformed by variable selection network (because it is a redundant operation) + * variable dimension in variable selection network are scaled up via linear interpolation to reduce + number of parameters + * non-linear variable processing in variable selection network can be shared among decoder and encoder + (not shared by default) + + Tune its hyperparameters with + :py:func:`~pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters`. + + Args: + + hidden_size: hidden size of network which is its main hyperparameter and can range from 8 to 512 + lstm_layers: number of LSTM layers (2 is mostly optimal) + dropout: dropout rate + output_size: number of outputs (e.g. number of quantiles for QuantileLoss and one target or list + of output sizes). + loss: loss function taking prediction and targets + attention_head_size: number of attention heads (4 is a good default) + max_encoder_length: length to encode (can be far longer than the decoder length but does not have to be) + static_categoricals: names of static categorical variables + static_reals: names of static continuous variables + time_varying_categoricals_encoder: names of categorical variables for encoder + time_varying_categoricals_decoder: names of categorical variables for decoder + time_varying_reals_encoder: names of continuous variables for encoder + time_varying_reals_decoder: names of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical + embedding size) + hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection + (fallback to hidden_continuous_size if index is not in dictionary) + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + learning_rate: learning rate + log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 + , will log multiple entries per batch. Defaults to -1. + log_val_interval: frequency with which to log validation set metrics, defaults to log_interval + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 + monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder + variables mapping + position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive, + larger numbers add more weight to the constraint vs. the loss but are usually not necessary). + This constraint significantly slows down training. Defaults to {}. + share_single_variable_networks (bool): if to share the single variable networks between the encoder and + decoder. Defaults to False. + causal_attention (bool): If to attend only at previous timesteps in the decoder or also include future + predictions. Defaults to True. + logging_metrics (nn.ModuleList[LightningMetric]): list of metrics that are logged during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]). + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ + if monotone_constaints is None: + monotone_constaints = {} + if embedding_labels is None: + embedding_labels = {} + if embedding_paddings is None: + embedding_paddings = [] + if embedding_sizes is None: + embedding_sizes = {} + if hidden_continuous_sizes is None: + hidden_continuous_sizes = {} + if x_categoricals is None: + x_categoricals = [] + if x_reals is None: + x_reals = [] + if time_varying_reals_decoder is None: + time_varying_reals_decoder = [] + if time_varying_reals_encoder is None: + time_varying_reals_encoder = [] + if categorical_groups is None: + categorical_groups = {} + if time_varying_categoricals_decoder is None: + time_varying_categoricals_decoder = [] + if time_varying_categoricals_encoder is None: + time_varying_categoricals_encoder = [] + if static_reals is None: + static_reals = [] + if static_categoricals is None: + static_categoricals = [] + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]) + if loss is None: + loss = QuantileLoss() + self.save_hyperparameters() + # store loss function separately as it is a module + assert isinstance(loss, LightningMetric), "Loss has to be a PyTorch Lightning `Metric`" + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # processing inputs + # embeddings + self.input_embeddings = MultiEmbedding( + embedding_sizes=self.hparams.embedding_sizes, + categorical_groups=self.hparams.categorical_groups, + embedding_paddings=self.hparams.embedding_paddings, + x_categoricals=self.hparams.x_categoricals, + max_embedding_size=self.hparams.hidden_size, + ) + + # continuous variable processing + self.prescalers = nn.ModuleDict( + { + name: nn.Linear(1, self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size)) + for name in self.reals + } + ) + + # variable selection + # variable selection for static variables + static_input_sizes = { + name: self.input_embeddings.output_size[name] for name in self.hparams.static_categoricals + } + static_input_sizes.update( + { + name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) + for name in self.hparams.static_reals + } + ) + self.static_variable_selection = VariableSelectionNetwork( + input_sizes=static_input_sizes, + hidden_size=self.hparams.hidden_size, + input_embedding_flags={name: True for name in self.hparams.static_categoricals}, + dropout=self.hparams.dropout, + prescalers=self.prescalers, + ) + + # variable selection for encoder and decoder + encoder_input_sizes = { + name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder + } + encoder_input_sizes.update( + { + name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) + for name in self.hparams.time_varying_reals_encoder + } + ) + + decoder_input_sizes = { + name: self.input_embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder + } + decoder_input_sizes.update( + { + name: self.hparams.hidden_continuous_sizes.get(name, self.hparams.hidden_continuous_size) + for name in self.hparams.time_varying_reals_decoder + } + ) + + # create single variable grns that are shared across decoder and encoder + if self.hparams.share_single_variable_networks: + self.shared_single_variable_grns = nn.ModuleDict() + for name, input_size in encoder_input_sizes.items(): + self.shared_single_variable_grns[name] = GatedResidualNetwork( + input_size, + min(input_size, self.hparams.hidden_size), + self.hparams.hidden_size, + self.hparams.dropout, + ) + for name, input_size in decoder_input_sizes.items(): + if name not in self.shared_single_variable_grns: + self.shared_single_variable_grns[name] = GatedResidualNetwork( + input_size, + min(input_size, self.hparams.hidden_size), + self.hparams.hidden_size, + self.hparams.dropout, + ) + + self.encoder_variable_selection = VariableSelectionNetwork( + input_sizes=encoder_input_sizes, + hidden_size=self.hparams.hidden_size, + input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_encoder}, + dropout=self.hparams.dropout, + context_size=self.hparams.hidden_size, + prescalers=self.prescalers, + single_variable_grns=( + {} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns + ), + ) + + self.decoder_variable_selection = VariableSelectionNetwork( + input_sizes=decoder_input_sizes, + hidden_size=self.hparams.hidden_size, + input_embedding_flags={name: True for name in self.hparams.time_varying_categoricals_decoder}, + dropout=self.hparams.dropout, + context_size=self.hparams.hidden_size, + prescalers=self.prescalers, + single_variable_grns=( + {} if not self.hparams.share_single_variable_networks else self.shared_single_variable_grns + ), + ) + + # static encoders + # for variable selection + self.static_context_variable_selection = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + ) + + # for hidden state of the lstm + self.static_context_initial_hidden_lstm = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + ) + + # for cell state of the lstm + self.static_context_initial_cell_lstm = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + ) + + # for post lstm static enrichment + self.static_context_enrichment = GatedResidualNetwork( + self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.dropout + ) + + # lstm encoder (history) and decoder (future) for local processing + self.lstm_encoder = LSTM( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + num_layers=self.hparams.lstm_layers, + dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, + batch_first=True, + ) + + self.lstm_decoder = LSTM( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + num_layers=self.hparams.lstm_layers, + dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, + batch_first=True, + ) + + # skip connection for lstm + self.post_lstm_gate_encoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout) + self.post_lstm_gate_decoder = self.post_lstm_gate_encoder + # self.post_lstm_gate_decoder = GatedLinearUnit(self.hparams.hidden_size, dropout=self.hparams.dropout) + self.post_lstm_add_norm_encoder = AddNorm(self.hparams.hidden_size, trainable_add=False) + # self.post_lstm_add_norm_decoder = AddNorm(self.hparams.hidden_size, trainable_add=True) + self.post_lstm_add_norm_decoder = self.post_lstm_add_norm_encoder + + # static enrichment and processing past LSTM + self.static_enrichment = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + context_size=self.hparams.hidden_size, + ) + + # attention for long-range processing + self.multihead_attn = InterpretableMultiHeadAttention( + d_model=self.hparams.hidden_size, n_head=self.hparams.attention_head_size, dropout=self.hparams.dropout + ) + self.post_attn_gate_norm = GateAddNorm( + self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False + ) + self.pos_wise_ff = GatedResidualNetwork( + self.hparams.hidden_size, self.hparams.hidden_size, self.hparams.hidden_size, dropout=self.hparams.dropout + ) + + # output processing -> no dropout at this late stage + self.pre_output_gate_norm = GateAddNorm(self.hparams.hidden_size, dropout=None, trainable_add=False) + + if self.n_targets > 1: # if to run with multiple targets + self.output_layer = nn.ModuleList( + [nn.Linear(self.hparams.hidden_size, output_size) for output_size in self.hparams.output_size] + ) + else: + self.output_layer = nn.Linear(self.hparams.hidden_size, self.hparams.output_size) + + @classmethod + def from_dataset( + cls, + dataset: TimeSeriesDataSet, + allowed_encoder_known_variable_names: List[str] = None, + **kwargs, + ): + """ + Create model from dataset. + + Args: + dataset: timeseries dataset + allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all + **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) + + Returns: + TemporalFusionTransformer + """ + # add maximum encoder length + # update defaults + new_kwargs = copy(kwargs) + new_kwargs["max_encoder_length"] = dataset.max_encoder_length + new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss())) + + # create class and return + return super().from_dataset( + dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs + ) + + def expand_static_context(self, context, timesteps): + """ + add time dimension to static context + """ + return context[:, None].expand(-1, timesteps, -1) + + def get_attention_mask(self, encoder_lengths: torch.LongTensor, decoder_lengths: torch.LongTensor): + """ + Returns causal mask to apply for self-attention layer. + """ + decoder_length = decoder_lengths.max() + if self.hparams.causal_attention: + # indices to which is attended + attend_step = torch.arange(decoder_length, device=self.device) + # indices for which is predicted + predict_step = torch.arange(0, decoder_length, device=self.device)[:, None] + # do not attend to steps to self or after prediction + decoder_mask = (attend_step >= predict_step).unsqueeze(0).expand(encoder_lengths.size(0), -1, -1) + else: + # there is value in attending to future forecasts if they are made with knowledge currently + # available + # one possibility is here to use a second attention layer for future attention (assuming different effects + # matter in the future than the past) + # or alternatively using the same layer but allowing forward attention - i.e. only + # masking out non-available data and self + decoder_mask = create_mask(decoder_length, decoder_lengths).unsqueeze(1).expand(-1, decoder_length, -1) + # do not attend to steps where data is padded + encoder_mask = create_mask(encoder_lengths.max(), encoder_lengths).unsqueeze(1).expand(-1, decoder_length, -1) + # combine masks along attended time - first encoder and then decoder + mask = torch.cat( + ( + encoder_mask, + decoder_mask, + ), + dim=2, + ) + return mask + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + input dimensions: n_samples x time x variables + """ + encoder_lengths = x["encoder_lengths"] + decoder_lengths = x["decoder_lengths"] + x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension + x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1) # concatenate in time dimension + timesteps = x_cont.size(1) # encode + decode length + max_encoder_length = int(encoder_lengths.max()) + input_vectors = self.input_embeddings(x_cat) + input_vectors.update( + { + name: x_cont[..., idx].unsqueeze(-1) + for idx, name in enumerate(self.hparams.x_reals) + if name in self.reals + } + ) + + # Embedding and variable selection + if len(self.static_variables) > 0: + # static embeddings will be constant over entire batch + static_embedding = {name: input_vectors[name][:, 0] for name in self.static_variables} + static_embedding, static_variable_selection = self.static_variable_selection(static_embedding) + else: + static_embedding = torch.zeros( + (x_cont.size(0), self.hparams.hidden_size), dtype=self.dtype, device=self.device + ) + static_variable_selection = torch.zeros((x_cont.size(0), 0), dtype=self.dtype, device=self.device) + + static_context_variable_selection = self.expand_static_context( + self.static_context_variable_selection(static_embedding), timesteps + ) + + embeddings_varying_encoder = { + name: input_vectors[name][:, :max_encoder_length] for name in self.encoder_variables + } + embeddings_varying_encoder, encoder_sparse_weights = self.encoder_variable_selection( + embeddings_varying_encoder, + static_context_variable_selection[:, :max_encoder_length], + ) + + embeddings_varying_decoder = { + name: input_vectors[name][:, max_encoder_length:] for name in self.decoder_variables # select decoder + } + embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection( + embeddings_varying_decoder, + static_context_variable_selection[:, max_encoder_length:], + ) + + # LSTM + # calculate initial state + input_hidden = self.static_context_initial_hidden_lstm(static_embedding).expand( + self.hparams.lstm_layers, -1, -1 + ) + input_cell = self.static_context_initial_cell_lstm(static_embedding).expand(self.hparams.lstm_layers, -1, -1) + + # run local encoder + encoder_output, (hidden, cell) = self.lstm_encoder( + embeddings_varying_encoder, (input_hidden, input_cell), lengths=encoder_lengths, enforce_sorted=False + ) + + # run local decoder + decoder_output, _ = self.lstm_decoder( + embeddings_varying_decoder, + (hidden, cell), + lengths=decoder_lengths, + enforce_sorted=False, + ) + + # skip connection over lstm + lstm_output_encoder = self.post_lstm_gate_encoder(encoder_output) + lstm_output_encoder = self.post_lstm_add_norm_encoder(lstm_output_encoder, embeddings_varying_encoder) + + lstm_output_decoder = self.post_lstm_gate_decoder(decoder_output) + lstm_output_decoder = self.post_lstm_add_norm_decoder(lstm_output_decoder, embeddings_varying_decoder) + + lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1) + + # static enrichment + static_context_enrichment = self.static_context_enrichment(static_embedding) + attn_input = self.static_enrichment( + lstm_output, self.expand_static_context(static_context_enrichment, timesteps) + ) + + # Attention + attn_output, attn_output_weights = self.multihead_attn( + q=attn_input[:, max_encoder_length:], # query only for predictions + k=attn_input, + v=attn_input, + mask=self.get_attention_mask(encoder_lengths=encoder_lengths, decoder_lengths=decoder_lengths), + ) + + # skip connection over attention + attn_output = self.post_attn_gate_norm(attn_output, attn_input[:, max_encoder_length:]) + + output = self.pos_wise_ff(attn_output) + + # skip connection over temporal fusion decoder (not LSTM decoder despite the LSTM output contains + # a skip from the variable selection network) + output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:]) + if self.n_targets > 1: # if to use multi-target architecture + output = [output_layer(output) for output_layer in self.output_layer] + else: + output = self.output_layer(output) + + return self.to_network_output( + prediction=self.transform_output(output, target_scale=x["target_scale"]), + encoder_attention=attn_output_weights[..., :max_encoder_length], + decoder_attention=attn_output_weights[..., max_encoder_length:], + static_variables=static_variable_selection, + encoder_variables=encoder_sparse_weights, + decoder_variables=decoder_sparse_weights, + decoder_lengths=decoder_lengths, + encoder_lengths=encoder_lengths, + ) + + def on_fit_end(self): + if self.log_interval > 0: + self.log_embeddings() + + def create_log(self, x, y, out, batch_idx, **kwargs): + log = super().create_log(x, y, out, batch_idx, **kwargs) + if self.log_interval > 0: + log["interpretation"] = self._log_interpretation(out) + return log + + def _log_interpretation(self, out): + # calculate interpretations etc for latter logging + interpretation = self.interpret_output( + detach(out), + reduction="sum", + attention_prediction_horizon=0, # attention only for first prediction horizon + ) + return interpretation + + def on_epoch_end(self, outputs): + """ + run at epoch end for training or validation + """ + if self.log_interval > 0 and not self.training: + self.log_interpretation(outputs) + + def interpret_output( + self, + out: Dict[str, torch.Tensor], + reduction: str = "none", + attention_prediction_horizon: int = 0, + ) -> Dict[str, torch.Tensor]: + """ + interpret output of model + + Args: + out: output as produced by ``forward()`` + reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for + normalizing by encode lengths + attention_prediction_horizon: which prediction horizon to use for attention + + Returns: + interpretations that can be plotted with ``plot_interpretation()`` + """ + # take attention and concatenate if a list to proper attention object + batch_size = len(out["decoder_attention"]) + if isinstance(out["decoder_attention"], (list, tuple)): + # start with decoder attention + # assume issue is in last dimension, we need to find max + max_last_dimension = max(x.size(-1) for x in out["decoder_attention"]) + first_elm = out["decoder_attention"][0] + # create new attention tensor into which we will scatter + decoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], max_last_dimension), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["decoder_attention"]): + decoder_length = out["decoder_lengths"][idx] + decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length] + else: + decoder_attention = out["decoder_attention"].clone() + decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"]) + decoder_attention[decoder_mask[..., None, None].expand_as(decoder_attention)] = float("nan") + + if isinstance(out["encoder_attention"], (list, tuple)): + # same game for encoder attention + # create new attention tensor into which we will scatter + first_elm = out["encoder_attention"][0] + encoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["encoder_attention"]): + encoder_length = out["encoder_lengths"][idx] + encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length :] = x[ + ..., :encoder_length + ] + else: + # roll encoder attention (so start last encoder value is on the right) + encoder_attention = out["encoder_attention"].clone() + shifts = encoder_attention.size(3) - out["encoder_lengths"] + new_index = ( + torch.arange(encoder_attention.size(3), device=encoder_attention.device)[None, None, None].expand_as( + encoder_attention + ) + - shifts[:, None, None, None] + ) % encoder_attention.size(3) + encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) + # expand encoder_attention to full size + if encoder_attention.size(-1) < self.hparams.max_encoder_length: + encoder_attention = torch.concat( + [ + torch.full( + ( + *encoder_attention.shape[:-1], + self.hparams.max_encoder_length - out["encoder_lengths"].max(), + ), + float("nan"), + dtype=encoder_attention.dtype, + device=encoder_attention.device, + ), + encoder_attention, + ], + dim=-1, + ) + + # combine attention vector + attention = torch.concat([encoder_attention, decoder_attention], dim=-1) + attention[attention < 1e-5] = float("nan") + + # histogram of decode and encode lengths + encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) + decoder_length_histogram = integer_histogram( + out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1) + ) + + # mask where decoder and encoder where not applied when averaging variable selection weights + encoder_variables = out["encoder_variables"].squeeze(-2).clone() + encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) + encoder_variables = encoder_variables.masked_fill(encode_mask.unsqueeze(-1), 0.0).sum(dim=1) + encoder_variables /= ( + out["encoder_lengths"] + .where(out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])) + .unsqueeze(-1) + ) + + decoder_variables = out["decoder_variables"].squeeze(-2).clone() + decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) + decoder_variables = decoder_variables.masked_fill(decode_mask.unsqueeze(-1), 0.0).sum(dim=1) + decoder_variables /= out["decoder_lengths"].unsqueeze(-1) + + # static variables need no masking + static_variables = out["static_variables"].squeeze(1) + # attention is batch x time x heads x time_to_attend + # average over heads + only keep prediction attention and attention on observed timesteps + attention = masked_op( + attention[ + :, attention_prediction_horizon, :, : self.hparams.max_encoder_length + attention_prediction_horizon + ], + op="mean", + dim=1, + ) + + if reduction != "none": # if to average over batches + static_variables = static_variables.sum(dim=0) + encoder_variables = encoder_variables.sum(dim=0) + decoder_variables = decoder_variables.sum(dim=0) + + attention = masked_op(attention, dim=0, op=reduction) + else: + attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze(-1) # renormalize + + interpretation = dict( + attention=attention.masked_fill(torch.isnan(attention), 0.0), + static_variables=static_variables, + encoder_variables=encoder_variables, + decoder_variables=decoder_variables, + encoder_length_histogram=encoder_length_histogram, + decoder_length_histogram=decoder_length_histogram, + ) + return interpretation + + def plot_prediction( + self, + x: Dict[str, torch.Tensor], + out: Dict[str, torch.Tensor], + idx: int, + plot_attention: bool = True, + add_loss_to_title: bool = False, + show_future_observed: bool = True, + ax=None, + **kwargs, + ): + """ + Plot actuals vs prediction and attention + + Args: + x (Dict[str, torch.Tensor]): network input + out (Dict[str, torch.Tensor]): network output + idx (int): sample index + plot_attention: if to plot attention on secondary axis + add_loss_to_title: if to add loss to title. Default to False. + show_future_observed: if to show actuals for future. Defaults to True. + ax: matplotlib axes to plot on + + Returns: + plt.Figure: matplotlib figure + """ + # plot prediction as normal + fig = super().plot_prediction( + x, + out, + idx=idx, + add_loss_to_title=add_loss_to_title, + show_future_observed=show_future_observed, + ax=ax, + **kwargs, + ) + + # add attention on secondary axis + if plot_attention: + interpretation = self.interpret_output(out.iget(slice(idx, idx + 1))) + for f in to_list(fig): + ax = f.axes[0] + ax2 = ax.twinx() + ax2.set_ylabel("Attention") + encoder_length = x["encoder_lengths"][0] + ax2.plot( + torch.arange(-encoder_length, 0), + interpretation["attention"][0, -encoder_length:].detach().cpu(), + alpha=0.2, + color="k", + ) + f.tight_layout() + return fig + + def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]): + """ + Make figures that interpret model. + + * Attention + * Variable selection weights / importances + + Args: + interpretation: as obtained from ``interpret_output()`` + + Returns: + dictionary of matplotlib figures + """ + _check_matplotlib("plot_interpretation") + + import matplotlib.pyplot as plt + + figs = {} + + # attention + fig, ax = plt.subplots() + attention = interpretation["attention"].detach().cpu() + attention = attention / attention.sum(-1).unsqueeze(-1) + ax.plot( + np.arange(-self.hparams.max_encoder_length, attention.size(0) - self.hparams.max_encoder_length), attention + ) + ax.set_xlabel("Time index") + ax.set_ylabel("Attention") + ax.set_title("Attention") + figs["attention"] = fig + + # variable selection + def make_selection_plot(title, values, labels): + fig, ax = plt.subplots(figsize=(7, len(values) * 0.25 + 2)) + order = np.argsort(values) + values = values / values.sum(-1).unsqueeze(-1) + ax.barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order]) + ax.set_title(title) + ax.set_xlabel("Importance in %") + plt.tight_layout() + return fig + + figs["static_variables"] = make_selection_plot( + "Static variables importance", interpretation["static_variables"].detach().cpu(), self.static_variables + ) + figs["encoder_variables"] = make_selection_plot( + "Encoder variables importance", interpretation["encoder_variables"].detach().cpu(), self.encoder_variables + ) + figs["decoder_variables"] = make_selection_plot( + "Decoder variables importance", interpretation["decoder_variables"].detach().cpu(), self.decoder_variables + ) + + return figs + + def log_interpretation(self, outputs): + """ + Log interpretation metrics to tensorboard. + """ + # extract interpretations + interpretation = { + # use padded_stack because decoder length histogram can be of different length + name: padded_stack([x["interpretation"][name].detach() for x in outputs], side="right", value=0).sum(0) + for name in outputs[0]["interpretation"].keys() + } + # normalize attention with length histogram squared to account for: 1. zeros in attention and + # 2. higher attention due to less values + attention_occurances = interpretation["encoder_length_histogram"][1:].flip(0).float().cumsum(0) + attention_occurances = attention_occurances / attention_occurances.max() + attention_occurances = torch.cat( + [ + attention_occurances, + torch.ones( + interpretation["attention"].size(0) - attention_occurances.size(0), + dtype=attention_occurances.dtype, + device=attention_occurances.device, + ), + ], + dim=0, + ) + interpretation["attention"] = interpretation["attention"] / attention_occurances.pow(2).clamp(1.0) + interpretation["attention"] = interpretation["attention"] / interpretation["attention"].sum() + + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + import matplotlib.pyplot as plt + + figs = self.plot_interpretation(interpretation) # make interpretation figures + label = self.current_stage + # log to tensorboard + for name, fig in figs.items(): + self.logger.experiment.add_figure( + f"{label.capitalize()} {name} importance", fig, global_step=self.global_step + ) + + # log lengths of encoder/decoder + for type in ["encoder", "decoder"]: + fig, ax = plt.subplots() + lengths = ( + padded_stack([out["interpretation"][f"{type}_length_histogram"] for out in outputs]) + .sum(0) + .detach() + .cpu() + ) + if type == "decoder": + start = 1 + else: + start = 0 + ax.plot(torch.arange(start, start + len(lengths)), lengths) + ax.set_xlabel(f"{type.capitalize()} length") + ax.set_ylabel("Number of samples") + ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") + + self.logger.experiment.add_figure( + f"{label.capitalize()} {type} length distribution", fig, global_step=self.global_step + ) + + def log_embeddings(self): + """ + Log embeddings to tensorboard + """ + + # Don't log embeddings if add_embedding is not available + if not self._logger_supports("add_embedding"): + return None + + for name, emb in self.input_embeddings.items(): + labels = self.hparams.embedding_labels[name] + self.logger.experiment.add_embedding( + emb.weight.data.detach().cpu(), metadata=labels, tag=name, global_step=self.global_step + )