From 3544fdfd4e43ffb920d2b191a53d66946226199e Mon Sep 17 00:00:00 2001 From: vl-dud <60846135+vl-dud@users.noreply.github.com> Date: Mon, 2 Dec 2024 01:34:32 +0000 Subject: [PATCH] Backend PyTorch: Fix L2 regularizers for external_trainable_variables (#1884) --- deepxde/model.py | 59 ++++++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/deepxde/model.py b/deepxde/model.py index 458dfe5cf..1644c2529 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -112,9 +112,10 @@ def compile( weighted by the `loss_weights` coefficients. external_trainable_variables: A trainable ``dde.Variable`` object or a list of trainable ``dde.Variable`` objects. The unknown parameters in the - physics systems that need to be recovered. If the backend is - tensorflow.compat.v1, `external_trainable_variables` is ignored, and all - trainable ``dde.Variable`` objects are automatically collected. + physics systems that need to be recovered. Regularization will not be + applied to these variables. If the backend is tensorflow.compat.v1, + `external_trainable_variables` is ignored, and all trainable ``dde.Variable`` + objects are automatically collected. verbose (Integer): Controls the verbosity of the compile process. """ if verbose > 0 and config.rank == 0: @@ -330,30 +331,40 @@ def outputs_losses_test(inputs, targets, auxiliary_vars): False, inputs, targets, auxiliary_vars, self.data.losses_test ) - # Another way is using per-parameter options - # https://pytorch.org/docs/stable/optim.html#per-parameter-options, - # but not all optimizers (such as L-BFGS) support this. - trainable_variables = ( - list(self.net.parameters()) + self.external_trainable_variables - ) - if self.net.regularizer is None: - self.opt, self.lr_scheduler = optimizers.get( - trainable_variables, self.opt_name, learning_rate=lr, decay=decay - ) - else: - if self.net.regularizer[0] == "l2": - self.opt, self.lr_scheduler = optimizers.get( - trainable_variables, - self.opt_name, - learning_rate=lr, - decay=decay, - weight_decay=self.net.regularizer[1], - ) - else: + weight_decay = 0 + if self.net.regularizer is not None: + if self.net.regularizer[0] != "l2": raise NotImplementedError( f"{self.net.regularizer[0]} regularization to be implemented for " - "backend pytorch." + "backend pytorch" ) + weight_decay = self.net.regularizer[1] + + optimizer_params = self.net.parameters() + if self.external_trainable_variables: + # L-BFGS doesn't support per-parameter options. + if self.opt_name in ["L-BFGS", "L-BFGS-B"]: + optimizer_params = ( + list(optimizer_params) + self.external_trainable_variables + ) + if weight_decay > 0: + print( + "Warning: L2 regularization will also be applied to external_trainable_variables. " + "Ensure this is intended behavior." + ) + else: + optimizer_params = [ + {"params": optimizer_params}, + {"params": self.external_trainable_variables, "weight_decay": 0}, + ] + + self.opt, self.lr_scheduler = optimizers.get( + optimizer_params, + self.opt_name, + learning_rate=lr, + decay=decay, + weight_decay=weight_decay, + ) def train_step(inputs, targets, auxiliary_vars): def closure():