Skip to content

Commit

Permalink
Bugfixed save/load set transformers with inducing points
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jul 16, 2023
1 parent 8c3d8da commit 1b64a2c
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions bayesflow/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,10 @@ def call(self, x, **kwargs):
Output of shape (batch_size, set_size, input_dim)
"""

batch_size = x.shape[0]
h = self.mab0(tf.stack([self.I] * batch_size), x, **kwargs)
batch_size = tf.shape(x)[0]
I_expanded = self.I[None, ...]
I_tiled = tf.tile(I_expanded, [batch_size, 1, 1])
h = self.mab0(I_tiled, x, **kwargs)
return self.mab1(x, h, **kwargs)


Expand Down Expand Up @@ -240,7 +242,7 @@ def __init__(
summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs
)
init = tf.keras.initializers.GlorotUniform()
self.seed_vec = init(shape=(num_seeds, summary_dim))
self.seed_vec = tf.Variable(init(shape=(num_seeds, summary_dim)), name="seed_vec", trainable=True)
self.fc = Sequential([Dense(**dense_settings) for _ in range(num_dense_fc)])
self.fc.add(Dense(summary_dim))

Expand All @@ -258,7 +260,9 @@ def call(self, x, **kwargs):
Output of shape (batch_size, num_seeds * summary_dim)
"""

batch_size = x.shape[0]
out = self.fc(x)
out = self.mab(tf.stack([self.seed_vec] * batch_size), out, **kwargs)
return tf.reshape(out, (out.shape[0], -1))
batch_size = tf.shape(x)[0]
seed_expanded = self.seed_vec[None, ...]
seed_tiled = tf.tile(seed_expanded, [batch_size, 1, 1])
out = self.mab(seed_tiled, out, **kwargs)
return tf.reshape(out, (tf.shape(out)[0], -1))

0 comments on commit 1b64a2c

Please sign in to comment.