Skip to content

Commit

Permalink
Finalize set transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 12, 2024
1 parent b951c14 commit 78a7ad7
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 15 deletions.
2 changes: 2 additions & 0 deletions bayesflow/experimental/networks/transformers/isab.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@

import keras
import keras.ops as ops
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
from .mab import MultiHeadAttentionBlock


@register_keras_serializable(package="bayesflow.networks")
class InducedSetAttentionBlock(keras.Layer):
"""Implements the ISAB block from [1] which represents learnable self-attention specifically
designed to deal with large sets via a learnable set of "inducing points".
Expand Down
6 changes: 1 addition & 5 deletions bayesflow/experimental/networks/transformers/mab.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@

import keras
from keras import layers
from keras.saving import (
deserialize_keras_object,
register_keras_serializable,
serialize_keras_object,
)
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor

Expand Down
6 changes: 4 additions & 2 deletions bayesflow/experimental/networks/transformers/pma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import keras
import keras.ops as ops
from keras import layers
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
from .mab import MultiHeadAttentionBlock


@register_keras_serializable(package="bayesflow.networks")
class PoolingByMultiHeadAttention(keras.Layer):
"""Implements the pooling with multi-head attention (PMA) block from [1] which represents
a permutation-invariant encoder for set-based inputs.
Expand All @@ -21,11 +23,11 @@ class PoolingByMultiHeadAttention(keras.Layer):

def __init__(
self,
num_seeds: int = 1,
seed_dim: int = None,
summary_dim: int = 16,
num_seeds: int = 1,
key_dim: int = 32,
num_heads: int = 4,
seed_dim: int = None,
dropout: float = 0.05,
num_dense_feedforward: int = 2,
dense_units: int = 128,
Expand Down
6 changes: 1 addition & 5 deletions bayesflow/experimental/networks/transformers/sab.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@

from keras.saving import (
deserialize_keras_object,
register_keras_serializable,
serialize_keras_object,
)
from keras.saving import register_keras_serializable

from .mab import MultiHeadAttentionBlock

Expand Down
104 changes: 101 additions & 3 deletions bayesflow/experimental/networks/transformers/set_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,102 @@

class SetTransformer:
#TODO
pass
import keras
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor

from .sab import SetAttentionBlock
from .isab import InducedSetAttentionBlock
from .pma import PoolingByMultiHeadAttention


@register_keras_serializable(package="bayesflow.networks")
class SetTransformer(keras.Layer):
"""Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function. Designed to naturally model interactions in
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
In International conference on machine learning (pp. 3744-3753). PMLR.
"""

def __init__(
self,
summary_dim: int = 16,
num_attention_blocks: int = 2,
num_inducing_points: int = None,
num_seeds: int = 1,
key_dim: int = 32,
num_heads: int = 4,
dropout: float = 0.05,
num_dense_feedforward: int = 2,
dense_units: int = 128,
dense_activation: str = "gelu",
kernel_initializer: str = "he_normal",
use_bias=True,
layer_norm: bool = True,
set_attention_output_dim: int = None,
seed_dim: int = None,
**kwargs
):
"""
#TODO
"""

super().__init__(**kwargs)

# Construct a series of set-attention blocks
self.attention_blocks = keras.Sequential()

attention_block_settings = dict(
num_inducing_points=num_inducing_points,
key_dim=key_dim,
num_heads=num_heads,
dropout=dropout,
num_dense_feedforward=num_dense_feedforward,
output_dim=set_attention_output_dim,
dense_units=dense_units,
dense_activation=dense_activation,
kernel_initializer=kernel_initializer,
use_bias=use_bias,
layer_norm=layer_norm
)

for _ in range(num_attention_blocks):
if num_inducing_points is not None:
block = InducedSetAttentionBlock(**attention_block_settings)
else:
block = SetAttentionBlock(**{k: v for k, v in attention_block_settings if k != "num_inducing_points"})
self.attention_blocks.add(block)

# Pooling will be applied as a final step to the abstract representations obtained from set attention
attention_block_settings.pop("num_inducing_points")
attention_block_settings.pop("output_dim")
pooling_settings = dict(
seed_dim=seed_dim,
num_seeds=num_seeds,
summary_dim=summary_dim
)
self.pooling_by_attention = PoolingByMultiHeadAttention(**attention_block_settings | pooling_settings)

# Output projector is needed to keep output dimensions be summary_dim in case of num_seeds > 1
self.output_projector = keras.layers.Dense(summary_dim)

def call(self, set_x: Tensor, **kwargs) -> Tensor:
"""Performs the forward pass through the set-transformer.
Parameters
----------
set_x : Tensor
The input set of shape (batch_size, set_size, input_dim)
Returns
-------
set_summary : Tensor
Output representation of shape (batch_size, summary_dim)
"""

set_summary = self.attention_blocks(set_x, **kwargs)
set_summary = self.pooling_by_attention(set_summary, **kwargs)
set_summary = self.output_projector(set_summary)
return set_summary

0 comments on commit 78a7ad7

Please sign in to comment.