Skip to content

Commit

Permalink
Backend PyTorch: Fix L2 regularizers for external_trainable_variables (
Browse files Browse the repository at this point in the history
  • Loading branch information
vl-dud authored Dec 2, 2024
1 parent 8275aeb commit 3544fdf
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def compile(
weighted by the `loss_weights` coefficients.
external_trainable_variables: A trainable ``dde.Variable`` object or a list
of trainable ``dde.Variable`` objects. The unknown parameters in the
physics systems that need to be recovered. If the backend is
tensorflow.compat.v1, `external_trainable_variables` is ignored, and all
trainable ``dde.Variable`` objects are automatically collected.
physics systems that need to be recovered. Regularization will not be
applied to these variables. If the backend is tensorflow.compat.v1,
`external_trainable_variables` is ignored, and all trainable ``dde.Variable``
objects are automatically collected.
verbose (Integer): Controls the verbosity of the compile process.
"""
if verbose > 0 and config.rank == 0:
Expand Down Expand Up @@ -330,30 +331,40 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
False, inputs, targets, auxiliary_vars, self.data.losses_test
)

# Another way is using per-parameter options
# https://pytorch.org/docs/stable/optim.html#per-parameter-options,
# but not all optimizers (such as L-BFGS) support this.
trainable_variables = (
list(self.net.parameters()) + self.external_trainable_variables
)
if self.net.regularizer is None:
self.opt, self.lr_scheduler = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
)
else:
if self.net.regularizer[0] == "l2":
self.opt, self.lr_scheduler = optimizers.get(
trainable_variables,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=self.net.regularizer[1],
)
else:
weight_decay = 0
if self.net.regularizer is not None:
if self.net.regularizer[0] != "l2":
raise NotImplementedError(
f"{self.net.regularizer[0]} regularization to be implemented for "
"backend pytorch."
"backend pytorch"
)
weight_decay = self.net.regularizer[1]

optimizer_params = self.net.parameters()
if self.external_trainable_variables:
# L-BFGS doesn't support per-parameter options.
if self.opt_name in ["L-BFGS", "L-BFGS-B"]:
optimizer_params = (
list(optimizer_params) + self.external_trainable_variables
)
if weight_decay > 0:
print(
"Warning: L2 regularization will also be applied to external_trainable_variables. "
"Ensure this is intended behavior."
)
else:
optimizer_params = [
{"params": optimizer_params},
{"params": self.external_trainable_variables, "weight_decay": 0},
]

self.opt, self.lr_scheduler = optimizers.get(
optimizer_params,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=weight_decay,
)

def train_step(inputs, targets, auxiliary_vars):
def closure():
Expand Down

0 comments on commit 3544fdf

Please sign in to comment.