Skip to content

Commit

Permalink
reduce default warnings (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Apr 17, 2024
1 parent 0af2734 commit 0c1a760
Show file tree
Hide file tree
Showing 18 changed files with 209 additions and 1,498 deletions.
24 changes: 18 additions & 6 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@
" **trainer_kwargs,\n",
" ):\n",
" super().__init__()\n",
" self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n",
" with warnings.catch_warnings(record=False):\n",
" warnings.filterwarnings('ignore')\n",
" # the following line issues a warning about the loss attribute being saved\n",
" # but we do want to save it\n",
" self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n",
" self.random_seed = random_seed\n",
" pl.seed_everything(self.random_seed, workers=True)\n",
"\n",
Expand Down Expand Up @@ -240,8 +244,10 @@
" )\n",
"\n",
" if self.val_check_steps > self.max_steps:\n",
" warnings.warn('val_check_steps is greater than max_steps, \\\n",
" setting val_check_steps to max_steps')\n",
" warnings.warn(\n",
" 'val_check_steps is greater than max_steps, '\n",
" 'setting val_check_steps to max_steps.'\n",
" )\n",
" val_check_interval = min(self.val_check_steps, self.max_steps)\n",
" self.trainer_kwargs['val_check_interval'] = int(val_check_interval)\n",
" self.trainer_kwargs['check_val_every_n_epoch'] = None\n",
Expand Down Expand Up @@ -355,9 +361,15 @@
" def on_validation_epoch_end(self):\n",
" if self.val_size == 0:\n",
" return\n",
" avg_loss = torch.stack(self.validation_step_outputs).mean()\n",
" self.log(\"ptl/val_loss\", avg_loss, sync_dist=True)\n",
" self.valid_trajectories.append((self.global_step, float(avg_loss)))\n",
" losses = torch.stack(self.validation_step_outputs)\n",
" avg_loss = losses.mean().item()\n",
" self.log(\n",
" \"ptl/val_loss\",\n",
" avg_loss,\n",
" batch_size=losses.size(0),\n",
" sync_dist=True,\n",
" )\n",
" self.valid_trajectories.append((self.global_step, avg_loss))\n",
" self.validation_step_outputs.clear() # free memory (compute `avg_loss` per epoch)\n",
"\n",
" def save(self, path):\n",
Expand Down
18 changes: 15 additions & 3 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,14 @@
" print('output', torch.isnan(output).sum())\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" self.log('train_loss', loss, prog_bar=True, on_epoch=True)\n",
" self.train_trajectories.append((self.global_step, float(loss)))\n",
" self.log(\n",
" 'train_loss',\n",
" loss.item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.train_trajectories.append((self.global_step, loss.item()))\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
Expand Down Expand Up @@ -428,7 +434,13 @@
" if torch.isnan(valid_loss):\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" self.log('valid_loss', valid_loss, prog_bar=True, on_epoch=True)\n",
" self.log(\n",
" 'valid_loss',\n",
" valid_loss.item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.validation_step_outputs.append(valid_loss)\n",
" return valid_loss\n",
"\n",
Expand Down
18 changes: 15 additions & 3 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,14 @@
" print('output', torch.isnan(output).sum())\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" self.log('train_loss', loss, batch_size=self.batch_size, prog_bar=True, on_epoch=True)\n",
" self.train_trajectories.append((self.global_step, float(loss)))\n",
" self.log(\n",
" 'train_loss',\n",
" loss.item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.train_trajectories.append((self.global_step, loss.item()))\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
Expand Down Expand Up @@ -425,7 +431,13 @@
" if torch.isnan(valid_loss):\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" self.log('valid_loss', valid_loss, batch_size=self.batch_size, prog_bar=True, on_epoch=True)\n",
" self.log(\n",
" 'valid_loss',\n",
" valid_loss.item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.validation_step_outputs.append(valid_loss)\n",
" return valid_loss\n",
"\n",
Expand Down
24 changes: 18 additions & 6 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,14 @@
" print('output', torch.isnan(output).sum())\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" self.log('train_loss', loss, prog_bar=True, on_epoch=True)\n",
" self.train_trajectories.append((self.global_step, float(loss)))\n",
" self.log(\n",
" 'train_loss',\n",
" loss.item(),\n",
" batch_size=outsample_y.size(0),\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.train_trajectories.append((self.global_step, loss.item()))\n",
" return loss\n",
"\n",
" def _compute_valid_loss(self, outsample_y, output, outsample_mask, temporal_cols, y_idx):\n",
Expand Down Expand Up @@ -513,14 +519,20 @@
" batch_sizes.append(len(output_batch))\n",
" \n",
" valid_loss = torch.stack(valid_losses)\n",
" batch_sizes = torch.tensor(batch_sizes).to(valid_loss.device)\n",
" valid_loss = torch.sum(valid_loss * batch_sizes) \\\n",
" / torch.sum(batch_sizes)\n",
" batch_sizes = torch.tensor(batch_sizes, device=valid_loss.device)\n",
" batch_size = torch.sum(batch_sizes)\n",
" valid_loss = torch.sum(valid_loss * batch_sizes) / batch_size\n",
"\n",
" if torch.isnan(valid_loss):\n",
" raise Exception('Loss is NaN, training stopped.')\n",
"\n",
" self.log('valid_loss', valid_loss, prog_bar=True, on_epoch=True)\n",
" self.log(\n",
" 'valid_loss',\n",
" valid_loss.item(),\n",
" batch_size=batch_size,\n",
" prog_bar=True,\n",
" on_epoch=True,\n",
" )\n",
" self.validation_step_outputs.append(valid_loss)\n",
" return valid_loss\n",
"\n",
Expand Down
5 changes: 5 additions & 0 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"import fsspec\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"import utilsforecast.processing as ufp\n",
"from coreforecast.grouped_array import GroupedArray\n",
Expand Down Expand Up @@ -102,6 +103,10 @@
"outputs": [],
"source": [
"#| exporti\n",
"# this disables warnings about the number of workers in the dataloaders\n",
"# which the user can't control\n",
"pl.disable_possible_user_warnings()\n",
"\n",
"def _insample_times(\n",
" times: np.ndarray,\n",
" uids: Series,\n",
Expand Down
Loading

0 comments on commit 0c1a760

Please sign in to comment.