Skip to content

Commit

Permalink
add TODO
Browse files Browse the repository at this point in the history
  • Loading branch information
gdevos010 committed Sep 23, 2024
1 parent 9bbd1c9 commit 3edce3a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions darts/models/forecasting/times_net_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __init__(
num_layers: int,
num_kernels: int,
top_k: int,
embed_type:str="fixed",
freq:str="h",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -124,7 +126,9 @@ def __init__(
self.output_dim = output_dim
self.nr_params = nr_params

self.embedding = DataEmbedding(input_dim, hidden_size, "fixed", "h", 0.1)
# embed_type and freq are placeholders and are not used until the futures
# covariate in the forward method are figured out
self.embedding = DataEmbedding(input_dim, hidden_size, embed_type=embed_type, freq=freq, dropout=0.1)

self.model = nn.ModuleList([
TimesBlock(
Expand All @@ -148,7 +152,7 @@ def forward(self, x_in: Tuple) -> torch.Tensor:
x, _ = x_in

# Embedding
x = self.embedding(x, None)
x = self.embedding(x, None) # TODO: future covariate would go here
x = self.predict_linear(x.transpose(1, 2)).transpose(1, 2)

# TimesNet
Expand Down

0 comments on commit 3edce3a

Please sign in to comment.