diff --git a/deepxde/nn/tensorflow_compat_v1/deeponet.py b/deepxde/nn/tensorflow_compat_v1/deeponet.py index 2eeb4db4c..aba79dc11 100644 --- a/deepxde/nn/tensorflow_compat_v1/deeponet.py +++ b/deepxde/nn/tensorflow_compat_v1/deeponet.py @@ -321,54 +321,48 @@ def build_branch_net(self): y_func = self.X_func if callable(self.layer_size_func[1]): # User-defined network - y_func = self.layer_size_func[1](y_func) - elif self.stacked: - # Stacked fully connected network - stack_size = self.layer_size_func[-1] - for i in range(1, len(self.layer_size_func) - 1): - y_func = self._stacked_dense( - y_func, - self.layer_size_func[i], - stack_size, - activation=self.activation_branch, + return self.layer_size_func[1](y_func) + + def _add_branch_layer( + inputs, units, stack_size=None, activation=None, use_bias=True + ): + if stack_size is None: + return self._dense( + inputs, + units, + activation=activation, + regularizer=self.regularizer, trainable=self.trainable_branch, + use_bias=use_bias, ) - if self.dropout_rate_branch[i - 1] > 0: - y_func = tf.layers.dropout( - y_func, - rate=self.dropout_rate_branch[i - 1], - training=self.training, - ) - y_func = self._stacked_dense( - y_func, - 1, + return self._stacked_dense( + inputs, + units, stack_size, - use_bias=self.use_bias, + activation=activation, trainable=self.trainable_branch, + use_bias=use_bias, ) - else: - # Unstacked fully connected network - for i in range(1, len(self.layer_size_func) - 1): - y_func = self._dense( - y_func, - self.layer_size_func[i], - activation=self.activation_branch, - regularizer=self.regularizer, - trainable=self.trainable_branch, - ) - if self.dropout_rate_branch[i - 1] > 0: - y_func = tf.layers.dropout( - y_func, - rate=self.dropout_rate_branch[i - 1], - training=self.training, - ) - y_func = self._dense( + + for i in range(1, len(self.layer_size_func) - 1): + y_func = _add_branch_layer( y_func, - self.layer_size_func[-1], - use_bias=self.use_bias, - regularizer=self.regularizer, - trainable=self.trainable_branch, + self.layer_size_func[i], + self.layer_size_func[-1] if self.stacked else None, + activation=self.activation_branch, ) + if self.dropout_rate_branch[i - 1] > 0: + y_func = tf.layers.dropout( + y_func, + rate=self.dropout_rate_branch[i - 1], + training=self.training, + ) + y_func = _add_branch_layer( + y_func, + 1 if self.stacked else self.layer_size_func[-1], + self.layer_size_func[-1] if self.stacked else None, + use_bias=self.use_bias, + ) return y_func def build_trunk_net(self):