Skip to content

Commit

Permalink
Improve SetTransformer documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
elseml committed Dec 14, 2023
1 parent f1f19fd commit 219aeec
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions bayesflow/summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
# Construct final attention layer, which will perform cross-attention
# between the outputs ot the self-attention layers and the dynamic template
if bidirectional:
final_input_dim = template_dim*2
final_input_dim = template_dim * 2
else:
final_input_dim = template_dim
self.output_attention = MultiHeadAttentionBlock(
Expand Down Expand Up @@ -184,7 +184,8 @@ def call(self, x, **kwargs):

class SetTransformer(tf.keras.Model):
"""Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function.
a learnable permutation-invariant function. Designed to naturally model interactions in
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
Expand Down

0 comments on commit 219aeec

Please sign in to comment.