Skip to content

Commit

Permalink
Bump up transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Nov 5, 2024
1 parent 39f3c45 commit bfca3b1
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 19 deletions.
2 changes: 1 addition & 1 deletion bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .mlp import MLP
from .lstnet import LSTNet
from .summary_network import SummaryNetwork
from .transformers import SetTransformer
from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer
1 change: 1 addition & 0 deletions bayesflow/networks/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .fourier_embedding import FourierEmbedding
from .time2vec import Time2Vec
72 changes: 72 additions & 0 deletions bayesflow/networks/embeddings/time2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import expand_tile


@serializable(package="bayesflow.networks")
class Time2Vec(keras.Layer):
"""
Implements the Time2Vec learnbale embedding from [1].
[1] Kazemi, S. M., Goel, R., Eghbali, S., Ramanan, J., Sahota, J., Thakur, S., ... & Brubaker, M.
(2019). Time2vec: Learning a vector representation of time. arXiv preprint arXiv:1907.05321.
"""

def __init__(self, num_periodic_features: int = 8):
super().__init__()

self.num_periodic_features = num_periodic_features
self.linear_weight = self.add_weight(
shape=(1,),
initializer="glorot_uniform",
trainable=True,
name="trend_weight",
)

self.linear_bias = self.add_weight(
shape=(1,),
initializer="glorot_uniform",
trainable=True,
name="trend_bias",
)

self.periodic_weights = self.add_weight(
shape=(self.num_periodic_features,),
initializer="glorot_normal",
trainable=True,
name="periodic_weights",
)

self.periodic_biases = self.add_weight(
shape=(self.num_periodic_features,),
initializer="glorot_normal",
trainable=True,
name="periodic_biases",
)

def call(self, x: Tensor, t: Tensor = None) -> Tensor:
"""Creates time representations and concatenates them to x.
Parameters:
-----------
x : Tensor of shape (batch_size, sequence_length, dim)
The input sequence.
t : Tensor of shape (batch_size, sequence_length)
Vector of times
Returns:
--------
emb : Tensor
Embedding of shape (batch_size, fourier_emb_dim) if `include_identity`
is False, else (batch_size, fourier_emb_dim+1)
"""

if t is None:
t = keras.ops.linspace(0, keras.ops.shape(x)[1], keras.ops.shape(x)[1], dtype=x.dtype)
t = expand_tile(t, keras.ops.shape(x)[0], axis=0)

linear = t * self.linear_weight + self.linear_bias
periodic = keras.ops.sin(t[..., None] * self.periodic_weights[None, :] + self.periodic_biases[None, :])
emb = keras.ops.concatenate([linear[..., None], periodic], axis=-1)
return keras.ops.concatenate([x, emb], axis=-1)
2 changes: 2 additions & 0 deletions bayesflow/networks/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .set_transformer import SetTransformer
from .time_series_transformer import TimeSeriesTransformer
from .fusion_transformer import FusionTransformer
159 changes: 156 additions & 3 deletions bayesflow/networks/transformers/fusion_transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,156 @@
class FusionTransformer:
# TODO
pass
import keras
from keras import layers
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import check_lengths_same

from ..embeddings import Time2Vec
from ..summary_network import SummaryNetwork

from .mab import MultiHeadAttentionBlock


@serializable(package="bayesflow.networks")
class FusionTransformer(SummaryNetwork):
"""Implements a more flexible version of the TimeSeriesTransformer that applies a series of self-attention layers
followed by cross-attention between the representation and a learnable template summarized via a recurrent net."""

