diff --git a/bayesflow/summary_networks.py b/bayesflow/summary_networks.py index b29a8a4ea..57a0e6adf 100644 --- a/bayesflow/summary_networks.py +++ b/bayesflow/summary_networks.py @@ -140,7 +140,7 @@ def __init__( # Construct final attention layer, which will perform cross-attention # between the outputs ot the self-attention layers and the dynamic template if bidirectional: - final_input_dim = template_dim*2 + final_input_dim = template_dim * 2 else: final_input_dim = template_dim self.output_attention = MultiHeadAttentionBlock( @@ -184,7 +184,8 @@ def call(self, x, **kwargs): class SetTransformer(tf.keras.Model): """Implements the set transformer architecture from [1] which ultimately represents - a learnable permutation-invariant function. + a learnable permutation-invariant function. Designed to naturally model interactions in + the input set, which may be hard to capture with the simpler ``DeepSet`` architecture. [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). Set transformer: A framework for attention-based permutation-invariant neural networks.