Skip to content

Commit

Permalink
add_back_item()
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Sep 27, 2024
1 parent 536e8f6 commit e08ed44
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 18 deletions.
6 changes: 3 additions & 3 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,12 @@
"\n",
" self.log(\n",
" 'train_loss',\n",
" loss.detach(),\n",
" loss.detach().item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.train_trajectories.append((self.global_step, loss.detach()))\n",
" self.train_trajectories.append((self.global_step, loss.detach().item()))\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
Expand Down Expand Up @@ -440,7 +440,7 @@
"\n",
" self.log(\n",
" 'valid_loss',\n",
" valid_loss.detach(),\n",
" valid_loss.detach().item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
Expand Down
6 changes: 3 additions & 3 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,12 @@
"\n",
" self.log(\n",
" 'train_loss',\n",
" loss.detach(),\n",
" loss.detach().item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.train_trajectories.append((self.global_step, loss.detach()))\n",
" self.train_trajectories.append((self.global_step, loss.detach().item()))\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
Expand Down Expand Up @@ -438,7 +438,7 @@
"\n",
" self.log(\n",
" 'valid_loss',\n",
" valid_loss.detach(),\n",
" valid_loss.detach().item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
Expand Down
6 changes: 3 additions & 3 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,12 @@
"\n",
" self.log(\n",
" 'train_loss',\n",
" loss.detach(),\n",
" loss.detach().item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.train_trajectories.append((self.global_step, loss.detach()))\n",
" self.train_trajectories.append((self.global_step, loss.detach().item()))\n",
" return loss\n",
"\n",
" def _compute_valid_loss(self, outsample_y, output, outsample_mask, temporal_cols, y_idx):\n",
Expand Down Expand Up @@ -533,7 +533,7 @@
"\n",
" self.log(\n",
" 'valid_loss',\n",
" valid_loss.detach(),\n",
" valid_loss.detach().item(),\n",
" batch_size=batch_size,\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
Expand Down
6 changes: 3 additions & 3 deletions neuralforecast/common/_base_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,12 @@ def training_step(self, batch, batch_idx):

self.log(
"train_loss",
loss.detach(),
loss.detach().item(),
batch_size=outsample_y.size(0),
prog_bar=True,
on_epoch=True,
)
self.train_trajectories.append((self.global_step, loss.detach()))
self.train_trajectories.append((self.global_step, loss.detach().item()))
return loss

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -456,7 +456,7 @@ def validation_step(self, batch, batch_idx):

self.log(
"valid_loss",
valid_loss.detach(),
valid_loss.detach().item(),
batch_size=outsample_y.size(0),
prog_bar=True,
on_epoch=True,
Expand Down
6 changes: 3 additions & 3 deletions neuralforecast/common/_base_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,12 @@ def training_step(self, batch, batch_idx):

self.log(
"train_loss",
loss.detach(),
loss.detach().item(),
batch_size=outsample_y.size(0),
prog_bar=True,
on_epoch=True,
)
self.train_trajectories.append((self.global_step, loss.detach()))
self.train_trajectories.append((self.global_step, loss.detach().item()))
return loss

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -447,7 +447,7 @@ def validation_step(self, batch, batch_idx):

self.log(
"valid_loss",
valid_loss.detach(),
valid_loss.detach().item(),
batch_size=outsample_y.size(0),
prog_bar=True,
on_epoch=True,
Expand Down
6 changes: 3 additions & 3 deletions neuralforecast/common/_base_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,12 @@ def training_step(self, batch, batch_idx):

self.log(
"train_loss",
loss.detach(),
loss.detach().item(),
batch_size=outsample_y.size(0),
prog_bar=True,
on_epoch=True,
)
self.train_trajectories.append((self.global_step, loss.detach()))
self.train_trajectories.append((self.global_step, loss.detach().item()))
return loss

def _compute_valid_loss(
Expand Down Expand Up @@ -551,7 +551,7 @@ def validation_step(self, batch, batch_idx):

self.log(
"valid_loss",
valid_loss.detach(),
valid_loss.detach().item(),
batch_size=batch_size,
prog_bar=True,
on_epoch=True,
Expand Down

0 comments on commit e08ed44

Please sign in to comment.