def __init__(
self,
summary_dim: int = 16,
embed_dims: tuple = (64, 64),
num_heads: tuple = (4, 4),
mlp_depths: tuple = (2, 2),
mlp_widths: tuple = (128, 128),
dropout: float = 0.05,
mlp_activation: str = "gelu",
kernel_initializer: str = "he_normal",
use_bias: bool = True,
layer_norm: bool = True,
t2v_embed_dim: int = 8,
template_type: str = "lstm",
bidirectional: bool = True,
template_dim: int = 128,
**kwargs,
):
"""Creates a fusion transformer used to flexibly compress time series. If the time intervals vary across
batches, it is highly recommended that your simulator also returns a "time" vector denoting absolute or
relative time.
Parameters
----------
summary_dim : int, optional (default - 16)
Dimensionality of the final summary output.
embed_dims : tuple of int, optional (default - (64, 64))
Dimensions of the keys, values, and queries for each attention block.
num_heads : tuple of int, optional (default - (4, 4))
Number of attention heads for each embedding dimension.
mlp_depths : tuple of int, optional (default - (2, 2))
Depth of the multi-layer perceptron (MLP) blocks for each component.
mlp_widths : tuple of int, optional (default - (128, 128))
Width of each MLP layer in each block for each component.
dropout : float, optional (default - 0.05)
Dropout rate applied to the attention and MLP layers. If set to None, no dropout is applied.
mlp_activation : str, optional (default - 'gelu')
Activation function used in the dense layers. Common choices include "relu", "elu", and "gelu".
kernel_initializer : str, optional (default - 'he_normal')
Initializer for the kernel weights matrix. Common choices include "glorot_uniform", "he_normal", etc.
use_bias : bool, optional (default - True)
Whether to include a bias term in the dense layers.
layer_norm : bool, optional (default - True)
Whether to apply layer normalization after the attention and MLP layers.
t2v_embed_dim : int, optional (default - 8)
The dimensionality of the Time2Vec embedding.
template_type : str or callable, optional, default: 'lstm'
The many-to-one (learnable) transformation of the time series.
if ``lstm``, an LSTM network will be used.
if ``gru``, a GRU unit will be used.
bidirectional : bool, optional (default - False)
Indicates whether the involved recurrent template network is bidirectional (i.e., forward
and backward in time) or unidirectional (forward in time). Defaults to False, but may
increase performance in some applications.
template_dim : int, optional (default - 128)
Only used if ``template_type`` in ['lstm', 'gru']. The number of hidden
units (equiv. output dimensions) of the recurrent network.
**kwargs : dict
Additional keyword arguments passed to the base layer.
"""

super().__init__(**kwargs)

# Ensure all tuple-settings have the same length
check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths)

# Initialize Time2Vec embedding layer
self.time2vec = Time2Vec(t2v_embed_dim)

# Construct a series of set-attention blocks
self.attention_blocks = []
for i in range(len(embed_dims)):
layer_attention_settings = dict(
dropout=dropout,
mlp_activation=mlp_activation,
kernel_initializer=kernel_initializer,
use_bias=use_bias,
layer_norm=layer_norm,
num_heads=num_heads[i],
embed_dim=embed_dims[i],
mlp_depth=mlp_depths[i],
mlp_width=mlp_widths[i],
)

block = MultiHeadAttentionBlock(**layer_attention_settings)
self.attention_blocks.append(block)

