From f7aa5637cf81d943a0504cd43f7fa90fc70b505c Mon Sep 17 00:00:00 2001 From: Christopher Laurens Woolford <70346478+cwoolfo1@users.noreply.github.com> Date: Fri, 6 Sep 2024 20:47:26 -0500 Subject: [PATCH] Backend pytorch: FNN supports regularization (#1833) --- deepxde/nn/pytorch/fnn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepxde/nn/pytorch/fnn.py b/deepxde/nn/pytorch/fnn.py index 873013acc..94f09e5da 100644 --- a/deepxde/nn/pytorch/fnn.py +++ b/deepxde/nn/pytorch/fnn.py @@ -9,7 +9,9 @@ class FNN(NN): """Fully-connected neural network.""" - def __init__(self, layer_sizes, activation, kernel_initializer): + def __init__( + self, layer_sizes, activation, kernel_initializer, regularization=None + ): super().__init__() if isinstance(activation, list): if not (len(layer_sizes) - 1) == len(activation): @@ -21,6 +23,7 @@ def __init__(self, layer_sizes, activation, kernel_initializer): self.activation = activations.get(activation) initializer = initializers.get(kernel_initializer) initializer_zero = initializers.get("zeros") + self.regularizer = regularization self.linears = torch.nn.ModuleList() for i in range(1, len(layer_sizes)):