diff --git a/bayesflow/attention.py b/bayesflow/attention.py index 6e9d49e3d..b279cf252 100644 --- a/bayesflow/attention.py +++ b/bayesflow/attention.py @@ -191,8 +191,10 @@ def call(self, x, **kwargs): Output of shape (batch_size, set_size, input_dim) """ - batch_size = x.shape[0] - h = self.mab0(tf.stack([self.I] * batch_size), x, **kwargs) + batch_size = tf.shape(x)[0] + I_expanded = self.I[None, ...] + I_tiled = tf.tile(I_expanded, [batch_size, 1, 1]) + h = self.mab0(I_tiled, x, **kwargs) return self.mab1(x, h, **kwargs) @@ -240,7 +242,7 @@ def __init__( summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs ) init = tf.keras.initializers.GlorotUniform() - self.seed_vec = init(shape=(num_seeds, summary_dim)) + self.seed_vec = tf.Variable(init(shape=(num_seeds, summary_dim)), name="seed_vec", trainable=True) self.fc = Sequential([Dense(**dense_settings) for _ in range(num_dense_fc)]) self.fc.add(Dense(summary_dim)) @@ -258,7 +260,9 @@ def call(self, x, **kwargs): Output of shape (batch_size, num_seeds * summary_dim) """ - batch_size = x.shape[0] out = self.fc(x) - out = self.mab(tf.stack([self.seed_vec] * batch_size), out, **kwargs) - return tf.reshape(out, (out.shape[0], -1)) + batch_size = tf.shape(x)[0] + seed_expanded = self.seed_vec[None, ...] + seed_tiled = tf.tile(seed_expanded, [batch_size, 1, 1]) + out = self.mab(seed_tiled, out, **kwargs) + return tf.reshape(out, (tf.shape(out)[0], -1))