diff --git a/pytorch_forecasting/models/tide/__init__.py b/pytorch_forecasting/models/tide/__init__.py index b2eb7365..b8556fbe 100644 --- a/pytorch_forecasting/models/tide/__init__.py +++ b/pytorch_forecasting/models/tide/__init__.py @@ -163,8 +163,11 @@ def decoder_covariate_size(self) -> int: Returns: int: size of time-dependent covariates used by the decoder """ - return len(set(self.hparams.time_varying_reals_decoder) - set(self.target_names)) + sum( - self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder + return len( + set(self.hparams.time_varying_reals_decoder) - set(self.target_names) + ) + sum( + self.embeddings.output_size[name] + for name in self.hparams.time_varying_categoricals_decoder ) @property @@ -174,8 +177,11 @@ def encoder_covariate_size(self) -> int: Returns: int: size of time-dependent covariates used by the encoder """ - return len(set(self.hparams.time_varying_reals_encoder) - set(self.target_names)) + sum( - self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder + return len( + set(self.hparams.time_varying_reals_encoder) - set(self.target_names) + ) + sum( + self.embeddings.output_size[name] + for name in self.hparams.time_varying_categoricals_encoder ) @property @@ -186,7 +192,8 @@ def static_size(self) -> int: int: size of static covariates """ return len(self.hparams.static_reals) + sum( - self.embeddings.output_size[name] for name in self.hparams.static_categoricals + self.embeddings.output_size[name] + for name in self.hparams.static_categoricals ) @classmethod @@ -215,12 +222,19 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): dataset.max_prediction_length == dataset.min_prediction_length ), "only fixed prediction length is allowed, but max_prediction_length != min_prediction_length" - assert dataset.randomize_length is None, "length has to be fixed, but randomize_length is not None" - assert not dataset.add_relative_time_idx, "add_relative_time_idx has to be False" + assert ( + dataset.randomize_length is None + ), "length has to be fixed, but randomize_length is not None" + assert ( + not dataset.add_relative_time_idx + ), "add_relative_time_idx has to be False" new_kwargs = copy(kwargs) new_kwargs.update( - {"output_chunk_length": dataset.max_prediction_length, "input_chunk_length": dataset.max_encoder_length} + { + "output_chunk_length": dataset.max_prediction_length, + "input_chunk_length": dataset.max_encoder_length, + } ) new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, MAE())) # initialize class @@ -246,7 +260,11 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self.encoder_covariate_size > 0: # encoder_features = self.extract_features(x, self.embeddings, period="encoder") encoder_x_t = torch.concat( - [encoder_features[name] for name in self.encoder_variables if name not in self.target_names], + [ + encoder_features[name] + for name in self.encoder_variables + if name not in self.target_names + ], dim=2, ) input_vector = torch.concat((encoder_y, encoder_x_t), dim=2) @@ -256,14 +274,20 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: input_vector = encoder_y if self.decoder_covariate_size > 0: - decoder_features = self.extract_features(x, self.embeddings, period="decoder") - decoder_x_t = torch.concat([decoder_features[name] for name in self.decoder_variables], dim=2) + decoder_features = self.extract_features( + x, self.embeddings, period="decoder" + ) + decoder_x_t = torch.concat( + [decoder_features[name] for name in self.decoder_variables], dim=2 + ) else: decoder_x_t = None # statics if self.static_size > 0: - x_s = torch.concat([encoder_features[name][:, 0] for name in self.static_variables], dim=1) + x_s = torch.concat( + [encoder_features[name][:, 0] for name in self.static_variables], dim=1 + ) else: x_s = None diff --git a/pytorch_forecasting/models/tide/sub_modules.py b/pytorch_forecasting/models/tide/sub_modules.py index da1316c5..75097cd9 100644 --- a/pytorch_forecasting/models/tide/sub_modules.py +++ b/pytorch_forecasting/models/tide/sub_modules.py @@ -152,7 +152,11 @@ def __init__( else: historical_future_covariates_flat_dim = 0 - encoder_dim = self.input_chunk_length * output_dim + historical_future_covariates_flat_dim + static_cov_dim + encoder_dim = ( + self.input_chunk_length * output_dim + + historical_future_covariates_flat_dim + + static_cov_dim + ) self.encoders = nn.Sequential( _ResidualBlock( @@ -209,9 +213,13 @@ def __init__( dropout=dropout, ) - self.lookback_skip = nn.Linear(self.input_chunk_length, self.output_chunk_length) + self.lookback_skip = nn.Linear( + self.input_chunk_length, self.output_chunk_length + ) - def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> torch.Tensor: + def forward( + self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] + ) -> torch.Tensor: """TiDE model forward pass. Parameters ---------- @@ -247,7 +255,9 @@ def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tor ) if self.temporal_width_future: # project input features across all input and output time steps - x_dynamic_future_covariates = self.future_cov_projection(x_dynamic_future_covariates) + x_dynamic_future_covariates = self.future_cov_projection( + x_dynamic_future_covariates + ) else: x_dynamic_future_covariates = None @@ -270,7 +280,11 @@ def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tor # stack and temporally decode with future covariate last output steps temporal_decoder_input = [ decoded, - (x_dynamic_future_covariates[:, -self.output_chunk_length :, :] if self.future_cov_dim > 0 else None), + ( + x_dynamic_future_covariates[:, -self.output_chunk_length :, :] + if self.future_cov_dim > 0 + else None + ), ] temporal_decoder_input = [t for t in temporal_decoder_input if t is not None] @@ -283,7 +297,9 @@ def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tor skip = self.lookback_skip(x_lookback.transpose(1, 2)).transpose(1, 2) # add skip connection - y = temporal_decoded + skip.reshape_as(temporal_decoded) # skip.view(temporal_decoded.shape) + y = temporal_decoded + skip.reshape_as( + temporal_decoded + ) # skip.view(temporal_decoded.shape) y = y.view(-1, self.output_chunk_length, self.output_dim) return y