Skip to content

Commit

Permalink
Tensorflow 1.x backend: branch subnet refactoring for DeepONet
Browse files Browse the repository at this point in the history
  • Loading branch information
vl-dud committed Oct 2, 2024
1 parent f7aa563 commit 3b97ba9
Showing 1 changed file with 35 additions and 41 deletions.
76 changes: 35 additions & 41 deletions deepxde/nn/tensorflow_compat_v1/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,54 +321,48 @@ def build_branch_net(self):
y_func = self.X_func
if callable(self.layer_size_func[1]):
# User-defined network
y_func = self.layer_size_func[1](y_func)
elif self.stacked:
# Stacked fully connected network
stack_size = self.layer_size_func[-1]
for i in range(1, len(self.layer_size_func) - 1):
y_func = self._stacked_dense(
y_func,
self.layer_size_func[i],
stack_size,
activation=self.activation_branch,
return self.layer_size_func[1](y_func)

def _add_branch_layer(
inputs, units, stack_size=None, activation=None, use_bias=True
):
if stack_size is None:
return self._dense(
inputs,
units,
activation=activation,
regularizer=self.regularizer,
trainable=self.trainable_branch,
use_bias=use_bias,
)
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
y_func = self._stacked_dense(
y_func,
1,
return self._stacked_dense(
inputs,
units,
stack_size,
use_bias=self.use_bias,
activation=activation,
trainable=self.trainable_branch,
use_bias=use_bias,
)
else:
# Unstacked fully connected network
for i in range(1, len(self.layer_size_func) - 1):
y_func = self._dense(
y_func,
self.layer_size_func[i],
activation=self.activation_branch,
regularizer=self.regularizer,
trainable=self.trainable_branch,
)
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
y_func = self._dense(

for i in range(1, len(self.layer_size_func) - 1):
y_func = _add_branch_layer(
y_func,
self.layer_size_func[-1],
use_bias=self.use_bias,
regularizer=self.regularizer,
trainable=self.trainable_branch,
self.layer_size_func[i],
self.layer_size_func[-1] if self.stacked else None,
activation=self.activation_branch,
)
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
y_func = _add_branch_layer(
y_func,
1 if self.stacked else self.layer_size_func[-1],
self.layer_size_func[-1] if self.stacked else None,
use_bias=self.use_bias,
)
return y_func

def build_trunk_net(self):
Expand Down

0 comments on commit 3b97ba9

Please sign in to comment.