# A recurrent network will learn a dynamic many-to-one template
if template_type.upper() == "LSTM":
self.template_net = (
layers.Bidirectional(layers.LSTM(template_dim // 2, dropout=dropout))
if bidirectional
else (layers.LSTM(template_dim, dropout=dropout))
)
elif template_type.upper() == "GRU":
self.template_net = (
layers.Bidirectional(layers.GRU(template_dim // 2, dropout=dropout))
if bidirectional
else (layers.GRU(template_dim, dropout=dropout))
)
else:
raise ValueError("Argument `template_dim` should be in ['lstm', 'gru']")

self.output_projector = keras.layers.Dense(summary_dim)

def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = False, **kwargs) -> Tensor:
"""Compresses the input sequence into a summary vector of size `summary_dim`.
Parameters
----------
input_sequence : Tensor
Input of shape (batch_size, sequence_length, input_dim)
time : Tensor
Time vector of shape (batch_size, sequence_length), optional (default - None)
Note: time values for Time2Vec embeddings will be inferred on a linearly spaced
interval between [0, sequence length], if no time vector is specified.
training : boolean, optional (default - False)
Passed to the optional internal dropout and spectral normalization
layers to distinguish between train and test time behavior.
**kwargs : dict, optional (default - {})
Additional keyword arguments passed to the internal attention layer,
such as ``attention_mask`` or ``return_attention_scores``
Returns
-------
out : Tensor
Output of shape (batch_size, set_size, output_dim)
"""

inp = self.time2vec(input_sequence, t=time)
template = self.template_net(inp, training=training)

for layer in self.attention_blocks[:-1]:
inp = layer(inp, inp, training=training, **kwargs)

summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), inp, training=training, **kwargs)
summary = self.output_projector(keras.ops.squeeze(summary, axis=1))
return summary
14 changes: 7 additions & 7 deletions bayesflow/networks/transformers/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ def __init__(
self.output_projector = layers.Dense(embed_dim)
self.ln_post = layers.LayerNormalization() if layer_norm else None

def call(self, set_x: Tensor, set_y: Tensor, training: bool = False, **kwargs) -> Tensor:
def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -> Tensor:
"""Performs the forward pass through the attention layer.
Parameters
----------
set_x : Tensor (e.g., np.ndarray, tf.Tensor, ...)
Input of shape (batch_size, set_size_x, input_dim), which will
seq_x : Tensor (e.g., np.ndarray, tf.Tensor, ...)
Input of shape (batch_size, seq_size_x, input_dim), which will
play the role of a query (Q).
set_y : Tensor
Input of shape (batch_size, set_size_y, input_dim), which will
seq_y : Tensor
Input of shape (batch_size, seq_size_y, input_dim), which will
play the role of key (K) and value (V).
training : boolean, optional (default - True)
Passed to the optional internal dropout and spectral normalization
Expand All @@ -80,8 +80,8 @@ def call(self, set_x: Tensor, set_y: Tensor, training: bool = False, **kwargs) -
Output of shape (batch_size, set_size_x, output_dim)
"""

h = self.input_projector(set_x) + self.attention(
query=set_x, key=set_y, value=set_y, training=training, **kwargs
h = self.input_projector(seq_x) + self.attention(
query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs
)
if self.ln_pre is not None:
h = self.ln_pre(h, training=training)
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/networks/transformers/pma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
layer_norm: bool = True,
**kwargs,
):
"""Creates a multi-head attention block (MAB) which will perform cross-attention between an input set
"""Creates a multi-head attention block (MAB) which will perform cross-attention between an input sequence
and a set of seed vectors (typically one for a single summary) with summary_dim output dimensions.
Could also be used as part of a ``DeepSet`` for representing learnable instead of fixed pooling.
Expand Down
11 changes: 4 additions & 7 deletions bayesflow/networks/transformers/set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import check_lengths_same

from ..summary_network import SummaryNetwork

Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
dropout : float, optional (default - 0.05)
Dropout rate applied to the attention and MLP layers. If set to None, no dropout is applied.
mlp_activation : str, optional (default - 'gelu')
Activation function used in the dense layers. Common choices include "relu", "tanh", and "gelu".
Activation function used in the dense layers. Common choices include "relu", "elu", and "gelu".
kernel_initializer : str, optional (default - 'he_normal')
Initializer for the kernel weights matrix. Common choices include "glorot_uniform", "he_normal", etc.
use_bias : bool, optional (default - True)
Expand All @@ -79,7 +80,8 @@ def __init__(

super().__init__(**kwargs)

SetTransformer._check_lengths(embed_dims, num_heads, mlp_depths, mlp_widths)
check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths)

num_attention_layers = len(embed_dims)

# Construct a series of set-attention blocks
Expand Down Expand Up @@ -144,8 +146,3 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
summary = self.pooling_by_attention(summary, training=training, **kwargs)
summary = self.output_projector(summary)
return summary

@staticmethod
def _check_lengths(*args):
if len(set(map(len, args))) > 1:
raise ValueError("All tuple arguments must have the same length.")
Loading

0 comments on commit bfca3b1

Please sign in to comment.