From 47e3fbfcf1706185d6a1e7432123ba5085de65c6 Mon Sep 17 00:00:00 2001 From: "Adam M. Krajewski" <54290107+amkrajewski@users.noreply.github.com> Date: Wed, 27 Mar 2024 21:13:38 -0400 Subject: [PATCH] (MA) implemented a hyperparameter search matrix for 27 common options; meant mostly for tuning to smaller datasets --- pysipfenn/core/modelAdjusters.py | 161 ++++++++++++++++++++++++++++++- 1 file changed, 159 insertions(+), 2 deletions(-) diff --git a/pysipfenn/core/modelAdjusters.py b/pysipfenn/core/modelAdjusters.py index 0198f26..c51c2e8 100644 --- a/pysipfenn/core/modelAdjusters.py +++ b/pysipfenn/core/modelAdjusters.py @@ -1,5 +1,5 @@ import os -from typing import Union, Literal, Tuple, List +from typing import Union, Literal, Tuple, List, Dict from copy import deepcopy import gc @@ -7,6 +7,7 @@ import torch from torch.utils.data import DataLoader, TensorDataset import plotly.express as px +import plotly.graph_objects as go from pysipfenn.core.pysipfenn import Calculator class LocalAdjuster: @@ -257,7 +258,6 @@ def adjust( if verbose: print(f'Train: {transferLosses[-1]:.4f} | Epoch: 0/{epochs}') - for epoch in range(epochs): model.train() for data, target in dataloaderTrain: @@ -305,7 +305,160 @@ def adjust( return self.adjustedModel, transferLosses, validationLosses + def matrixHyperParameterSearch( + self, + validation: float = 0.2, + epochs: int = 100, + batchSize: int = 32, + lossFunction: Literal["MSE", "MAE"] = "MAE", + learningRates: Tuple[float] = (1e-6, 1e-5, 1e-4), + optimizers: Tuple[Literal["Adam", "AdamW", "Adamax", "RMSprop"]] = ("Adam", "AdamW", "Adamax"), + weightDecays: Tuple[float] = (1e-5, 1e-4, 1e-3), + verbose: bool = True, + plot: bool = True + ) -> Tuple[torch.nn.Module, Dict[str, Union[float, str]]]: + """ + Performs a grid search over the hyperparameters provided to find the best combination. By default, it will + plot the training history with plotly in your browser, and (b) print the best hyperparameters found. If the + ClearML platform was set to be used for logging (at the class initialization), the results will be uploaded + there as well. If the default values are used, it will test 27 combinations of learning rates, optimizers, and + weight decays. The method will then adjust the model to the best hyperparameters found, corresponding to the + lowest validation loss if validation is used, or the lowest training loss if validation is not used + (``validation=0``). Note that the validation is used by default. + + Args: + validation: Same as in the ``adjust`` method. Default is ``0.2``. + epochs: Same as in the ``adjust`` method. Default is ``100``. + batchSize: Same as in the ``adjust`` method. Default is ``32``. + lossFunction: Same as in the ``adjust`` method. Default is ``MAE``, i.e. Mean Absolute Error or L1 loss. + learningRates: Tuple of floats with the learning rates to be tested. Default is ``(1e-6, 1e-5, 1e-4)``. See + the ``adjust`` method for more information. + optimizers: Tuple of strings with the optimizers to be tested. Default is ``("Adam", "AdamW", "Adamax")``. See + the ``adjust`` method for more information. + weightDecays: Tuple of floats with the weight decays to be tested. Default is ``(1e-5, 1e-4, 1e-3)``. See + the ``adjust`` method for more information. + verbose: Same as in the ``adjust`` method. Default is ``True``. + plot: Whether to plot the training history after all the combinations are tested. Default is ``True``. + """ + if verbose: + print("Starting the hyperparameter search...") + + bestModel: torch.nn.Module = None + bestTrainingLoss: float = np.inf + bestValidationLoss: float = np.inf + bestHyperparameters: Dict[str, Union[float, str, None]] = { + "learningRate": None, + "optimizer": None, + "weightDecay": None, + "epochs": None + } + + trainLossHistory: List[List[float]] = [] + validationLossHistory: List[List[float]] = [] + labels: List[str] = [] + + for learningRate in learningRates: + for optimizer in optimizers: + for weightDecay in weightDecays: + labels.append(f"LR: {learningRate} | OPT: {optimizer} | WD: {weightDecay}") + model, trainingLoss, validationLoss = self.adjust( + validation=validation, + learningRate=learningRate, + epochs=epochs, + batchSize=batchSize, + optimizer=optimizer, + weightDecay=weightDecay, + lossFunction=lossFunction, + verbose=True + ) + trainLossHistory.append(trainingLoss) + validationLossHistory.append(validationLoss) + if validation > 0: + localBestValidationLoss, bestEpoch = min((val, idx) for idx, val in enumerate(validationLoss)) + if localBestValidationLoss < bestValidationLoss: + print(f"New best model found with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay}, " + f"Epoch: {bestEpoch + 1}/{epochs} | Train: {trainingLoss[bestEpoch]:.4f} | " + f"Validation: {localBestValidationLoss:.4f}") + del bestModel + gc.collect() + bestModel = model + bestTrainingLoss = trainingLoss[bestEpoch] + bestValidationLoss = localBestValidationLoss + bestHyperparameters["learningRate"] = learningRate + bestHyperparameters["optimizer"] = optimizer + bestHyperparameters["weightDecay"] = weightDecay + bestHyperparameters["epochs"] = bestEpoch + 1 + else: + print(f"Model with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay} did not improve.") + else: + localBestTrainingLoss, bestEpoch = min((val, idx) for idx, val in enumerate(trainingLoss)) + if localBestTrainingLoss < bestTrainingLoss: + print(f"New best model found with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay}, " + f"Epoch: {bestEpoch + 1}/{epochs} | Train: {localBestTrainingLoss:.4f}") + del bestModel + gc.collect() + bestModel = model + bestTrainingLoss = localBestTrainingLoss + bestHyperparameters["learningRate"] = learningRate + bestHyperparameters["optimizer"] = optimizer + bestHyperparameters["weightDecay"] = weightDecay + bestHyperparameters["epochs"] = bestEpoch + 1 + else: + print(f"Model with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay} did not improve.") + + if verbose: + print(f"\n\nBest model found with LR: {bestHyperparameters['learningRate']}, OPT: {bestHyperparameters['optimizer']}, " + f"WD: {bestHyperparameters['weightDecay']}, Epoch: {bestHyperparameters['epochs']}") + if validation > 0: + print(f"Train: {bestTrainingLoss:.4f} | Validation: {bestValidationLoss:.4f}") + else: + print(f"Train: {bestTrainingLoss:.4f}") + assert bestModel is not None, "The best model was not found. Something went wrong during the hyperparameter search." + self.adjustedModel = bestModel + del bestModel + gc.collect() + if plot: + fig1 = go.Figure() + for idx, label in enumerate(labels): + fig1.add_trace( + go.Scatter( + x=np.arange(epochs+1), + y=trainLossHistory[idx], + mode='lines+markers', + name=label) + + ) + fig1.update_layout( + title="Training Loss History", + xaxis_title="Epoch", + yaxis_title="Loss", + legend_title="Hyperparameters", + showlegend=True, + template="plotly_white" + ) + fig1.show() + if validation > 0: + fig2 = go.Figure() + for idx, label in enumerate(labels): + fig2.add_trace( + go.Scatter( + x=np.arange(epochs+1), + y=validationLossHistory[idx], + mode='lines+markers', + name=label) + ) + fig2.update_layout( + title="Validation Loss History", + xaxis_title="Epoch", + yaxis_title="Loss", + legend_title="Hyperparameters", + showlegend=True, + template="plotly_white" + ) + fig2.show() + + return self.adjustedModel, bestHyperparameters @@ -317,3 +470,7 @@ class OPTIMADEAdjuster(LocalAdjuster): settings used by that database or focusing its attention to specific chemistry like, for instance, all compounds of Sn and all perovskites. It accepts OPTIMADE query as an input and then operates based on the ``LocalAdjuster`` class. """ + + + +