From bfca3b187ab4da0aa326529961d9f45354e1bec1 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Tue, 5 Nov 2024 17:57:13 -0500 Subject: [PATCH] Bump up transformers --- bayesflow/networks/__init__.py | 2 +- bayesflow/networks/embeddings/__init__.py | 1 + bayesflow/networks/embeddings/time2vec.py | 72 ++++++++ bayesflow/networks/transformers/__init__.py | 2 + .../transformers/fusion_transformer.py | 159 +++++++++++++++++- bayesflow/networks/transformers/mab.py | 14 +- bayesflow/networks/transformers/pma.py | 2 +- .../networks/transformers/set_transformer.py | 11 +- .../transformers/time_series_transformer.py | 126 ++++++++++++++ bayesflow/utils/__init__.py | 1 + bayesflow/utils/validators.py | 3 + 11 files changed, 374 insertions(+), 19 deletions(-) create mode 100644 bayesflow/networks/embeddings/time2vec.py create mode 100644 bayesflow/networks/transformers/time_series_transformer.py create mode 100644 bayesflow/utils/validators.py diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index 7ae8ff93f..fce9b27fa 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -7,4 +7,4 @@ from .mlp import MLP from .lstnet import LSTNet from .summary_network import SummaryNetwork -from .transformers import SetTransformer +from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer diff --git a/bayesflow/networks/embeddings/__init__.py b/bayesflow/networks/embeddings/__init__.py index c76d31251..289ef6ab4 100644 --- a/bayesflow/networks/embeddings/__init__.py +++ b/bayesflow/networks/embeddings/__init__.py @@ -1 +1,2 @@ from .fourier_embedding import FourierEmbedding +from .time2vec import Time2Vec diff --git a/bayesflow/networks/embeddings/time2vec.py b/bayesflow/networks/embeddings/time2vec.py new file mode 100644 index 000000000..991e89170 --- /dev/null +++ b/bayesflow/networks/embeddings/time2vec.py @@ -0,0 +1,72 @@ +import keras +from keras.saving import register_keras_serializable as serializable + +from bayesflow.types import Tensor +from bayesflow.utils import expand_tile + + +@serializable(package="bayesflow.networks") +class Time2Vec(keras.Layer): + """ + Implements the Time2Vec learnbale embedding from [1]. + [1] Kazemi, S. M., Goel, R., Eghbali, S., Ramanan, J., Sahota, J., Thakur, S., ... & Brubaker, M. + (2019). Time2vec: Learning a vector representation of time. arXiv preprint arXiv:1907.05321. + """ + + def __init__(self, num_periodic_features: int = 8): + super().__init__() + + self.num_periodic_features = num_periodic_features + self.linear_weight = self.add_weight( + shape=(1,), + initializer="glorot_uniform", + trainable=True, + name="trend_weight", + ) + + self.linear_bias = self.add_weight( + shape=(1,), + initializer="glorot_uniform", + trainable=True, + name="trend_bias", + ) + + self.periodic_weights = self.add_weight( + shape=(self.num_periodic_features,), + initializer="glorot_normal", + trainable=True, + name="periodic_weights", + ) + + self.periodic_biases = self.add_weight( + shape=(self.num_periodic_features,), + initializer="glorot_normal", + trainable=True, + name="periodic_biases", + ) + + def call(self, x: Tensor, t: Tensor = None) -> Tensor: + """Creates time representations and concatenates them to x. + + Parameters: + ----------- + x : Tensor of shape (batch_size, sequence_length, dim) + The input sequence. + t : Tensor of shape (batch_size, sequence_length) + Vector of times + + Returns: + -------- + emb : Tensor + Embedding of shape (batch_size, fourier_emb_dim) if `include_identity` + is False, else (batch_size, fourier_emb_dim+1) + """ + + if t is None: + t = keras.ops.linspace(0, keras.ops.shape(x)[1], keras.ops.shape(x)[1], dtype=x.dtype) + t = expand_tile(t, keras.ops.shape(x)[0], axis=0) + + linear = t * self.linear_weight + self.linear_bias + periodic = keras.ops.sin(t[..., None] * self.periodic_weights[None, :] + self.periodic_biases[None, :]) + emb = keras.ops.concatenate([linear[..., None], periodic], axis=-1) + return keras.ops.concatenate([x, emb], axis=-1) diff --git a/bayesflow/networks/transformers/__init__.py b/bayesflow/networks/transformers/__init__.py index ab9e03de8..1aef426c8 100644 --- a/bayesflow/networks/transformers/__init__.py +++ b/bayesflow/networks/transformers/__init__.py @@ -1 +1,3 @@ from .set_transformer import SetTransformer +from .time_series_transformer import TimeSeriesTransformer +from .fusion_transformer import FusionTransformer diff --git a/bayesflow/networks/transformers/fusion_transformer.py b/bayesflow/networks/transformers/fusion_transformer.py index 3dde8eede..8747b9c27 100644 --- a/bayesflow/networks/transformers/fusion_transformer.py +++ b/bayesflow/networks/transformers/fusion_transformer.py @@ -1,3 +1,156 @@ -class FusionTransformer: - # TODO - pass +import keras +from keras import layers +from keras.saving import register_keras_serializable as serializable + +from bayesflow.types import Tensor +from bayesflow.utils import check_lengths_same + +from ..embeddings import Time2Vec +from ..summary_network import SummaryNetwork + +from .mab import MultiHeadAttentionBlock + + +@serializable(package="bayesflow.networks") +class FusionTransformer(SummaryNetwork): + """Implements a more flexible version of the TimeSeriesTransformer that applies a series of self-attention layers + followed by cross-attention between the representation and a learnable template summarized via a recurrent net.""" + + def __init__( + self, + summary_dim: int = 16, + embed_dims: tuple = (64, 64), + num_heads: tuple = (4, 4), + mlp_depths: tuple = (2, 2), + mlp_widths: tuple = (128, 128), + dropout: float = 0.05, + mlp_activation: str = "gelu", + kernel_initializer: str = "he_normal", + use_bias: bool = True, + layer_norm: bool = True, + t2v_embed_dim: int = 8, + template_type: str = "lstm", + bidirectional: bool = True, + template_dim: int = 128, + **kwargs, + ): + """Creates a fusion transformer used to flexibly compress time series. If the time intervals vary across + batches, it is highly recommended that your simulator also returns a "time" vector denoting absolute or + relative time. + + Parameters + ---------- + summary_dim : int, optional (default - 16) + Dimensionality of the final summary output. + embed_dims : tuple of int, optional (default - (64, 64)) + Dimensions of the keys, values, and queries for each attention block. + num_heads : tuple of int, optional (default - (4, 4)) + Number of attention heads for each embedding dimension. + mlp_depths : tuple of int, optional (default - (2, 2)) + Depth of the multi-layer perceptron (MLP) blocks for each component. + mlp_widths : tuple of int, optional (default - (128, 128)) + Width of each MLP layer in each block for each component. + dropout : float, optional (default - 0.05) + Dropout rate applied to the attention and MLP layers. If set to None, no dropout is applied. + mlp_activation : str, optional (default - 'gelu') + Activation function used in the dense layers. Common choices include "relu", "elu", and "gelu". + kernel_initializer : str, optional (default - 'he_normal') + Initializer for the kernel weights matrix. Common choices include "glorot_uniform", "he_normal", etc. + use_bias : bool, optional (default - True) + Whether to include a bias term in the dense layers. + layer_norm : bool, optional (default - True) + Whether to apply layer normalization after the attention and MLP layers. + t2v_embed_dim : int, optional (default - 8) + The dimensionality of the Time2Vec embedding. + template_type : str or callable, optional, default: 'lstm' + The many-to-one (learnable) transformation of the time series. + if ``lstm``, an LSTM network will be used. + if ``gru``, a GRU unit will be used. + bidirectional : bool, optional (default - False) + Indicates whether the involved recurrent template network is bidirectional (i.e., forward + and backward in time) or unidirectional (forward in time). Defaults to False, but may + increase performance in some applications. + template_dim : int, optional (default - 128) + Only used if ``template_type`` in ['lstm', 'gru']. The number of hidden + units (equiv. output dimensions) of the recurrent network. + **kwargs : dict + Additional keyword arguments passed to the base layer. + """ + + super().__init__(**kwargs) + + # Ensure all tuple-settings have the same length + check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths) + + # Initialize Time2Vec embedding layer + self.time2vec = Time2Vec(t2v_embed_dim) + + # Construct a series of set-attention blocks + self.attention_blocks = [] + for i in range(len(embed_dims)): + layer_attention_settings = dict( + dropout=dropout, + mlp_activation=mlp_activation, + kernel_initializer=kernel_initializer, + use_bias=use_bias, + layer_norm=layer_norm, + num_heads=num_heads[i], + embed_dim=embed_dims[i], + mlp_depth=mlp_depths[i], + mlp_width=mlp_widths[i], + ) + + block = MultiHeadAttentionBlock(**layer_attention_settings) + self.attention_blocks.append(block) + + # A recurrent network will learn a dynamic many-to-one template + if template_type.upper() == "LSTM": + self.template_net = ( + layers.Bidirectional(layers.LSTM(template_dim // 2, dropout=dropout)) + if bidirectional + else (layers.LSTM(template_dim, dropout=dropout)) + ) + elif template_type.upper() == "GRU": + self.template_net = ( + layers.Bidirectional(layers.GRU(template_dim // 2, dropout=dropout)) + if bidirectional + else (layers.GRU(template_dim, dropout=dropout)) + ) + else: + raise ValueError("Argument `template_dim` should be in ['lstm', 'gru']") + + self.output_projector = keras.layers.Dense(summary_dim) + + def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = False, **kwargs) -> Tensor: + """Compresses the input sequence into a summary vector of size `summary_dim`. + + Parameters + ---------- + input_sequence : Tensor + Input of shape (batch_size, sequence_length, input_dim) + time : Tensor + Time vector of shape (batch_size, sequence_length), optional (default - None) + Note: time values for Time2Vec embeddings will be inferred on a linearly spaced + interval between [0, sequence length], if no time vector is specified. + training : boolean, optional (default - False) + Passed to the optional internal dropout and spectral normalization + layers to distinguish between train and test time behavior. + **kwargs : dict, optional (default - {}) + Additional keyword arguments passed to the internal attention layer, + such as ``attention_mask`` or ``return_attention_scores`` + + Returns + ------- + out : Tensor + Output of shape (batch_size, set_size, output_dim) + """ + + inp = self.time2vec(input_sequence, t=time) + template = self.template_net(inp, training=training) + + for layer in self.attention_blocks[:-1]: + inp = layer(inp, inp, training=training, **kwargs) + + summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), inp, training=training, **kwargs) + summary = self.output_projector(keras.ops.squeeze(summary, axis=1)) + return summary diff --git a/bayesflow/networks/transformers/mab.py b/bayesflow/networks/transformers/mab.py index 7fb4654d9..757af5cc5 100644 --- a/bayesflow/networks/transformers/mab.py +++ b/bayesflow/networks/transformers/mab.py @@ -56,16 +56,16 @@ def __init__( self.output_projector = layers.Dense(embed_dim) self.ln_post = layers.LayerNormalization() if layer_norm else None - def call(self, set_x: Tensor, set_y: Tensor, training: bool = False, **kwargs) -> Tensor: + def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -> Tensor: """Performs the forward pass through the attention layer. Parameters ---------- - set_x : Tensor (e.g., np.ndarray, tf.Tensor, ...) - Input of shape (batch_size, set_size_x, input_dim), which will + seq_x : Tensor (e.g., np.ndarray, tf.Tensor, ...) + Input of shape (batch_size, seq_size_x, input_dim), which will play the role of a query (Q). - set_y : Tensor - Input of shape (batch_size, set_size_y, input_dim), which will + seq_y : Tensor + Input of shape (batch_size, seq_size_y, input_dim), which will play the role of key (K) and value (V). training : boolean, optional (default - True) Passed to the optional internal dropout and spectral normalization @@ -80,8 +80,8 @@ def call(self, set_x: Tensor, set_y: Tensor, training: bool = False, **kwargs) - Output of shape (batch_size, set_size_x, output_dim) """ - h = self.input_projector(set_x) + self.attention( - query=set_x, key=set_y, value=set_y, training=training, **kwargs + h = self.input_projector(seq_x) + self.attention( + query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs ) if self.ln_pre is not None: h = self.ln_pre(h, training=training) diff --git a/bayesflow/networks/transformers/pma.py b/bayesflow/networks/transformers/pma.py index c41d245cf..b52a4168f 100644 --- a/bayesflow/networks/transformers/pma.py +++ b/bayesflow/networks/transformers/pma.py @@ -35,7 +35,7 @@ def __init__( layer_norm: bool = True, **kwargs, ): - """Creates a multi-head attention block (MAB) which will perform cross-attention between an input set + """Creates a multi-head attention block (MAB) which will perform cross-attention between an input sequence and a set of seed vectors (typically one for a single summary) with summary_dim output dimensions. Could also be used as part of a ``DeepSet`` for representing learnable instead of fixed pooling. diff --git a/bayesflow/networks/transformers/set_transformer.py b/bayesflow/networks/transformers/set_transformer.py index 38a730a49..4972bc76b 100644 --- a/bayesflow/networks/transformers/set_transformer.py +++ b/bayesflow/networks/transformers/set_transformer.py @@ -2,6 +2,7 @@ from keras.saving import register_keras_serializable as serializable from bayesflow.types import Tensor +from bayesflow.utils import check_lengths_same from ..summary_network import SummaryNetwork @@ -62,7 +63,7 @@ def __init__( dropout : float, optional (default - 0.05) Dropout rate applied to the attention and MLP layers. If set to None, no dropout is applied. mlp_activation : str, optional (default - 'gelu') - Activation function used in the dense layers. Common choices include "relu", "tanh", and "gelu". + Activation function used in the dense layers. Common choices include "relu", "elu", and "gelu". kernel_initializer : str, optional (default - 'he_normal') Initializer for the kernel weights matrix. Common choices include "glorot_uniform", "he_normal", etc. use_bias : bool, optional (default - True) @@ -79,7 +80,8 @@ def __init__( super().__init__(**kwargs) - SetTransformer._check_lengths(embed_dims, num_heads, mlp_depths, mlp_widths) + check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths) + num_attention_layers = len(embed_dims) # Construct a series of set-attention blocks @@ -144,8 +146,3 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor: summary = self.pooling_by_attention(summary, training=training, **kwargs) summary = self.output_projector(summary) return summary - - @staticmethod - def _check_lengths(*args): - if len(set(map(len, args))) > 1: - raise ValueError("All tuple arguments must have the same length.") diff --git a/bayesflow/networks/transformers/time_series_transformer.py b/bayesflow/networks/transformers/time_series_transformer.py new file mode 100644 index 000000000..f076b3327 --- /dev/null +++ b/bayesflow/networks/transformers/time_series_transformer.py @@ -0,0 +1,126 @@ +import keras +from keras.saving import register_keras_serializable as serializable + +from bayesflow.types import Tensor +from bayesflow.utils import check_lengths_same + +from ..embeddings import Time2Vec +from ..summary_network import SummaryNetwork + +from .mab import MultiHeadAttentionBlock + + +@serializable(package="bayesflow.networks") +class TimeSeriesTransformer(SummaryNetwork): + def __init__( + self, + summary_dim: int = 16, + embed_dims: tuple = (64, 64), + num_heads: tuple = (4, 4), + mlp_depths: tuple = (2, 2), + mlp_widths: tuple = (128, 128), + dropout: float = 0.05, + mlp_activation: str = "gelu", + kernel_initializer: str = "he_normal", + use_bias: bool = True, + layer_norm: bool = True, + t2v_embed_dim: int = 8, + **kwargs, + ): + """Creates a regular transformer coupled with Time2Vec embeddings of time used to flexibly compress time series. + If the time intervals vary across batches, it is highly recommended that your simulator also returns a "time" + vector denoting absolute or relative time. + + Parameters + ---------- + summary_dim : int, optional (default - 16) + Dimensionality of the final summary output. + embed_dims : tuple of int, optional (default - (64, 64)) + Dimensions of the keys, values, and queries for each attention block. + num_heads : tuple of int, optional (default - (4, 4)) + Number of attention heads for each embedding dimension. + mlp_depths : tuple of int, optional (default - (2, 2)) + Depth of the multi-layer perceptron (MLP) blocks for each component. + mlp_widths : tuple of int, optional (default - (128, 128)) + Width of each MLP layer in each block for each component. + dropout : float, optional (default - 0.05) + Dropout rate applied to the attention and MLP layers. If set to None, no dropout is applied. + mlp_activation : str, optional (default - 'gelu') + Activation function used in the dense layers. Common choices include "relu", "elu", and "gelu". + kernel_initializer : str, optional (default - 'he_normal') + Initializer for the kernel weights matrix. Common choices include "glorot_uniform", "he_normal", etc. + use_bias : bool, optional (default - True) + Whether to include a bias term in the dense layers. + layer_norm : bool, optional (default - True) + Whether to apply layer normalization after the attention and MLP layers. + t2v_embed_dim : int, optional (default - 8) + The dimensionality of the Time2Vec embedding. + **kwargs : dict + Additional keyword arguments passed to the base layer. + """ + + super().__init__(**kwargs) + + # Ensure all tuple-settings have the same length + check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths) + + # Initialize Time2Vec embedding layer + self.time2vec = Time2Vec(t2v_embed_dim) + + # Construct a series of set-attention blocks + self.attention_blocks = [] + for i in range(len(embed_dims)): + layer_attention_settings = dict( + dropout=dropout, + mlp_activation=mlp_activation, + kernel_initializer=kernel_initializer, + use_bias=use_bias, + layer_norm=layer_norm, + num_heads=num_heads[i], + embed_dim=embed_dims[i], + mlp_depth=mlp_depths[i], + mlp_width=mlp_widths[i], + ) + + block = MultiHeadAttentionBlock(**layer_attention_settings) + self.attention_blocks.append(block) + + # Pooling will be applied as a final step to the abstract representations obtained from set attention + self.pooling = keras.layers.GlobalAvgPool1D() + self.output_projector = keras.layers.Dense(summary_dim) + + def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = False, **kwargs) -> Tensor: + """Compresses the input sequence into a summary vector of size `summary_dim`. + + Parameters + ---------- + input_sequence : Tensor + Input of shape (batch_size, sequence_length, input_dim) + time : Tensor + Time vector of shape (batch_size, sequence_length), optional (default - None) + Note: time values for Time2Vec embeddings will be inferred on a linearly spaced + interval between [0, sequence length], if no time vector is specified. + training : boolean, optional (default - False) + Passed to the optional internal dropout and spectral normalization + layers to distinguish between train and test time behavior. + **kwargs : dict, optional (default - {}) + Additional keyword arguments passed to the internal attention layer, + such as ``attention_mask`` or ``return_attention_scores`` + + Returns + ------- + out : Tensor + Output of shape (batch_size, set_size, output_dim) + """ + + # Concatenate learnable time embedding to input sequence + inp = self.time2vec(input_sequence, t=time) + + # Apply self-attention blocks + for layer in self.attention_blocks: + inp = layer(inp, inp, training=training, **kwargs) + + # Global average pooling and output projection + summary = self.pooling(inp) + summary = self.output_projector(summary) + return summary diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 4674e5eab..c7e461af8 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -41,3 +41,4 @@ concatenate, tree_stack, ) +from .validators import check_lengths_same diff --git a/bayesflow/utils/validators.py b/bayesflow/utils/validators.py new file mode 100644 index 000000000..ac08583c3 --- /dev/null +++ b/bayesflow/utils/validators.py @@ -0,0 +1,3 @@ +def check_lengths_same(*args): + if len(set(map(len, args))) > 1: + raise ValueError(f"All tuple arguments must have the same length, but lengths are {tuple(map(len, args))}.")