diff --git a/deepxde/nn/pytorch/fnn.py b/deepxde/nn/pytorch/fnn.py index e07f9581d..b54bef4c6 100644 --- a/deepxde/nn/pytorch/fnn.py +++ b/deepxde/nn/pytorch/fnn.py @@ -9,7 +9,9 @@ class FNN(NN): """Fully-connected neural network.""" - def __init__(self, layer_sizes, activation, kernel_initializer, regularization=None): + def __init__( + self, layer_sizes, activation, kernel_initializer, regularization=None + ): super().__init__() if isinstance(activation, list): if not (len(layer_sizes) - 1) == len(activation): @@ -31,9 +33,10 @@ def __init__(self, layer_sizes, activation, kernel_initializer, regularization=N ) initializer(self.linears[-1].weight) initializer_zero(self.linears[-1].bias) - self.regularizer=regularization + self.regularizer = regularization # currently list with two components: regularization type, weight decay # currently supports l2 regularization + def forward(self, inputs): x = inputs if self._input_transform is not None: