Skip to content

Commit

Permalink
update code for adamw optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 committed Dec 16, 2024
1 parent c9c10ff commit 63d6fcd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
10 changes: 1 addition & 9 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,15 +517,7 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
trainable_variables = (
list(self.net.parameters()) + self.external_trainable_variables
)
regularizer = getattr(self.net, "regularizer", None)
if regularizer is not None:
weight_decay = (
self.net.regularizer_value
if self.opt_name == "adamw"
else self.net.regularizer
)
else:
weight_decay = None
weight_decay = getattr(self.net, "regularizer", None)
self.opt = optimizers.get(
trainable_variables,
self.opt_name,
Expand Down
16 changes: 9 additions & 7 deletions deepxde/optimizers/paddle/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
weight_decay=weight_decay,
)
if optimizer == "adamw":
if weight_decay[0] == 0:
raise ValueError("AdamW optimizer requires non-zero weight decay")
return paddle.optimizer.AdamW(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay[0],
)
if isinstance(weight_decay, paddle.regularizer.L2Decay):
if weight_decay._coeff == 0:
raise ValueError("AdamW optimizer requires non-zero weight decay")
return paddle.optimizer.AdamW(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay._coeff,
)
raise ValueError("AdamW optimizer requires l2 regularizer")
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")

0 comments on commit 63d6fcd

Please sign in to comment.