Skip to content

Commit

Permalink
[ENH] linting tide model (#1742)
Browse files Browse the repository at this point in the history
This PR lints the tide model files.
  • Loading branch information
fkiraly authored Dec 28, 2024
1 parent c07d28b commit 9661be4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 18 deletions.
48 changes: 36 additions & 12 deletions pytorch_forecasting/models/tide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,11 @@ def decoder_covariate_size(self) -> int:
Returns:
int: size of time-dependent covariates used by the decoder
"""
return len(set(self.hparams.time_varying_reals_decoder) - set(self.target_names)) + sum(
self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_decoder
return len(
set(self.hparams.time_varying_reals_decoder) - set(self.target_names)
) + sum(
self.embeddings.output_size[name]
for name in self.hparams.time_varying_categoricals_decoder
)

@property
Expand All @@ -174,8 +177,11 @@ def encoder_covariate_size(self) -> int:
Returns:
int: size of time-dependent covariates used by the encoder
"""
return len(set(self.hparams.time_varying_reals_encoder) - set(self.target_names)) + sum(
self.embeddings.output_size[name] for name in self.hparams.time_varying_categoricals_encoder
return len(
set(self.hparams.time_varying_reals_encoder) - set(self.target_names)
) + sum(
self.embeddings.output_size[name]
for name in self.hparams.time_varying_categoricals_encoder
)

@property
Expand All @@ -186,7 +192,8 @@ def static_size(self) -> int:
int: size of static covariates
"""
return len(self.hparams.static_reals) + sum(
self.embeddings.output_size[name] for name in self.hparams.static_categoricals
self.embeddings.output_size[name]
for name in self.hparams.static_categoricals
)

@classmethod
Expand Down Expand Up @@ -215,12 +222,19 @@ def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
dataset.max_prediction_length == dataset.min_prediction_length
), "only fixed prediction length is allowed, but max_prediction_length != min_prediction_length"

assert dataset.randomize_length is None, "length has to be fixed, but randomize_length is not None"
assert not dataset.add_relative_time_idx, "add_relative_time_idx has to be False"
assert (
dataset.randomize_length is None
), "length has to be fixed, but randomize_length is not None"
assert (
not dataset.add_relative_time_idx
), "add_relative_time_idx has to be False"

new_kwargs = copy(kwargs)
new_kwargs.update(
{"output_chunk_length": dataset.max_prediction_length, "input_chunk_length": dataset.max_encoder_length}
{
"output_chunk_length": dataset.max_prediction_length,
"input_chunk_length": dataset.max_encoder_length,
}
)
new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, MAE()))
# initialize class
Expand All @@ -246,7 +260,11 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if self.encoder_covariate_size > 0:
# encoder_features = self.extract_features(x, self.embeddings, period="encoder")
encoder_x_t = torch.concat(
[encoder_features[name] for name in self.encoder_variables if name not in self.target_names],
[
encoder_features[name]
for name in self.encoder_variables
if name not in self.target_names
],
dim=2,
)
input_vector = torch.concat((encoder_y, encoder_x_t), dim=2)
Expand All @@ -256,14 +274,20 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
input_vector = encoder_y

if self.decoder_covariate_size > 0:
decoder_features = self.extract_features(x, self.embeddings, period="decoder")
decoder_x_t = torch.concat([decoder_features[name] for name in self.decoder_variables], dim=2)
decoder_features = self.extract_features(
x, self.embeddings, period="decoder"
)
decoder_x_t = torch.concat(
[decoder_features[name] for name in self.decoder_variables], dim=2
)
else:
decoder_x_t = None

# statics
if self.static_size > 0:
x_s = torch.concat([encoder_features[name][:, 0] for name in self.static_variables], dim=1)
x_s = torch.concat(
[encoder_features[name][:, 0] for name in self.static_variables], dim=1
)
else:
x_s = None

Expand Down
28 changes: 22 additions & 6 deletions pytorch_forecasting/models/tide/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def __init__(
else:
historical_future_covariates_flat_dim = 0

encoder_dim = self.input_chunk_length * output_dim + historical_future_covariates_flat_dim + static_cov_dim
encoder_dim = (
self.input_chunk_length * output_dim
+ historical_future_covariates_flat_dim
+ static_cov_dim
)

self.encoders = nn.Sequential(
_ResidualBlock(
Expand Down Expand Up @@ -209,9 +213,13 @@ def __init__(
dropout=dropout,
)

self.lookback_skip = nn.Linear(self.input_chunk_length, self.output_chunk_length)
self.lookback_skip = nn.Linear(
self.input_chunk_length, self.output_chunk_length
)

def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> torch.Tensor:
def forward(
self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
) -> torch.Tensor:
"""TiDE model forward pass.
Parameters
----------
Expand Down Expand Up @@ -247,7 +255,9 @@ def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tor
)
if self.temporal_width_future:
# project input features across all input and output time steps
x_dynamic_future_covariates = self.future_cov_projection(x_dynamic_future_covariates)
x_dynamic_future_covariates = self.future_cov_projection(
x_dynamic_future_covariates
)
else:
x_dynamic_future_covariates = None

Expand All @@ -270,7 +280,11 @@ def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tor
# stack and temporally decode with future covariate last output steps
temporal_decoder_input = [
decoded,
(x_dynamic_future_covariates[:, -self.output_chunk_length :, :] if self.future_cov_dim > 0 else None),
(
x_dynamic_future_covariates[:, -self.output_chunk_length :, :]
if self.future_cov_dim > 0
else None
),
]
temporal_decoder_input = [t for t in temporal_decoder_input if t is not None]

Expand All @@ -283,7 +297,9 @@ def forward(self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tor
skip = self.lookback_skip(x_lookback.transpose(1, 2)).transpose(1, 2)

# add skip connection
y = temporal_decoded + skip.reshape_as(temporal_decoded) # skip.view(temporal_decoded.shape)
y = temporal_decoded + skip.reshape_as(
temporal_decoded
) # skip.view(temporal_decoded.shape)

y = y.view(-1, self.output_chunk_length, self.output_dim)
return y

0 comments on commit 9661be4

Please sign in to comment.