From eba9b550c31393a4bee2794ffdff31d455a120ea Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Tue, 17 Dec 2024 02:27:05 +0000 Subject: [PATCH] update code and add regularizer for nn --- deepxde/nn/paddle/nn.py | 1 + deepxde/optimizers/paddle/optimizers.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/deepxde/nn/paddle/nn.py b/deepxde/nn/paddle/nn.py index 6609027dc..74e6e5723 100644 --- a/deepxde/nn/paddle/nn.py +++ b/deepxde/nn/paddle/nn.py @@ -6,6 +6,7 @@ class NN(paddle.nn.Layer): def __init__(self): super().__init__() + self.regularizer = None self._input_transform = None self._output_transform = None diff --git a/deepxde/optimizers/paddle/optimizers.py b/deepxde/optimizers/paddle/optimizers.py index 1a12f9f5c..edb2ad7d8 100644 --- a/deepxde/optimizers/paddle/optimizers.py +++ b/deepxde/optimizers/paddle/optimizers.py @@ -62,13 +62,11 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None): weight_decay=weight_decay, ) if optimizer == "adamw": - if isinstance(weight_decay, paddle.regularizer.L2Decay): - if weight_decay._coeff == 0: - raise ValueError("AdamW optimizer requires non-zero weight decay") - return paddle.optimizer.AdamW( - learning_rate=learning_rate, - parameters=params, - weight_decay=weight_decay._coeff, - ) - raise ValueError("AdamW optimizer requires l2 regularizer") + if not isinstance(weight_decay, paddle.regularizer.L2Decay) or weight_decay._coeff == 0: + raise ValueError("AdamW optimizer requires L2 regularizer and non-zero weight decay") + return paddle.optimizer.AdamW( + learning_rate=learning_rate, + parameters=params, + weight_decay=weight_decay._coeff, + ) raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")