From f2c957ed4a8778b8e896457fcc48c96715543c3a Mon Sep 17 00:00:00 2001 From: mwamsojo Date: Thu, 19 Dec 2024 14:47:32 +0100 Subject: [PATCH] Feat: Introducing horizon weighting to the distribution losses similar to point losses like MAE. This has proven useful in our applications. Also cleaning notebook outputs. --- nbs/losses.pytorch.ipynb | 84 ++++++++++++++++++++++---------- neuralforecast/_modidx.py | 2 + neuralforecast/losses/pytorch.py | 30 +++++++++++- 3 files changed, 89 insertions(+), 27 deletions(-) diff --git a/nbs/losses.pytorch.ipynb b/nbs/losses.pytorch.ipynb index d8d333dd7..8840aaef5 100644 --- a/nbs/losses.pytorch.ipynb +++ b/nbs/losses.pytorch.ipynb @@ -2475,7 +2475,7 @@ "\n", " \"\"\"\n", " def __init__(self, distribution, level=[80, 90], quantiles=None,\n", - " num_samples=1000, return_params=False, **distribution_kwargs):\n", + " num_samples=1000, return_params=False, horizon_weight = None, **distribution_kwargs):\n", " super(DistributionLoss, self).__init__()\n", "\n", " qs, self.output_names = level_to_outputs(level)\n", @@ -2489,6 +2489,12 @@ " self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n", " num_qk = len(self.quantiles)\n", "\n", + " # Generate a horizon weight tensor from the array\n", + " if horizon_weight is not None:\n", + " horizon_weight = torch.Tensor(horizon_weight.flatten())\n", + " self.horizon_weight = horizon_weight\n", + "\n", + "\n", " if \"num_pieces\" not in distribution_kwargs:\n", " num_pieces = 5\n", " else:\n", @@ -2610,36 +2616,62 @@ "\n", " return samples, sample_mean, quants\n", "\n", - " def __call__(self,\n", - " y: torch.Tensor,\n", - " distr_args: torch.Tensor,\n", - " mask: Union[torch.Tensor, None] = None):\n", + "\n", + "\n", + " def _compute_weights(self, y, mask):\n", + " \"\"\"\n", + " Compute final weights for each datapoint (based on all weights and all masks)\n", + " Set horizon_weight to a ones[H] tensor if not set.\n", + " If set, check that it has the same length as the horizon in x.\n", " \"\"\"\n", - " Computes the negative log-likelihood objective function. \n", - " To estimate the following predictive distribution:\n", + " if mask is None:\n", + " mask = torch.ones_like(y, device=y.device)\n", + " else:\n", + " mask = mask.unsqueeze(1) # Add Q dimension.\n", "\n", - " $$\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta) \\\\quad \\mathrm{and} \\\\quad -\\log(\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta))$$\n", "\n", - " where $\\\\theta$ represents the distributions parameters. It aditionally \n", - " summarizes the objective signal using a weighted average using the `mask` tensor. \n", + " # get uniform weights if none\n", + " if self.horizon_weight is None:\n", + " self.horizon_weight = torch.ones(mask.shape[-1])\n", + " else:\n", + " assert mask.shape[-1] == len(self.horizon_weight), \\\n", + " 'horizon_weight must have same length as Y'\n", + " weights = self.horizon_weight.clone()\n", + " weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)\n", + " return weights * mask\n", + " \n", "\n", - " **Parameters**
\n", - " `y`: tensor, Actual values.
\n", - " `distr_args`: Constructor arguments for the underlying Distribution type.
\n", - " `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n", - " of the resulting distribution.
\n", - " `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n", - " of the resulting distribution.
\n", - " `mask`: tensor, Specifies date stamps per serie to consider in loss.
\n", "\n", - " **Returns**
\n", - " `loss`: scalar, weighted loss function against which backpropagation will be performed.
\n", - " \"\"\"\n", - " # Instantiate Scaled Decoupled Distribution\n", - " distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n", - " loss_values = -distr.log_prob(y)\n", - " loss_weights = mask\n", - " return weighted_average(loss_values, weights=loss_weights)" + " def __call__(self,\n", + " y: torch.Tensor,\n", + " distr_args: torch.Tensor,\n", + " mask: Union[torch.Tensor, None] = None):\n", + " \"\"\"\n", + " Computes the negative log-likelihood objective function. \n", + " To estimate the following predictive distribution:\n", + "\n", + " $$\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta) \\\\quad \\mathrm{and} \\\\quad -\\log(\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta))$$\n", + "\n", + " where $\\\\theta$ represents the distributions parameters. It aditionally \n", + " summarizes the objective signal using a weighted average using the `mask` tensor. \n", + " \n", + " **Parameters**
\n", + " `y`: tensor, Actual values.
\n", + " `distr_args`: Constructor arguments for the underlying Distribution type.
\n", + " `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n", + " of the resulting distribution.
\n", + " `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n", + " of the resulting distribution.
\n", + " `mask`: tensor, Specifies date stamps per serie to consider in loss.
\n", + "\n", + " **Returns**
\n", + " `loss`: scalar, weighted loss function against which backpropagation will be performed.
\n", + " \"\"\"\n", + " # Instantiate Scaled Decoupled Distribution\n", + " distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n", + " loss_values = -distr.log_prob(y)\n", + " loss_weights = self._compute_weights(y=y, mask=mask)\n", + " return weighted_average(loss_values, weights=loss_weights)" ] }, { diff --git a/neuralforecast/_modidx.py b/neuralforecast/_modidx.py index 25f008ce4..7f72d353d 100644 --- a/neuralforecast/_modidx.py +++ b/neuralforecast/_modidx.py @@ -284,6 +284,8 @@ 'neuralforecast/losses/pytorch.py'), 'neuralforecast.losses.pytorch.DistributionLoss.__init__': ( 'losses.pytorch.html#distributionloss.__init__', 'neuralforecast/losses/pytorch.py'), + 'neuralforecast.losses.pytorch.DistributionLoss._compute_weights': ( 'losses.pytorch.html#distributionloss._compute_weights', + 'neuralforecast/losses/pytorch.py'), 'neuralforecast.losses.pytorch.DistributionLoss.get_distribution': ( 'losses.pytorch.html#distributionloss.get_distribution', 'neuralforecast/losses/pytorch.py'), 'neuralforecast.losses.pytorch.DistributionLoss.sample': ( 'losses.pytorch.html#distributionloss.sample', diff --git a/neuralforecast/losses/pytorch.py b/neuralforecast/losses/pytorch.py index a713b5b31..53e61ebfc 100644 --- a/neuralforecast/losses/pytorch.py +++ b/neuralforecast/losses/pytorch.py @@ -1873,6 +1873,7 @@ def __init__( quantiles=None, num_samples=1000, return_params=False, + horizon_weight=None, **distribution_kwargs, ): super(DistributionLoss, self).__init__() @@ -1888,6 +1889,11 @@ def __init__( self.quantiles = torch.nn.Parameter(qs, requires_grad=False) num_qk = len(self.quantiles) + # Generate a horizon weight tensor from the array + if horizon_weight is not None: + horizon_weight = torch.Tensor(horizon_weight.flatten()) + self.horizon_weight = horizon_weight + if "num_pieces" not in distribution_kwargs: num_pieces = 5 else: @@ -2011,6 +2017,28 @@ def sample(self, distr_args: torch.Tensor, num_samples: Optional[int] = None): return samples, sample_mean, quants + def _compute_weights(self, y, mask): + """ + Compute final weights for each datapoint (based on all weights and all masks) + Set horizon_weight to a ones[H] tensor if not set. + If set, check that it has the same length as the horizon in x. + """ + if mask is None: + mask = torch.ones_like(y, device=y.device) + else: + mask = mask.unsqueeze(1) # Add Q dimension. + + # get uniform weights if none + if self.horizon_weight is None: + self.horizon_weight = torch.ones(mask.shape[-1]) + else: + assert mask.shape[-1] == len( + self.horizon_weight + ), "horizon_weight must have same length as Y" + weights = self.horizon_weight.clone() + weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device) + return weights * mask + def __call__( self, y: torch.Tensor, @@ -2041,7 +2069,7 @@ def __call__( # Instantiate Scaled Decoupled Distribution distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs) loss_values = -distr.log_prob(y) - loss_weights = mask + loss_weights = self._compute_weights(y=y, mask=mask) return weighted_average(loss_values, weights=loss_weights) # %% ../../nbs/losses.pytorch.ipynb 74