Skip to content

Commit

Permalink
fix_mlp_multivariate_static_input
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Oct 2, 2024
1 parent 8d378c6 commit ec260e7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
8 changes: 7 additions & 1 deletion nbs/models.mlpmultivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@
" x = torch.cat(( x, futr_exog.reshape(batch_size, -1) ), dim=1)\n",
"\n",
" if self.stat_exog_size > 0:\n",
" x = torch.cat(( x, stat_exog.reshape(batch_size, -1) ), dim=1)\n",
" stat_exog = stat_exog.reshape(-1) # [N, S] -> [N * S]\n",
" stat_exog = stat_exog.unsqueeze(0)\\\n",
" .repeat(batch_size, \n",
" 1) # [N * S] -> [B, N * S] \n",
" x = torch.cat((x, stat_exog), dim=1)\n",
"\n",
" for layer in self.mlp:\n",
" x = torch.relu(layer(x))\n",
Expand Down Expand Up @@ -362,6 +366,8 @@
"model = MLPMultivariate(h=12, \n",
" input_size=24,\n",
" n_series=2,\n",
" stat_exog_list=['airline1'],\n",
" futr_exog_list=['trend'], \n",
" loss = MAE(),\n",
" scaler_type='robust',\n",
" learning_rate=1e-3,\n",
Expand Down
6 changes: 5 additions & 1 deletion neuralforecast/models/mlpmultivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def forward(self, windows_batch):
x = torch.cat((x, futr_exog.reshape(batch_size, -1)), dim=1)

if self.stat_exog_size > 0:
x = torch.cat((x, stat_exog.reshape(batch_size, -1)), dim=1)
stat_exog = stat_exog.reshape(-1) # [N, S] -> [N * S]
stat_exog = stat_exog.unsqueeze(0).repeat(
batch_size, 1
) # [N * S] -> [B, N * S]
x = torch.cat((x, stat_exog), dim=1)

for layer in self.mlp:
x = torch.relu(layer(x))
Expand Down

0 comments on commit ec260e7

Please sign in to comment.