Skip to content

Commit

Permalink
Backend Tensorflow 1.x: replace tf.layers.dense to tf.keras.layers.Dense
Browse files Browse the repository at this point in the history
  • Loading branch information
vl-dud committed Oct 21, 2024
1 parent ba8e824 commit fba2305
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions deepxde/nn/tensorflow_compat_v1/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def build_branch_net(self):
if callable(self.layer_size_func[1]):
# User-defined network
return self.layer_size_func[1](self.X_func)

if self.stacked:
# Stacked fully connected network
return self._build_stacked_branch_net()
Expand Down Expand Up @@ -422,15 +422,14 @@ def _dense(
regularizer=None,
trainable=True,
):
return tf.layers.dense(
inputs,
return tf.keras.layers.Dense(
units,
activation=activation,
use_bias=use_bias,
kernel_initializer=self.kernel_initializer,
kernel_regularizer=regularizer,
trainable=trainable,
)
)(inputs)

def _stacked_dense(
self, inputs, units, stack_size, activation=None, use_bias=True, trainable=True
Expand Down Expand Up @@ -637,25 +636,23 @@ def build_branch_net(self):
else:
# Fully connected network
for i in range(1, len(self.layer_size_func) - 1):
y_func = tf.layers.dense(
y_func,
y_func = tf.keras.layers.Dense(
self.layer_size_func[i],
activation=self.activation_branch,
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
)(y_func)
if self.dropout_rate_branch[i - 1] > 0:
y_func = tf.layers.dropout(
y_func,
rate=self.dropout_rate_branch[i - 1],
training=self.training,
)
y_func = tf.layers.dense(
y_func,
y_func = tf.keras.layers.Dense(
self.layer_size_func[-1],
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
)(y_func)
return y_func

def build_trunk_net(self):
Expand All @@ -664,13 +661,12 @@ def build_trunk_net(self):
if self._input_transform is not None:
y_loc = self._input_transform(y_loc)
for i in range(1, len(self.layer_size_loc)):
y_loc = tf.layers.dense(
y_loc,
y_loc = tf.keras.layers.Dense(
self.layer_size_loc[i],
activation=self.activation_trunk,
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
)(y_loc)
if self.dropout_rate_trunk[i - 1] > 0:
y_loc = tf.layers.dropout(
y_loc, rate=self.dropout_rate_trunk[i - 1], training=self.training
Expand Down

0 comments on commit fba2305

Please sign in to comment.