Skip to content

Commit

Permalink
Remove old interface and deprecate the arguments
Browse files Browse the repository at this point in the history
nbdev_clean --clear_all
  • Loading branch information
JQGoh committed Dec 11, 2024
1 parent c0a7eb8 commit eca5d9b
Show file tree
Hide file tree
Showing 73 changed files with 63 additions and 1,075 deletions.
75 changes: 26 additions & 49 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"import random\n",
"import warnings\n",
"from contextlib import contextmanager\n",
"from copy import deepcopy\n",
"from dataclasses import dataclass\n",
"\n",
"import fsspec\n",
Expand Down Expand Up @@ -121,10 +120,6 @@
" random_seed,\n",
" loss,\n",
" valid_loss,\n",
" optimizer,\n",
" optimizer_kwargs,\n",
" lr_scheduler,\n",
" lr_scheduler_kwargs,\n",
" futr_exog_list,\n",
" hist_exog_list,\n",
" stat_exog_list,\n",
Expand All @@ -150,18 +145,6 @@
" self.train_trajectories = []\n",
" self.valid_trajectories = []\n",
"\n",
" # Optimization\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
"\n",
" # lr scheduler\n",
" if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
" self.lr_scheduler = lr_scheduler\n",
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
"\n",
" # customized by set_configure_optimizers()\n",
" self.config_optimizers = None\n",
"\n",
Expand Down Expand Up @@ -412,41 +395,19 @@
"\n",
" def configure_optimizers(self):\n",
" if self.config_optimizers is not None:\n",
" # return the customized optimizer settings if specified\n",
" return self.config_optimizers\n",
" \n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
" else:\n",
" if self.optimizer_kwargs:\n",
" warnings.warn(\n",
" \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
" )\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" \n",
" lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
" if self.lr_scheduler:\n",
" lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
" lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
" if 'optimizer' in lr_scheduler_signature.parameters:\n",
" if 'optimizer' in lr_scheduler_kwargs:\n",
" warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
" del lr_scheduler_kwargs['optimizer']\n",
" lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
" else:\n",
" if self.lr_scheduler_kwargs:\n",
" warnings.warn(\n",
" \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
" ) \n",
" lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
" # default choice\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {\n",
" \"scheduler\": torch.optim.lr_scheduler.StepLR(\n",
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
" )\n",
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
" ),\n",
" \"frequency\": 1,\n",
" \"interval\": \"step\",\n",
" }\n",
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
"\n",
" def set_configure_optimizers(\n",
" self, \n",
Expand Down Expand Up @@ -528,6 +489,22 @@
" model.load_state_dict(content[\"state_dict\"], strict=True)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "077ea025",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b36e87a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
10 changes: 1 addition & 9 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,12 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
" valid_loss=valid_loss, \n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
8 changes: 0 additions & 8 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,12 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
8 changes: 0 additions & 8 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,12 @@
" drop_last_loader=False,\n",
" random_seed=1,\n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
178 changes: 25 additions & 153 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3172,15 +3172,22 @@
" mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
"\n",
" # using a customized optimizer\n",
" params.update({\n",
" \"optimizer\": torch.optim.Adadelta,\n",
" \"optimizer_kwargs\": {\"rho\": 0.45}, \n",
" })\n",
" optimizer = torch.optim.Adadelta(params=models2[0].parameters(), rho=0.75)\n",
" scheduler=torch.optim.lr_scheduler.StepLR(\n",
" optimizer=optimizer, step_size=10e7, gamma=0.5\n",
" )\n",
"\n",
" models2 = [nf_model(**params)]\n",
" models2[0].set_configure_optimizers(\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
" )\n",
"\n",
" nf2 = NeuralForecast(models=models2, freq='M')\n",
" nf2.fit(AirPassengersPanel_train)\n",
" customized_optimizer_predict = nf2.predict()\n",
" mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
"\n",
" assert mean2 != mean"
]
},
Expand All @@ -3194,100 +3201,18 @@
"#| hide\n",
"# test that if the user-defined optimizer is not a subclass of torch.optim.optimizer, failed with exception\n",
"# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
"test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
"test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, optimizer=torch.nn.Module), contains=\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d908240f",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if we pass \"lr\" parameter, we expect warning and it ignores the passed in 'lr' parameter\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24, \n",
" \"max_steps\": 1, \n",
" \"optimizer\": torch.optim.Adadelta, \n",
" \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n",
" }\n",
"for model_name in [NHITS, RNN, StemGNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 10}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', UserWarning)\n",
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\" in str(w.message) for w in issued_warnings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c97858b5-e6a0-4353-a48f-5a5460eb2314",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if we pass \"optimizer_kwargs\" but not \"optimizer\", we expect a warning\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24, \n",
" \"max_steps\": 1,\n",
" \"optimizer_kwargs\": {\"lr\": 0.8, \"rho\": 0.45}\n",
" }\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', UserWarning)\n",
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring optimizer_kwargs as the optimizer is not specified\" in str(w.message) for w in issued_warnings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24142322",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test customized lr_scheduler behavior such that the user defined lr_scheduler result should differ from default\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 1}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" nf.fit(AirPassengersPanel_train)\n",
" default_optimizer_predict = nf.predict()\n",
" mean = default_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
"\n",
" # using a customized lr_scheduler, default is StepLR\n",
" params.update({\n",
" \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR,\n",
" \"lr_scheduler_kwargs\": {\"factor\": 0.78}, \n",
" })\n",
" models2 = [nf_model(**params)]\n",
" nf2 = NeuralForecast(models=models2, freq='M')\n",
" nf2.fit(AirPassengersPanel_train)\n",
" customized_optimizer_predict = nf2.predict()\n",
" mean2 = customized_optimizer_predict.loc[:, nf_model.__name__].mean()\n",
" assert mean2 != mean"
" model = model_name(**params) \n",
" optimizer = torch.nn.Module()\n",
" scheduler = torch.optim.lr_scheduler.StepLR(\n",
" optimizer=torch.optim.Adam(model.parameters()), step_size=10e7, gamma=0.5\n",
" ) \n",
" test_fail(lambda: model.set_configure_optimizers(optimizer=optimizer, scheduler=scheduler), contains=\"optimizer is not a valid instance of torch.optim.Optimizer\")\n"
]
},
{
Expand All @@ -3298,68 +3223,16 @@
"outputs": [],
"source": [
"#| hide\n",
"# test that if the user-defined lr_scheduler is not a subclass of torch.optim.lr_scheduler, failed with exception\n",
"# test that if the user-defined scheduler is not a subclass of torch.optim.lr_scheduler, failed with exception\n",
"# tests cover different types of base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"test_fail(lambda: NHITS(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
"test_fail(lambda: RNN(h=12, input_size=24, max_steps=10, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
"test_fail(lambda: StemGNN(h=12, input_size=24, max_steps=10, n_series=2, lr_scheduler=torch.nn.Module), contains=\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1d8bebb",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if we pass in \"optimizer\" parameter, we expect warning and it ignores them\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24, \n",
" \"max_steps\": 1, \n",
" \"lr_scheduler\": torch.optim.lr_scheduler.ConstantLR, \n",
" \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n",
" }\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', UserWarning)\n",
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\" in str(w.message) for w in issued_warnings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06febece",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test that if we pass in \"lr_scheduler_kwargs\" but not \"lr_scheduler\", we expect a warning\n",
"# tests consider models implemented using different base classes such as BaseWindows, BaseRecurrent, BaseMultivariate\n",
"\n",
"for nf_model in [NHITS, RNN, StemGNN]:\n",
" params = {\n",
" \"h\": 12, \n",
" \"input_size\": 24, \n",
" \"max_steps\": 1,\n",
" \"lr_scheduler_kwargs\": {\"optimizer\": torch.optim.Adadelta, \"factor\": 0.22}\n",
" }\n",
"for model_name in [NHITS, RNN, StemGNN]:\n",
" params = {\"h\": 12, \"input_size\": 24, \"max_steps\": 10}\n",
" if nf_model.__name__ == \"StemGNN\":\n",
" params.update({\"n_series\": 2})\n",
" models = [nf_model(**params)]\n",
" nf = NeuralForecast(models=models, freq='M')\n",
" with warnings.catch_warnings(record=True) as issued_warnings:\n",
" warnings.simplefilter('always', UserWarning)\n",
" nf.fit(AirPassengersPanel_train)\n",
" assert any(\"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\" in str(w.message) for w in issued_warnings)\n"
" model = model_name(**params)\n",
" optimizer = torch.optim.Adam(model.parameters())\n",
" test_fail(lambda: model.set_configure_optimizers(optimizer=optimizer, scheduler=torch.nn.Module), contains=\"scheduler is not a valid instance of torch.optim.lr_scheduler.LRScheduler\")"
]
},
{
Expand Down Expand Up @@ -3493,7 +3366,6 @@
" models[0].set_configure_optimizers(\n",
" optimizer=optimizer,\n",
" scheduler=scheduler,\n",
"\n",
" )\n",
" nf2 = NeuralForecast(models=models, freq='M')\n",
" nf2.fit(AirPassengersPanel_train)\n",
Expand Down
Loading

0 comments on commit eca5d9b

Please sign in to comment.