diff --git a/deepxde/nn/tensorflow_compat_v1/deeponet.py b/deepxde/nn/tensorflow_compat_v1/deeponet.py index 2eeb4db4c..f7fbe992e 100644 --- a/deepxde/nn/tensorflow_compat_v1/deeponet.py +++ b/deepxde/nn/tensorflow_compat_v1/deeponet.py @@ -318,58 +318,67 @@ def build(self): self.built = True def build_branch_net(self): - y_func = self.X_func if callable(self.layer_size_func[1]): # User-defined network - y_func = self.layer_size_func[1](y_func) - elif self.stacked: + return self.layer_size_func[1](self.X_func) + + if self.stacked: # Stacked fully connected network - stack_size = self.layer_size_func[-1] - for i in range(1, len(self.layer_size_func) - 1): - y_func = self._stacked_dense( - y_func, - self.layer_size_func[i], - stack_size, - activation=self.activation_branch, - trainable=self.trainable_branch, - ) - if self.dropout_rate_branch[i - 1] > 0: - y_func = tf.layers.dropout( - y_func, - rate=self.dropout_rate_branch[i - 1], - training=self.training, - ) + return self._build_stacked_branch_net() + + # Unstacked fully connected network + return self._build_unstacked_branch_net() + + def _build_stacked_branch_net(self): + y_func = self.X_func + stack_size = self.layer_size_func[-1] + + for i in range(1, len(self.layer_size_func) - 1): y_func = self._stacked_dense( y_func, - 1, - stack_size, - use_bias=self.use_bias, + self.layer_size_func[i], + stack_size=stack_size, + activation=self.activation_branch, trainable=self.trainable_branch, ) - else: - # Unstacked fully connected network - for i in range(1, len(self.layer_size_func) - 1): - y_func = self._dense( + if self.dropout_rate_branch[i - 1] > 0: + y_func = tf.layers.dropout( y_func, - self.layer_size_func[i], - activation=self.activation_branch, - regularizer=self.regularizer, - trainable=self.trainable_branch, + rate=self.dropout_rate_branch[i - 1], + training=self.training, ) - if self.dropout_rate_branch[i - 1] > 0: - y_func = tf.layers.dropout( - y_func, - rate=self.dropout_rate_branch[i - 1], - training=self.training, - ) + return self._stacked_dense( + y_func, + 1, + stack_size=stack_size, + use_bias=self.use_bias, + trainable=self.trainable_branch, + ) + + def _build_unstacked_branch_net(self): + y_func = self.X_func + + for i in range(1, len(self.layer_size_func) - 1): y_func = self._dense( y_func, - self.layer_size_func[-1], - use_bias=self.use_bias, + self.layer_size_func[i], + activation=self.activation_branch, regularizer=self.regularizer, trainable=self.trainable_branch, ) - return y_func + if self.dropout_rate_branch[i - 1] > 0: + y_func = tf.layers.dropout( + y_func, + rate=self.dropout_rate_branch[i - 1], + training=self.training, + ) + return self._dense( + y_func, + self.layer_size_func[-1], + use_bias=self.use_bias, + regularizer=self.regularizer, + trainable=self.trainable_branch, + ) def build_trunk_net(self): y_loc = self.X_loc