Skip to content

Commit

Permalink
Backend paddle: Refactor and add regularizer (#1894)
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 authored Dec 10, 2024
1 parent 3544fdf commit ec4bdd3
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 6 deletions.
39 changes: 39 additions & 0 deletions deepxde/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,42 @@ def sparse_dense_matmul(x, y):
Returns:
Tensor: The multiplication result.
"""


###############################################################################
# Regularization


def l1_regularization(l1):
"""A regularizer that applies a L1 regularization penalty or L1 weight decay.
Warning:
The implementation may vary across different backends.
Args:
l1 (float): L1 regularization factor.
"""


def l2_regularization(l2):
"""A regularizer that applies a L2 regularization penalty or L2 weight decay.
Warning:
The implementation may vary across different backends.
Args:
l2 (float): L2 regularization factor.
"""


def l1_l2_regularization(l1, l2):
"""A regularizer that applies both L1 and L2 regularization penalties or
L1 and L2 weight decay.
Warning:
The implementation may vary across different backends.
Args:
l1 (float): L1 regularization factor.
l2 (float): L2 regularization factor.
"""
8 changes: 8 additions & 0 deletions deepxde/backend/paddle/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,11 @@ def matmul(x, y):

def sparse_dense_matmul(x, y):
return paddle.sparse.matmul(x, y)


def l1_regularization(l1):
return paddle.regularizer.L1Decay(coeff=l1)


def l2_regularization(l2):
return paddle.regularizer.L2Decay(coeff=l2)
12 changes: 12 additions & 0 deletions deepxde/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,15 @@ def zeros_like(input_tensor):

def matmul(x, y):
return tf.linalg.matmul(x, y)


def l1_regularization(l1):
return tf.keras.regularizers.L1(l1=l1)


def l2_regularization(l2):
return tf.keras.regularizers.L2(l2=l2)


def l1_l2_regularization(l1, l2):
return tf.keras.regularizers.L1L2(l1=l1, l2=l2)
12 changes: 12 additions & 0 deletions deepxde/backend/tensorflow_compat_v1/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,15 @@ def matmul(x, y):

def sparse_dense_matmul(x, y):
return tf.sparse.sparse_dense_matmul(x, y)


def l1_regularization(l1):
return tf.keras.regularizers.L1(l1=l1)


def l2_regularization(l2):
return tf.keras.regularizers.L2(l2=l2)


def l1_l2_regularization(l1, l2):
return tf.keras.regularizers.L1L2(l1=l1, l2=l2)
12 changes: 6 additions & 6 deletions deepxde/nn/regularizers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ..backend import tf
from .. import backend as bkd


def get(identifier):
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
"""Retrieves a regularizer instance based on the given identifier.
Args:
identifier (list/tuple): Specifies the type and factor(s) of the regularizer.
Expand All @@ -11,7 +11,6 @@ def get(identifier):
For "l1l2", provide both "l1" and "l2" factors.
"""

# TODO: other backends
if identifier is None or not identifier:
return None
if not isinstance(identifier, (list, tuple)):
Expand All @@ -23,11 +22,12 @@ def get(identifier):
raise ValueError("Regularization factor must be provided.")

if name == "l1":
return tf.keras.regularizers.L1(l1=factor[0])
return bkd.l1_regularization(factor[0])
if name == "l2":
return tf.keras.regularizers.L2(l2=factor[0])
return bkd.l2_regularization(factor[0])
if name in ("l1l2", "l1+l2"):
# TODO: only supported by 'tensorflow.compat.v1' and 'tensorflow' now.
if len(factor) < 2:
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1])
return bkd.l1_l2_regularization(factor[0], factor[1])
raise ValueError(f"Unknown regularizer name: {name}")

0 comments on commit ec4bdd3

Please sign in to comment.