diff --git a/bayesflow/experimental/networks/transformers/isab.py b/bayesflow/experimental/networks/transformers/isab.py index 5429e7ad2..d44b35b58 100644 --- a/bayesflow/experimental/networks/transformers/isab.py +++ b/bayesflow/experimental/networks/transformers/isab.py @@ -1,11 +1,13 @@ import keras import keras.ops as ops +from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor from .mab import MultiHeadAttentionBlock +@register_keras_serializable(package="bayesflow.networks") class InducedSetAttentionBlock(keras.Layer): """Implements the ISAB block from [1] which represents learnable self-attention specifically designed to deal with large sets via a learnable set of "inducing points". diff --git a/bayesflow/experimental/networks/transformers/mab.py b/bayesflow/experimental/networks/transformers/mab.py index 844deffeb..59c5c4b01 100644 --- a/bayesflow/experimental/networks/transformers/mab.py +++ b/bayesflow/experimental/networks/transformers/mab.py @@ -1,11 +1,7 @@ import keras from keras import layers -from keras.saving import ( - deserialize_keras_object, - register_keras_serializable, - serialize_keras_object, -) +from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor diff --git a/bayesflow/experimental/networks/transformers/pma.py b/bayesflow/experimental/networks/transformers/pma.py index 2649508b4..316af8c11 100644 --- a/bayesflow/experimental/networks/transformers/pma.py +++ b/bayesflow/experimental/networks/transformers/pma.py @@ -2,11 +2,13 @@ import keras import keras.ops as ops from keras import layers +from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor from .mab import MultiHeadAttentionBlock +@register_keras_serializable(package="bayesflow.networks") class PoolingByMultiHeadAttention(keras.Layer): """Implements the pooling with multi-head attention (PMA) block from [1] which represents a permutation-invariant encoder for set-based inputs. @@ -21,11 +23,11 @@ class PoolingByMultiHeadAttention(keras.Layer): def __init__( self, - num_seeds: int = 1, - seed_dim: int = None, summary_dim: int = 16, + num_seeds: int = 1, key_dim: int = 32, num_heads: int = 4, + seed_dim: int = None, dropout: float = 0.05, num_dense_feedforward: int = 2, dense_units: int = 128, diff --git a/bayesflow/experimental/networks/transformers/sab.py b/bayesflow/experimental/networks/transformers/sab.py index 9440a3740..7154b3c4f 100644 --- a/bayesflow/experimental/networks/transformers/sab.py +++ b/bayesflow/experimental/networks/transformers/sab.py @@ -1,9 +1,5 @@ -from keras.saving import ( - deserialize_keras_object, - register_keras_serializable, - serialize_keras_object, -) +from keras.saving import register_keras_serializable from .mab import MultiHeadAttentionBlock diff --git a/bayesflow/experimental/networks/transformers/set_transformer.py b/bayesflow/experimental/networks/transformers/set_transformer.py index f9071011a..8f9ad8f0a 100644 --- a/bayesflow/experimental/networks/transformers/set_transformer.py +++ b/bayesflow/experimental/networks/transformers/set_transformer.py @@ -1,4 +1,102 @@ -class SetTransformer: - #TODO - pass +import keras +from keras.saving import register_keras_serializable + +from bayesflow.experimental.types import Tensor + +from .sab import SetAttentionBlock +from .isab import InducedSetAttentionBlock +from .pma import PoolingByMultiHeadAttention + + +@register_keras_serializable(package="bayesflow.networks") +class SetTransformer(keras.Layer): + """Implements the set transformer architecture from [1] which ultimately represents + a learnable permutation-invariant function. Designed to naturally model interactions in + the input set, which may be hard to capture with the simpler ``DeepSet`` architecture. + + [1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019). + Set transformer: A framework for attention-based permutation-invariant neural networks. + In International conference on machine learning (pp. 3744-3753). PMLR. + """ + + def __init__( + self, + summary_dim: int = 16, + num_attention_blocks: int = 2, + num_inducing_points: int = None, + num_seeds: int = 1, + key_dim: int = 32, + num_heads: int = 4, + dropout: float = 0.05, + num_dense_feedforward: int = 2, + dense_units: int = 128, + dense_activation: str = "gelu", + kernel_initializer: str = "he_normal", + use_bias=True, + layer_norm: bool = True, + set_attention_output_dim: int = None, + seed_dim: int = None, + **kwargs + ): + """ + #TODO + """ + + super().__init__(**kwargs) + + # Construct a series of set-attention blocks + self.attention_blocks = keras.Sequential() + + attention_block_settings = dict( + num_inducing_points=num_inducing_points, + key_dim=key_dim, + num_heads=num_heads, + dropout=dropout, + num_dense_feedforward=num_dense_feedforward, + output_dim=set_attention_output_dim, + dense_units=dense_units, + dense_activation=dense_activation, + kernel_initializer=kernel_initializer, + use_bias=use_bias, + layer_norm=layer_norm + ) + + for _ in range(num_attention_blocks): + if num_inducing_points is not None: + block = InducedSetAttentionBlock(**attention_block_settings) + else: + block = SetAttentionBlock(**{k: v for k, v in attention_block_settings if k != "num_inducing_points"}) + self.attention_blocks.add(block) + + # Pooling will be applied as a final step to the abstract representations obtained from set attention + attention_block_settings.pop("num_inducing_points") + attention_block_settings.pop("output_dim") + pooling_settings = dict( + seed_dim=seed_dim, + num_seeds=num_seeds, + summary_dim=summary_dim + ) + self.pooling_by_attention = PoolingByMultiHeadAttention(**attention_block_settings | pooling_settings) + + # Output projector is needed to keep output dimensions be summary_dim in case of num_seeds > 1 + self.output_projector = keras.layers.Dense(summary_dim) + + def call(self, set_x: Tensor, **kwargs) -> Tensor: + """Performs the forward pass through the set-transformer. + + Parameters + ---------- + set_x : Tensor + The input set of shape (batch_size, set_size, input_dim) + + Returns + ------- + set_summary : Tensor + Output representation of shape (batch_size, summary_dim) + """ + + set_summary = self.attention_blocks(set_x, **kwargs) + set_summary = self.pooling_by_attention(set_summary, **kwargs) + set_summary = self.output_projector(set_summary) + return set_summary