Skip to content

Commit

Permalink
Backend Tensorflow 1.x: branch subnet refactoring for DeepONet (#1849)
Browse files Browse the repository at this point in the history
  • Loading branch information
vl-dud authored Oct 20, 2024
1 parent f7aa563 commit ba8e824
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions deepxde/nn/tensorflow_compat_v1/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,58 +318,67 @@ def build(self):
self.built = True

def build_branch_net(self):
y_func = self.X_func
if callable(self.layer_size_func[1]):
# User-defined network
y_func = self.layer_size_func[1](y_func)
elif self.stacked:
return self.layer_size_func[1](self.X_func)

if self.stacked:
# Stacked fully connected network
stack_size = self.layer_size_func[-1]
for i in range(1, len(self.layer_size_func) - 1):
y_func = self._stacked_dense(
y_func,
self.layer_size_func[i],
stack_size,
activation=self.activation_branch,
trainable=self.trainable_branch,
)
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
return self._build_stacked_branch_net()

# Unstacked fully connected network
return self._build_unstacked_branch_net()

def _build_stacked_branch_net(self):
y_func = self.X_func
stack_size = self.layer_size_func[-1]

for i in range(1, len(self.layer_size_func) - 1):
y_func = self._stacked_dense(
y_func,
1,
stack_size,
use_bias=self.use_bias,
self.layer_size_func[i],
stack_size=stack_size,
activation=self.activation_branch,
trainable=self.trainable_branch,
)
else:
# Unstacked fully connected network
for i in range(1, len(self.layer_size_func) - 1):
y_func = self._dense(
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
self.layer_size_func[i],
activation=self.activation_branch,
regularizer=self.regularizer,
trainable=self.trainable_branch,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
return self._stacked_dense(
y_func,
1,
stack_size=stack_size,
use_bias=self.use_bias,
trainable=self.trainable_branch,
)

def _build_unstacked_branch_net(self):
y_func = self.X_func

for i in range(1, len(self.layer_size_func) - 1):
y_func = self._dense(
y_func,
self.layer_size_func[-1],
use_bias=self.use_bias,
self.layer_size_func[i],
activation=self.activation_branch,
regularizer=self.regularizer,
trainable=self.trainable_branch,
)
return y_func
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
return self._dense(
y_func,
self.layer_size_func[-1],
use_bias=self.use_bias,
regularizer=self.regularizer,
trainable=self.trainable_branch,
)

def build_trunk_net(self):
y_loc = self.X_loc
Expand Down

0 comments on commit ba8e824

Please sign in to comment.